In [None]:
class DisjointSet:
    def __init__(self, n):
        self.parent = list(range(n))
        self.rank = [0] * n
    
    def find(self, u):
        if self.parent[u] != u:
            self.parent[u] = self.find(self.parent[u])
        return self.parent[u]
    
    def union(self, u, v):
        root_u = self.find(u)
        root_v = self.find(v)
        if root_u != root_v:
            if self.rank[root_u] > self.rank[root_v]:
                self.parent[root_v] = root_u
            elif self.rank[root_u] < self.rank[root_v]:
                self.parent[root_u] = root_v
            else:
                self.parent[root_v] = root_u
                self.rank[root_u] += 1

def kruskal(graph):
    n = len(graph)
    mst = []
    disjoint_set = DisjointSet(n)
    edges = [(weight, u, v) for u, neighbors in enumerate(graph) for v, weight in neighbors]
    edges.sort()
    
    for weight, u, v in edges:
        if disjoint_set.find(u) != disjoint_set.find(v):
            mst.append((u, v, weight))
            disjoint_set.union(u, v)
    
    return mst

In [None]:
import heapq

def prim(graph):
    n = len(graph)
    mst = []
    visited = [False] * n
    start_vertex = 0  # Start from the first vertex
    visited[start_vertex] = True
    pq = [(weight, start_vertex, neighbor) for neighbor, weight in graph[start_vertex]]
    heapq.heapify(pq)
    
    while pq:
        weight, u, v = heapq.heappop(pq)
        if not visited[v]:
            visited[v] = True
            mst.append((u, v, weight))
            for neighbor, weight in graph[v]:
                if not visited[neighbor]:
                    heapq.heappush(pq, (weight, v, neighbor))
    
    return mst

In [None]:
import networkx as nx
import matplotlib.pyplot as plt

# Create a connected graph
G = nx.Graph()
G.add_nodes_from([0, 1, 2, 3, 4, 5, 6, 7])
G.add_edges_from([(0, 1), (0, 2), (0, 3), (1, 4), (1, 5), (2, 6), (2, 7)])

for u, v in G.edges():
    G[u][v]['weight'] = 1  # You can change the weights as per your requirement


mst = nx.minimum_spanning_tree(G)
print("Edges of the Minimum Spanning Tree:")
for u, v in mst.edges():
    print((u, v))

plt.figure(figsize=(12, 8))
plt.subplot(121)
pos = nx.spring_layout(G)
nx.draw(G, pos, with_labels=True, node_color='skyblue', node_size=2000, font_size=12, font_weight='bold')
plt.title('Original Graph')

# Draw the minimum spanning tree
plt.subplot(122)
nx.draw(mst, pos, with_labels=True, node_color='lightgreen', node_size=2000, font_size=12, font_weight='bold')
plt.title('Minimum Spanning Tree')

plt.tight_layout()
plt.show()
