In [1]:
#minimum spanning tree is a subset of a graph
#where every vertex is connected to at least one other vertex
#but at most connected to two other vertices
#that indicates no cycle
#and the total weight of the graph is the minimum possible
import copy

In [2]:
#Prim's algorithm is perfect for the broadcast problem
#details of broadcast problem can be found in the following link
# http://interactivepython.org/runestone/static/pythonds/Graphs/PrimsSpanningTreeAlgorithm.html

#say you wanna broadcast to the far end user
#and also everyone along the way
#prim is a perfect solution
#but if the destination is not at the far end
#we could miss a few vertices along the way

#the algorithm is somewhat similar to bfs 
#details about bfs are in the following link
# https://github.com/je-suis-tm/graph-theory/blob/master/BFS%20DFS%20on%20DCG.ipynb
#except we dont pop children vertice by left to right order
#we always pop the vertex with minimum weight on the edge
#which is why we can solve the problem in priority heap queue as well

#plz note that we always find the vertex with minimum weight in our queue
#sometimes it may not be the shortest path to our destination
#think of it as a lazy guy who tries to go home
#he is on a triangle with edges 3,4,5
#he always picks the road that is the shortiest to the next turn
#from one point to another, he is going to take 3+4 to the destination
#however, there is a direct route which only costs 5
#if we try to find the shortest path to the destination
#dijkstra is an ideal solution
#details of dijkstra can be found in the following link
# https://github.com/je-suis-tm/graph-theory/blob/master/dijkstra%20shortest%20path.ipynb
class graph:
    def __init__(self):
        self.graph={}
        self.visited={}

    def append(self,vertexid,edge,weight):
        if vertexid not in self.graph.keys():          
            self.graph[vertexid]={}
            self.visited[vertexid]=0
        self.graph[vertexid][edge]=weight

    def reveal(self):
        return self.graph
    
    def vertex(self):
        return list(self.graph.keys())

    def edge(self,vertexid):
        return list(self.graph[vertexid].keys())
    
    def weight(self,vertexid,edge):
        
        return (self.graph[vertexid][edge])
    
    def size(self):
        return len(self.graph)
    
    def visit(self,vertexid):
        self.visited[vertexid]=1
    
    def go(self,vertexid):
        return self.visited[vertexid]
    
    def route(self):
        return self.visited

In [3]:
#we use a dictionary instead of a list as queue
#cuz we need to pop the vertex with minimum weight on the edge
#result is a list that keeps the order of vertices we have visited
#this is the same concept as topological sort
#details of topological sort can be found in the following link
# https://github.com/je-suis-tm/graph-theory/blob/master/topological%20sort.ipynb
#end point is optional

def prim(df,start,end=None):
    
    #queue keeps track of all the edges
    #we select the one with the minimum weight
    queue={}
    queue[start]=0
    
    #route keeps track of how we travel from one vertex to another
    route={}
    route[start]=start
    result=[]
    edge={}
    
    while queue:
        print(queue)
        
        #note that when we have the same values
        #they all are the minimum value
        #dictionary would pop the one with the smallest key
        key=min(queue,key=queue.get)
        queue.pop(key)
        result.append(key)
        df.visit(key)

        if key==end:       
            output=[[route[i],i] for i in result]
            return output
        
        for i in df.edge(key):
            if i not in queue and df.go(i)==0:
                queue[i]=df.weight(key,i)
                route[i]=key
                
            #everytime we find a smaller weight to visit an unvisited vertex
            #we need to update the smaller weight in queue
            if i in queue and queue[i]>df.weight(key,i):
                queue[i]=df.weight(key,i)
                route[i]=key                
                    
        
    output=[[route[i],i] for i in result]
    output.remove([start,start])
    return output

In [4]:
#note that for prim to work
#we need an undirected graph
#in another word, vertices with edge connections 
#are mutually connected to each other
df=graph()
df.append(1,2,6)
df.append(1,3,5)
df.append(2,1,6)
df.append(2,4,8)
df.append(2,6,3)
df.append(3,1,5)
df.append(3,4,2)
df.append(3,5,7)
df.append(4,2,8)
df.append(4,3,2)
df.append(4,5,7)
df.append(5,3,3)
df.append(5,4,7)
df.append(5,7,9)
df.append(6,2,3)
df.append(6,7,5)
df.append(7,5,9)
df.append(7,6,5)
df.append(7,8,13)
df.append(8,7,13)

![alt text](./preview/prim.jpg)

In [5]:
prim(copy.deepcopy(df),1)

{1: 0}
{2: 6, 3: 5}
{2: 6, 4: 2, 5: 7}
{2: 6, 5: 7}
{5: 7, 6: 3}
{5: 7, 7: 5}
{5: 7, 8: 13}
{8: 13}


[[1, 3], [3, 4], [1, 2], [2, 6], [6, 7], [3, 5], [7, 8]]

In [6]:
def trace_root(disjointset,target):

    if disjointset[target]!=target:
        trace_root(disjointset,disjointset[target])
    else:
        return target

In [7]:
def kruskal(df):
    
    d={}
    output=[]
    
    #use dictionary to sort edges by weight
    for i in df.vertex():
        for j in df.edge(i):
            #convert edge into string
            #as the graph is bidirected
            #we only need one edge for each pair of two vertices
            if f'{j}-{i}' not in d.keys():
                d[f'{i}-{j}']=df.weight(i,j)

    sort=sorted(d.items(), key=lambda x: x[1])
    
    #to achieve minimum spanning tree
    #we need to avoid cycles
    #using disjointset to detect cycle
    #for more details, you can go to geeksforgeeks
    # https://www.geeksforgeeks.org/union-find/
    disjointset={}
    
    #lets skip the part where default=-1
    for i in df.vertex():
        disjointset[i]=i
    
    
    for i in sort:
        
        parent=int(i[0].split('-')[0])
        child=int(i[0].split('-')[1])
        
        print(f'from {parent} to {child} at {df.weight(parent,child)}')
        
        #first we need to check disjoint set
        #if it already has indicated cycle
        #trace_root function will go to infinite loops
        if disjointset[parent]!=disjointset[child]:
            #if we trace back to the root of the tree
            #and it indicates no cycle
            #we update the disjoint set and add edge into output
            if trace_root(disjointset,parent)!=trace_root(disjointset,child):
                disjointset[child]=parent
                output.append([parent,child])   
                
                
    return output

In [8]:
kruskal(copy.deepcopy(df))

from 3 to 4 at 2
from 2 to 6 at 3
from 1 to 3 at 5
from 6 to 7 at 5
from 1 to 2 at 6
from 3 to 5 at 7
from 4 to 5 at 7
from 2 to 4 at 8
from 5 to 7 at 9
from 7 to 8 at 13


[[3, 4], [2, 6], [1, 3], [6, 7], [1, 2], [3, 5], [7, 8]]