In [None]:
# default_exp prim 

# Prim

> spanning tree for joint graphs

In [None]:
# export
from queue import Queue
from typing import List, Dict
from collections import namedtuple

import graph_utils.core as gu
from graph_utils.dijsktra import priority_dict

In [None]:
# export
Row = namedtuple('Node', ['distance_from_source', 'preceding_vertex'])

In [None]:
g = gu.AdjacencyMatrixGraph(8, directed=False)
g.add_edge(0, 1,1 )
g.add_edge(1, 2,2 )
g.add_edge(1, 3,2 )
g.add_edge(2, 3,2 )
g.add_edge(1, 4,3 )
g.add_edge(3, 5,1 )
g.add_edge(5, 4,3 )
g.add_edge(3, 6,1 )
g.add_edge(6, 7,1 )
g.add_edge(0, 7,1 )

In [None]:
# export
def spanning_tree(graph:gu.Graph, source):
    distance_table = dict()

    # initiate an empty distance table
    for i in range(graph.numVertices):
        #distance_table[i] = Row(distance_from_source=None, preceding_vertex=None)
        distance_table[i] = Row(None, None)

    # distance to the source from itself is 0
    distance_table[source] = Row(0, source)

    # Holds mapping of vertex id to distance
    # from source. Access the highest priority (lowest distance)
    # item first
    priority_queue = priority_dict()
    # priority_queue[vertex number] = distance

    priority_queue[source] = 0

    visited_vertices = set()

    # Set of edges where each edge represented as a string
    # "1->2 is an edge between vertices 1 and 2"
    spanning_tree = set()

    while priority_queue:

        current_vertex = priority_queue.pop_smallest()

        if current_vertex in visited_vertices:
            continue

        visited_vertices.add(current_vertex)

        if current_vertex != source:
            last_vertex = distance_table[current_vertex][1]

            edge = f"{last_vertex}->{current_vertex}"

            if edge not in spanning_tree:
                spanning_tree.add(edge)

        for neighbor in graph.get_adjacent_vertices(current_vertex):

            distance = g.get_edge_weight(current_vertex, neighbor)

            neighbor_distance = distance_table[neighbor][0]

            if neighbor_distance is None or neighbor_distance > distance:
                distance_table[neighbor] = Row(distance, current_vertex)
                priority_queue[neighbor] = distance




    return spanning_tree

In [None]:
assert spanning_tree(g, 3) == {'0->1', '3->2', '3->5', '3->6', '5->4', '6->7', '7->0'}