In [None]:
'''
  Function to find Minimum Spanning Tree (MST) of an undirected graph using
    Prim algorithm
  Time complexity = O(E.log(V))

  Parameters:
  -----------
    graph: defaultdict
           An undirected graph dictionary: { u: [[v1, w1], [v2, w2]] }
    V    : integer
           Number of vertex

  Returns:
  --------
    mst    : list
             An undirected Minimum Spanning Tree list of edges [u, v]
    minCost: integer
             Total weights of MST

  Examples:
  --------- 
      Undirected graph            Minimum Spanning Tree
            10                            10
          0-----1                       0-----1
          |\    |                        \  
          | \   |     Prim Algorithm      \       Total weights of MST 
         6|  \5 |15  ================>     \5     = 10 + 5 + 4 = 19
          |   \ |                           \
          |    \|                            \
          2-----3                       2-----3
             4                             4 

    >>> graph = {0: [[1, 10], [2, 6], [3, 5]], 1: [[0, 10], [3, 15]], 
                 2: [[0, 6], [3, 4]], 3: [[0, 5], [1, 15], [2, 4]]}
    >>> print(MST_Prim(graph, V))
    ([[0, 1], [3, 2], [0, 3]], 19)

  References:
    https://www.geeksforgeeks.org/prims-mst-for-adjacency-list-representation-greedy-algo-6/?ref=rp
    https://www.youtube.com/watch?v=oP2-8ysT3QQ&list=PLrmLmBdmIlpu2f2g8ltqaaCZiq6GJvl1j&index=3
'''

class Heap(): 
  def __init__(self): 
    self.array = [] 
    self.size = 0
    self.pos = [] 
  
  def newMinHeapNode(self, v, dist): 
    minHeapNode = [v, dist] 
    return minHeapNode 
  
  # Function to swap two nodes of min heap (used for heapify)
  def swapMinHeapNode(self, a, b): 
    t = self.array[a] 
    self.array[a] = self.array[b] 
    self.array[b] = t 
  
  # Function to min-heapify at given index
  def minHeapify(self, idx): 
    smallest = idx 
    left = 2*idx + 1
    right = 2*idx + 2
  
    if left < self.size and self.array[left][1] < self.array[smallest][1]: 
      smallest = left 
    if right < self.size and self.array[right][1] < self.array[smallest][1]: 
      smallest = right 
  
    # Swap nodes in min heap if index is not smallest 
    if smallest != idx: 
      # Swap positions 
      self.pos[ self.array[smallest][0] ] = idx 
      self.pos[ self.array[idx][0] ] = smallest 
      # Swap nodes 
      self.swapMinHeapNode(smallest, idx) 
      # Then do min-heapify
      self.minHeapify(smallest) 
  
  # Function to extract minimum node from heap 
  def extractMin(self): 
    # Return NULL wif heap is empty 
    if self.isEmpty() == True: 
      return

    # Store the root node 
    root = self.array[0] 
    # Replace root node with last node 
    lastNode = self.array[self.size - 1] 
    self.array[0] = lastNode   
    # Update position of last node 
    self.pos[lastNode[0]] = 0
    self.pos[root[0]] = self.size - 1 
    # Reduce heap size and heapify root 
    self.size -= 1
    self.minHeapify(0) 
  
    return root 
  
  # Function to check if the heap is empty
  def isEmpty(self): 
    return True if self.size == 0 else False
  
  # Function to replace "dist" value of v and re-min-heapify
  def decreaseKey(self, v, dist):    
    i = self.pos[v] # Get the index of v in heap array
    self.array[i][1] = dist # Get the node and update its dist value 
    # Travel up while the complete tree is not heapified. Time = O(log(n)) 
    while i > 0 and self.array[i][1] < self.array[(i - 1) // 2][1]: 
      # Swap this node with its parent 
      self.pos[ self.array[i][0] ] = (i-1)//2
      self.pos[ self.array[(i-1)//2][0] ] = i 
      self.swapMinHeapNode(i, (i - 1)//2 )  
      # move to parent index 
      i = (i - 1) // 2
  
  # Function to check if a given vertex 'v' is in min heap 
  def isInMinHeap(self, v): 
    if self.pos[v] < self.size: 
      return True
    return False

# Main function
def MST_Prim(graph, V):       
  minDist = [] # Store min-distance from a vertex         
  parent = [] # List to store contructed MST  
  minHeap = Heap() # minHeap represents set E     
  
  # Initialize min heap with all vertices. All initialized min-distance is Inf 
  for v in range(V): 
    parent.append(-1) 
    minDist.append(float('inf')) 
    minHeap.array.append( minHeap.newMinHeapNode(v, minDist[v]) ) 
    minHeap.pos.append(v) 
  
  # Min-distance of 0-th vertex is 0 as it is extracted first 
  minHeap.pos[0] = 0
  minDist[0] = 0
  minHeap.decreaseKey(0, minDist[0]) 
  
  # Initially size of min heap is equal to V 
  minHeap.size = V
  
  # In the following loop, min heap contains all nodes not yet added in the MST
  while minHeap.isEmpty() == False: 
    # Extract the vertex with minimum distance value 
    newHeapNode = minHeap.extractMin() 
    u = newHeapNode[0] 
  
    # Run through all adjacent vertices of the extracted vertex u and update their distance values 
    for nodeInfo in graph[u]: 
      v, w = nodeInfo # Vertex v with weight w 
      # If shortest distance to v is not finalized yet, and distance to v through u is less than 
      # its previously calculated distance 
      if minHeap.isInMinHeap(v) and w < minDist[v]: 
        minDist[v] = w 
        parent[v] = u 
        # Update distance value in min heap also 
        minHeap.decreaseKey(v, minDist[v]) 
  
  mst = [ [v, u+1] for u, v in enumerate(parent[1:]) ] # The complete MST
  minCost = sum(minDist) # Total weights of the MST
  return mst, minCost