# Prim algorithm with heap data structure 

In [436]:
import math
import time
from collections import defaultdict

## Definition of Node and Graph objects for graph manipluation

In [437]:
class Node:
    def __init__(self, tag: int):
        self.tag = tag
        self.key = None
        self.parent = None
        self.isPresent = True
        self.index = tag-1 # Track the index of the node in the heap instead of using list.index() method which is O(n)
        self.adjacencyList = []

    # For test
    def print(self):
        print("tag =", self.tag, "adjList=", self.adjacencyList, "key=", self.key)

class Graph:
    def __init__(self):
        self.nodes = defaultdict(Node)

    def createNodes(self, nums: int):
        for i in range(1, nums+1): # nums+1 in order to cover the last node
            self.nodes[i] = Node(i)

    def addNode(self, tag: int, adjTag: int, adjCost: int):
        self.nodes[tag].adjacencyList.append([self.nodes[adjTag], adjCost])
        self.nodes[adjTag].adjacencyList.append([self.nodes[tag], adjCost]) # Graph is undirected

    def buildGraph(self, input):
        lines = input.readlines()
        self.createNodes(int(lines[0].split()[0])) # Extract number of vertexes and pass it to createNode
        lines.pop(0) # Remove the first line of .txt input file
        for line in lines:
            info = list(map(int, line.split())) # Convert all the strings deriving from split to int
            self.addNode(info[0], info[1], info[2])

## MinHeap data structure implementation

In [438]:
# ArrayHeap object extends list
class ArrayHeap(list):
    def __init__(self, array):
        super().__init__(array)
        self.heapSize = len(array)

class MinHeap:
    def __init__(self, array: list, root: Node):
        self.arrayHeap = ArrayHeap(array)
        # Check if the root node is not the first
        if self.arrayHeap[0] != self.arrayHeap[root.tag-1]: # reset the starting node and update all indexes
            rootNode = self.arrayHeap[root.tag-1]
            self.arrayHeap.remove(rootNode)
            self.arrayHeap.insert(0,rootNode)
            for i in range(0,self.arrayHeap.heapSize):
                self.arrayHeap[i].index = i

    # All the following methods work with zero based array. Hence, we need to handle separately odd and even indexes.
    def parent(self, i: int):
        if i%2 == 0: # even
            return i//2 - 1
        else:
            return i//2
    
    def left(self, i):
        return 2*i + 1
    
    def right(self, i):
        return 2*i + 2
    
    # Execution time: O(lg n)
    def minHeapify(self, i: int):
        l = self.left(i)
        r = self.right(i)
        if l <= self.arrayHeap.heapSize-1 and self.arrayHeap[l].key < self.arrayHeap[i].key:
            minimo = l
        else:
            minimo = i
        if r <= self.arrayHeap.heapSize-1 and self.arrayHeap[r].key < self.arrayHeap[minimo].key:
            minimo = r
        if minimo != i:
            self.arrayHeap[i].index, self.arrayHeap[minimo].index = minimo, i # Update indexes
            self.arrayHeap[i], self.arrayHeap[minimo] = self.arrayHeap[minimo], self.arrayHeap[i]
            self.minHeapify(minimo)

    def bubbleUp(self, index: int):
        parent = self.parent(index)
        current = index
        while current > 0 and self.arrayHeap[parent].key > self.arrayHeap[current].key:
            self.arrayHeap[current].index, self.arrayHeap[parent].index = parent, current # Update indexes
            self.arrayHeap[current], self.arrayHeap[parent] = self.arrayHeap[parent], self.arrayHeap[current]
            current = parent
            parent = self.parent(parent)

    # Execution time: O(lg n)
    # First we update the heap structure, then we remove the last element.
    def extractMin(self):
        if self.arrayHeap.heapSize < 0:
            print("Error: extractMin underflow")
            return
        else:
            minimo = self.arrayHeap[0] # Save the minimum node
            self.arrayHeap[0].isPresent = False # Set its flag to false
            self.arrayHeap[0] = self.arrayHeap[self.arrayHeap.heapSize - 1] # Assign the first node the last one
            self.arrayHeap[0].index = 0 # Update its index
            self.arrayHeap.pop(self.arrayHeap.heapSize-1) # Pop the last node
            self.arrayHeap.heapSize -= 1 # Decreasing heapsize
            self.minHeapify(0) # Call minHeapify in order to move the new first node to the correct position

            return minimo

## Prim algorithm

In [439]:
def MSTPrim(g: Graph, r: Node):
    start = time.time()
    for node in g.nodes.values():
        node.key = math.inf # Set key. Parent is already set through Node constructor.
    r.key = 0
    q = MinHeap(list(g.nodes.values()), r) # Pass also the root node in order to build the heap starting from it
    while len(q.arrayHeap) is not 0:
        u = q.extractMin()
        for v in u.adjacencyList:
            if v[0].isPresent and v[1] < v[0].key:
                v[0].parent = u
                v[0].key = v[1]
                q.bubbleUp(v[0].index) # bubbleUp maintains the minheap condition
    print("Start = node", r.tag,"\nMSTPrim execution time =", time.time() - start)

## Main

In [440]:
start = time.time()
startingNode = 1 # Root node tag
result = Graph()
result.buildGraph(open("dataset/input_random_01_10.txt", "r"))
MSTPrim(result, result.nodes.get(startingNode))
print("Program execution time =", time.time() - start)
sum = 0
for node in result.nodes.values():
    sum += node.key
print("Final cost =",sum, "\n")


Start = node 1 
MSTPrim execution time = 9.989738464355469e-05
Program execution time = 0.005508899688720703
Final cost = 29316 

