In [13]:
import itertools
import numpy as np

class SteinerTree:
    def __init__(self, graph, required, optional, a=1.0, b=1.0):
        self.graph = graph
        self.required = required
        self.optional = optional
        self.a = a
        self.b = b
        self.num_nodes = len(graph)

    def compute_hamiltonian(self, x, y):
        """
        Compute the Hamiltonian H = H_A + H_B.
        """
        h_a, h_b = 0, 0

        # Root constraint (ensure at least one root)
        h_a += self.a * (1 - sum(x[v][0] for v in range(self.num_nodes)))**2

        # Required nodes must be included
        for v in self.required:
            h_a += self.a * (1 - sum(x[v]))**2

        # Optional nodes inclusion
        for v in self.optional:
            h_a += self.a * (y[v] - max(x[v]))**2

        # Tree cost (H_B)
        h_b = self.b * sum(
            self.graph[u][v] * x[u][v] for u in range(self.num_nodes) for v in range(u + 1, self.num_nodes)
        )

        return h_a + h_b

    def get_neighbors(self, node):
        """Return the neighbors of a node."""
        return [v for v, cost in enumerate(self.graph[node]) if cost > 0]

    def is_valid_tree(self, edges):
        """Check if the edges form a valid connected tree."""
        visited = set()
        adjacency_list = {i: [] for i in range(self.num_nodes)}

        for u, v in edges:
            adjacency_list[u].append(v)
            adjacency_list[v].append(u)

        def dfs(node):
            visited.add(node)
            for neighbor in adjacency_list[node]:
                if neighbor not in visited:
                    dfs(neighbor)

        # Start DFS from the first required node
        dfs(self.required[0])

        # Check connectivity for required nodes
        return all(node in visited for node in self.required)

    def solve(self):
        """
        Solve the Steiner tree problem by minimizing the Hamiltonian.
        """
        min_hamiltonian = float('inf')
        best_x = None
        best_y = None

        # Iterate over subsets of optional nodes
        for subset in itertools.chain.from_iterable(itertools.combinations(self.optional, r) for r in range(len(self.optional) + 1)):
            y = np.zeros(self.num_nodes, dtype=int)
            for v in self.required + list(subset):
                y[v] = 1

            # Generate all subsets of edges
            valid_edges = [(u, v) for u in range(self.num_nodes) for v in self.get_neighbors(u) if u < v]
            for edge_subset in itertools.combinations(valid_edges, len(self.required) - 1 + len(subset)):
                if self.is_valid_tree(edge_subset):
                    x = np.zeros((self.num_nodes, self.num_nodes), dtype=int)
                    for u, v in edge_subset:
                        x[u][v] = 1
                        x[v][u] = 1  # Symmetric edge inclusion

                    # Compute Hamiltonian
                    h = self.compute_hamiltonian(x, y)
                    if h < min_hamiltonian:
                        min_hamiltonian = h
                        best_x = x
                        best_y = y

        return min_hamiltonian, best_x, best_y


# Example Usage
if __name__ == "__main__":
    # Example graph (adjacency matrix with edge weights)
#     graph = [
#         [0, 1, 0, 0],
#         [1, 0, 1, 0],
#         [0, 1, 0, 1],
#         [0, 0, 1, 0],
#     ]

#     required = [0, 2]
#     optional = [1, 3]

#     steiner_tree = SteinerTree(graph, required, optional)
#     hamiltonian, x, y = steiner_tree.solve()

    # Example 2: A more complex graph
    graph = [
        [0, 2, 0, 0, 0],
        [2, 0, 3, 0, 0],
        [0, 3, 0, 4, 0],
        [0, 0, 4, 0, 5],
        [0, 0, 0, 5, 0],
    ]

    required = [0, 4]  # Start and end points
    optional = [1, 2, 3]  # Middle points

    steiner_tree = SteinerTree(graph, required, optional)
    hamiltonian, x, y = steiner_tree.solve()
    
    print("Minimum Hamiltonian:", hamiltonian)
    print("X Matrix (tree inclusion):\n", x)
    print("Y Array (node inclusion):", y)


Minimum Hamiltonian: 14.0
X Matrix (tree inclusion):
 [[0 1 0 0 0]
 [1 0 1 0 0]
 [0 1 0 1 0]
 [0 0 1 0 1]
 [0 0 0 1 0]]
Y Array (node inclusion): [1 1 1 1 1]
