This is an efficient implementation of the Prim's algorithm to find the minimum spanning tree (MST) that uses Heap data structure. 
The heap implements a min priority queue and stores the node with (current) min cost to be added to the MST in the next iteration.
The running time complexity of this algorithm is O(mlogn); m = # edges, n = #nodes

Input files stores edges with cost in the format tail_node, head_node, cost
For example, the third line of the edges.txt file is "2 3 -8874", indicating that there is an edge connecting vertex #2 and vertex #3 that has cost -8874.

In [56]:
adj_lst = {}

edge_file = open('clustering1.txt')

for line in edge_file:
    edge = line.split(' ')
    edge = [int(e) for e in edge]
    # print(edge)
    if(len(edge) == 3):
        if(edge[0] in adj_lst):
            adj_lst[edge[0]].append((edge[1], edge[2]))
        else:
            adj_lst[edge[0]] = [(edge[1], edge[2])] 

        if(edge[1] in adj_lst):
            adj_lst[edge[1]].append((edge[0], edge[2]))
        else:
            adj_lst[edge[1]] = [(edge[0], edge[2])] 

In [57]:
def heapify_up(h, i):
    if(i == 1):
        return
        
    if(h[i][1] < h[i//2][1]):    #// is floor division
        t = h[i//2]
        h[i//2] = h[i]
        h[i] = t
        heapify_up(h, i//2)

In [58]:
def heapify_down(h, i):
    if(2*i >= len(h)):
    # if(2*i >= len(h) - 1):
        return

    if(2*i < len(h) - 1):
        j = 2*i if(h[2*i][1] < h[2*i + 1][1]) else 2*i + 1
        # print('i, j: ', i, j)
    elif(2*i == len(h) - 1):
        j = 2*i

    if(h[i][1] > h[j][1]):
        t = h[j]
        h[j] = h[i]
        h[i] = t
        heapify_down(h, j)        

In [59]:
def extractMin(h):
    m = h[1]
    h[1] = h[len(h) - 1]
    del h[len(h) - 1]
    heapify_down(h,1)
    return m

In [60]:
def deleteHeapNode(h, w):
    # for i in range(1, len(h) - 1):
    for i in range(len(h)):
        if(h[i][0] == w):
            break

    deleted_node = h[i]
    #use the index i to delete the node and fix the damaged heap
    h[i] = h[len(h) - 1]
    del h[len(h) - 1]

    # last node was deleted; no need to heapify
    if(len(h) == i):
        return deleted_node
    
    if(i == 1):
        heapify_down(h,i)
    elif(h[i][1] < h[i//2][1]):      #// is floor division
        heapify_up(h,i)
    else:
        heapify_down(h, i)

    return deleted_node

In [61]:
def insertHeapNode(h, node):
    h.append(node)
    heapify_up(h, len(h) - 1)

In [62]:
def testHeapInv(h):
    for i in range(1,len(h)):
        break
            
        try:
            if(h[i][1] > h[2*i][1] or h[i][1] > h[2*i + 1][1]):
                print('heap damaged at: ', h[i])
                return False
        except IndexError:
            return True

    hl = [l[0] for l in h]
    s = set([y for y in hl if hl.count(y) > 1])

    if(len(s) > 0):
        print('heap damaged at node(s): ', s)
        return False

    return True

heap = [(0,0),(1, 4), (6, 7), (7, 7), (8, 10), (9, 16), (10, 8), (11, 11), (12, 15), (2, 17), (3, 20), (4, 17), (5, 15),  (30, 16)]
# heap = [(0,0),(1, 4), (6, 7), (7, 7), (8, 10), (9, 16), (10, 8), (11, 110), (12, 15), (2, 17), (3, 20), (4, 17), (5, 15),  (30, 16),(22, 21)]
testHeapInv(heap)

True

In [63]:
from functools import reduce
all_edges = reduce(lambda xs, ys: xs + ys, adj_lst.values())
max_cost = max(list(map(lambda x: x[1], all_edges)))
inf_cost = max_cost*2

In [64]:
heap = [(0,0,0)] + [(z[0],z[1],1) for z in sorted(adj_lst[1], key = lambda x: x[1])]
nodes_not_linked_to_1 = [z for z in list(adj_lst.keys()) if z not in list(map(lambda x: x[0], adj_lst[1])) + [1]]
heap = heap + list(map(lambda x: (x, inf_cost, 1), nodes_not_linked_to_1))

In [65]:
from datetime import datetime

cost = 0
v = set(list(adj_lst.keys()))
x = {1}
e = []

execution_time_begin = datetime.now()
while (x != v):
    if(not testHeapInv(heap)):
        print('Damaged heap: ', heap)
        break
    
    min = extractMin(heap)
    x.add(min[0])
    cost += min[1]
    e.append((min[0], min[2]))

    # print(min[0],min[1])

    for w in [z for z in adj_lst[min[0]] if z[0] not in x]:
        w_heap = deleteHeapNode(heap, w[0])
        if(w[1] < w_heap[1]):
            # w_heap = w
            w_heap = (w[0], w[1], min[0])
        insertHeapNode(heap, w_heap)

execution_time_end = datetime.now()

print('MST cost (with heap): ', cost)
print('nodes in MST: ', len(x))
print('edges in MST: ', len(e))

execution_time = (execution_time_end - execution_time_begin).total_seconds() * 1000
print(f"The execution time is {execution_time} milliseconds")

MST cost (with heap):  12320
nodes in MST:  500
edges in MST:  499
The execution time is 1785.871 milliseconds


In [66]:
import networkx as nx
import matplotlib.pyplot as plt
from pyvis.network import Network

G_ALL_EDGES = nx.DiGraph(e)

# Add labels to each node
for node in G_ALL_EDGES.nodes():
    G_ALL_EDGES.nodes[node]['label'] = f'Node {node}' 

net = Network(notebook=True, cdn_resources='remote')

net.set_options("""
var options = {
    "physics": {
        "enabled": true,
        "stabilization": {
            "enabled": true,
            "iterations": 500
        },
        "maxVelocity": 20
    }
}
""")

net.from_nx(G_ALL_EDGES)
net.show("mst_heap.html")

mst_heap.html
