# Prims MST

<img src="../images/Screenshot 2025-03-21 at 11.01.39 AM.png"/>

In [14]:
import heapq

def spanning_tree(V, adj):
    vis = [0] * V  # Visited nodes
    mst = []  # Stores MST edges
    pq = []  # Min heap priority queue

    # find the edge min weight from the tree
    # Initial: Find min cost edge
    min_weight = 1000
    for i in range(V):
        for neighbor, weight in adj[i]:
            if weight < min_weight:
                min_weight = weight
                u, v = i, neighbor

    
    heapq.heappush(pq, (0, u, -1))  # (weight, node, parent)
    
    total_weight = 0
    
    while pq:
        weight, node, parent = heapq.heappop(pq)
        
        if vis[node]:
            continue
        
        if parent != -1:
            mst.append((node, parent))
        
        vis[node] = 1
        total_weight += weight
        
        for adj_node, edge_w in adj[node]:
            if not vis[adj_node]:
                heapq.heappush(pq, (edge_w, adj_node, node))
    
    # Printing MST
    # for node, parent in mst:
    #     print(node, parent)
    
    return total_weight,mst

# Example usage:
adj_list = {
    0: [(1, 2), (3, 6)],
    1: [(0, 2), (2, 3), (3, 8), (4, 5)],
    2: [(1, 3), (4, 7)],
    3: [(0, 6), (1, 8), (4, 9)],
    4: [(1, 5), (2, 7), (3, 9)]
}

V = 5
print("Total Weight of MST:", spanning_tree(V, adj_list))


Total Weight of MST: (16, [(1, 0), (2, 1), (4, 1), (3, 0)])


In [15]:
import heapq

def spanning_tree(V, adj):
    vis = [0] * V  # Visited nodes
    mst = []  # Stores MST edges
    pq = []  # Min heap priority queue

    # Initial: Find min cost edge from the tree
    min_weight = float('inf')
    u, v = -1, -1
    for i in range(V):
        for j in range(i + 1, V):  # Avoid revisiting edges (undirected graph)
            weight = adj[i][j]
            if weight > 0 and weight < min_weight:  # Exclude no edges (weight > 0 means there's an edge)
                min_weight = weight
                u, v = i, j

    heapq.heappush(pq, (0, u, -1))  # (weight, node, parent)
    
    total_weight = 0
    
    while pq:
        weight, node, parent = heapq.heappop(pq)
        
        if vis[node]:
            continue
        
        if parent != -1:
            mst.append((node, parent))
        
        vis[node] = 1
        total_weight += weight
        
        # Check all possible adjacent nodes in the adjacency matrix
        for adj_node in range(V):
            edge_w = adj[node][adj_node]
            if not vis[adj_node] and edge_w > 0:  # Edge exists and node not visited
                heapq.heappush(pq, (edge_w, adj_node, node))
    
    return total_weight, mst

# Example usage with adjacency matrix:
adj_matrix = [
    [0, 2, 0, 6, 0],
    [2, 0, 3, 8, 5],
    [0, 3, 0, 0, 7],
    [6, 8, 0, 0, 9],
    [0, 5, 7, 9, 0]
]

V = 5
total_weight, mst = spanning_tree(V, adj_matrix)
print("Total Weight of MST:", total_weight)
print("Edges in MST:", mst)

Total Weight of MST: 16
Edges in MST: [(1, 0), (2, 1), (4, 1), (3, 0)]
