## **Minimum Spanning Tree**

In [116]:
import networkx as nx
from random import randint
from time import time
import heapq

In [117]:
def timer(func):
    def wrapper(*args):
        start = time()
        result = func(*args)
        end = time()
        print(f"{end-start:.8f}")
        return result
    return wrapper

Random Connected Weighted Graph

In [118]:
nodes = 300
G = nx.gnp_random_graph(nodes,0.5,directed=False)
for (u,v,w) in G.edges(data=True):
    w['weight'] = randint(1,20)

In [119]:
def find(parent,x):
    if parent[x] == x:
        return x
    return find(parent,parent[x])

def union(parent,x,y):
    x_root = find(parent,x)
    y_root = find(parent,y)

    if x_root != y_root:
        parent[y_root] = x_root

@timer
def kruskal(G):
    min_span_tree = {}
    parent = []
    total_weight = 0
    for i in G.nodes():
        parent.insert(i,i)

    for u,v,w in sorted(G.edges(data=True),key=lambda x: x[2]['weight']):
        if find(parent,u) != find(parent,v):
            min_span_tree[u] = (v,w['weight'])
            min_span_tree[v] = (v,w['weight'])
            total_weight += w['weight']
            union(parent,u,v)
    
    return min_span_tree,total_weight


@timer
def prim(G):
    min_span_tree = {}
    total_weight = 0
    cost = [float('inf') for _ in range(nodes)]
    visited = []
    heap = []

    visited.append(0)
    for u,v,w in G.edges(0,data=True):
        heapq.heappush(heap,(w['weight'],u,v))

    while heap:
        weight,v,u = heapq.heappop(heap)
        if u not in visited:
            visited.append(u)
            min_span_tree[v] = (u,weight)
            min_span_tree[u] = (v,weight)
            total_weight += weight

            for _,v,w in G.edges(u,data=True):
                if v not in visited and w['weight'] < cost[v]:
                    cost[v] = w['weight']
                    heapq.heappush(heap,(w['weight'],u,v))
    
    return min_span_tree,total_weight


Minimum Spanning Tree Tests

In [120]:
print("Kruskal")
for test in range(5):
    mst, k_weight = kruskal(G)
print("\nPrim")
for test in range(5):
    mst, p_weight = prim(G)

assert(k_weight == p_weight)

Kruskal
0.05451322
0.05499983
0.05622411
0.07916069
0.05765820
Prim
0.03003883
0.03399968
0.02899957
0.02700520
0.02751541
