In [1]:
class Node:
    def __init__(self, data):
        self.data = data           
        self.edge = list() # To store tbe edge
 
    def __str__(self):
        return self.data     

    # Adds an edge to the node's edge list with a neighbor and weight
    def add_edge(self, neighbor, weight):
        self.edge.append((neighbor, weight))  


In [5]:
  def merge_sort(array, n):
    left = list()
    right = list()
    if n <= 1:
      return array
    else:
      middle = n//2
      for i in range(0, middle):
        left.append(array[i])
      for i in range(middle, n):
        right.append(array[i])
      left = merge_sort(left, len(left))
      right = merge_sort(right, len(right))
      return merge(left, right)

  def merge(left, right):
    result = list()
    while len(left) > 0 and len(right) > 0:
      if left[0] <= right[0]:
        result.append(left[0])
        left.pop(0)
      else:
        result.append(right[0])
        right.pop(0)
    if len(left) > 0:
      result.extend(left)
    if len(right) > 0:
      result.extend(right)
    return result

In [7]:
import numpy as np

# Except for the Prim function, the Graph class of dijkstra.ipynb in the class was used.
class Graph:
  def __init__(self):
    self.nodes = dict()

  def __str__(self):
    msg = ''
    for key in self.nodes:
      node = self.nodes[key]
      msg += node.data + ': '
      for edge in node.edge:
        msg += edge[0].data + '(' + str(edge[1]) + ') '
      msg += '\n'
    if msg != '':
      msg = msg[:len(msg)-1]
    return msg

  def insert_info(self, data_tuple):
    data_i = data_tuple[0]
    data_j = data_tuple[1]
    weight = data_tuple[2]

    node_i = self.get_node(data_i)
    node_j = self.get_node(data_j)

    node_i.add_edge(node_j, weight)
    node_j.add_edge(node_i, weight)

  def get_node(self, data):
    if data not in self.nodes:
      node = Node(data)
      self.nodes[data] = node
    return self.nodes[data]
    
  # Add prim fuction to find minimum spanning tree
  def prim(self, start_data):
        if start_data not in self.nodes:
            raise ValueError(f"{start_data} in not graph")

        mst = []
        visited = set()
        temp = [(0, start_data, None)]  # (weight, current_node, previous_node)
        total_weights = 0

        while len(visited) < len(self.nodes): # Iterat over the size of mst is same to graph
            temp = merge_sort(temp, len(temp)) # Sort temp by weight using merge_sort 
            min_weight, current_data, previous_data = temp.pop(0) # Get minimum edge and 
            if current_data in visited:  # If the data is already visited, skip it
                continue

            visited.add(current_data) # Add current_data the visited node set
            if previous_data is not None: # If there is a previous node, add the previous data, current data and weight to the MST
                mst.append((previous_data, current_data, weight))  
                total_weights += min_weight # Update total weight

            # Get the current node to check neighbor is visited
            current_node = self.nodes[current_data]
            for neighbor, weight in current_node.edge: # Iterate over all neighbor of the current node
              # If the neighbor is not visited, update the temp to change current data to previous data and neighbor to current data
                if neighbor.data not in visited: 
                    temp.append((weight, neighbor.data, current_node))  
                    
        # Print previous data
        for start, end, weight in mst:
          print(f"{start} - {end}")
        
        print(f"weight : {total_weights}")
      


In [8]:
graph = Graph()
data_tuples = [('A', 'B', 4), ('A', 'H', 8),
              ('B', 'C', 8), ('B', 'H', 11),
              ('C', 'D', 7),('C', 'F', 4), ('C', 'I', 2),
              ('D', 'F', 14), ('D', 'E', 9),
              ('E', 'F', 10), 
              ('F', 'G', 2),
              ('G', 'I', 6), ('G', 'H', 1),
              ('H', 'I', 7)]
for data_tuple in data_tuples:
  graph.insert_info(data_tuple)
print(graph)

A: B(4) H(8) 
B: A(4) C(8) H(11) 
H: A(8) B(11) G(1) I(7) 
C: B(8) D(7) F(4) I(2) 
D: C(7) F(14) E(9) 
F: C(4) D(14) E(10) G(2) 
I: C(2) G(6) H(7) 
E: D(9) F(10) 
G: F(2) I(6) H(1) 


In [10]:
graph.prim('A')

A - B
B - C
C - I
C - F
F - G
G - H
C - D
D - E
weight : 37
