In [2]:
%load_ext autoreload
%autoreload 2
import copy
from graph_algo.adt import LinkedGraph as LG, MatrixGraph as MG, Vertex, Edge, MapGraph
import  math
from data_structures import LinkedQueue, LinkedStack, SortedPriorityQueue
from graph_algo.union_find import UF

In [3]:
def make_graph(string, max_vertices, directed=False):
    g = MapGraph(directed=directed)
    vertices = [Vertex(i) for i in range(max_vertices + 1)]
    for uv in string.split(" "):
        u, v, w = map(lambda x: int(x[1]) if x[0]<2 else float(x[1]),
                   enumerate(uv.split("-")))
        g.insert_edge(Edge(vertices[u], vertices[v], w))
    return g


In [4]:


g = make_graph("0-6-0.51 0-1-0.32 0-2-0.29 4-3-0.34 5-3-0.18 "
               "7-4-0.46 5-4-0.4 0-5-0.6 6-4-0.51 7-0-0.31 7-6-0.25 7-1-0.21", 7)

In [5]:
g

Edge(Vertex(value=0), Vertex(value=1), 0.32)
Edge(Vertex(value=0), Vertex(value=2), 0.29)
Edge(Vertex(value=0), Vertex(value=5), 0.6)
Edge(Vertex(value=0), Vertex(value=6), 0.51)
Edge(Vertex(value=4), Vertex(value=3), 0.34)
Edge(Vertex(value=5), Vertex(value=3), 0.18)
Edge(Vertex(value=5), Vertex(value=4), 0.4)
Edge(Vertex(value=6), Vertex(value=4), 0.51)
Edge(Vertex(value=7), Vertex(value=0), 0.31)
Edge(Vertex(value=7), Vertex(value=1), 0.21)
Edge(Vertex(value=7), Vertex(value=4), 0.46)
Edge(Vertex(value=7), Vertex(value=6), 0.25)

In [6]:
for i, edge in enumerate(g.edges):
    print(i, edge, edge.weight)

0 Edge(Vertex(value=0), Vertex(value=1), 0.32) 0.32
1 Edge(Vertex(value=0), Vertex(value=2), 0.29) 0.29
2 Edge(Vertex(value=0), Vertex(value=5), 0.6) 0.6
3 Edge(Vertex(value=0), Vertex(value=6), 0.51) 0.51
4 Edge(Vertex(value=4), Vertex(value=3), 0.34) 0.34
5 Edge(Vertex(value=5), Vertex(value=3), 0.18) 0.18
6 Edge(Vertex(value=5), Vertex(value=4), 0.4) 0.4
7 Edge(Vertex(value=6), Vertex(value=4), 0.51) 0.51
8 Edge(Vertex(value=7), Vertex(value=0), 0.31) 0.31
9 Edge(Vertex(value=7), Vertex(value=1), 0.21) 0.21
10 Edge(Vertex(value=7), Vertex(value=4), 0.46) 0.46
11 Edge(Vertex(value=7), Vertex(value=6), 0.25) 0.25


In [7]:

def find_prim_mst(g):
    "Prim mst algorithm"

    root = None
    for v in g.vertices:
        v.is_visited = False
        v.parent = None
        v.mst_weight = float("inf")
        if root is None:
            v.mst_weight = 0
            root = v

    q = SortedPriorityQueue()
    q.add(root.mst_weight, root)
    mst = MapGraph()
    while not q.is_empty():
        _, vertex = q.remove_min()
        vertex.is_visited = True
        if vertex.parent is not None:
            mst.insert_edge(Edge(vertex.parent, vertex, vertex.mst_weight))
        for adj in g.adjacent(vertex):
            if not adj.is_visited:
                edge = g.get_edge(vertex, adj)
                if not q.contains(adj):
                    adj.parent = vertex
                    adj.mst_weight = edge.weight
                    q.add(adj.mst_weight, adj)

                elif edge.weight < adj.mst_weight:
                    assert adj.mst_weight < float("inf")
                    adj.mst_weight = edge.weight
                    adj.parent = vertex

    return mst

        
    
min_span_tree = find_prim_mst(copy.deepcopy(g))

0 Vertex(value=0)
0.32 Vertex(value=1)
0.29 Vertex(value=2)
0.6 Vertex(value=5)
0.51 Vertex(value=6)
0.31 Vertex(value=7)
0.46 Vertex(value=4)
0.34 Vertex(value=3)


In [8]:
min_span_tree

Edge(Vertex(value=0), Vertex(value=2), 0.29)
Edge(Vertex(value=0), Vertex(value=7), 0.31)
Edge(Vertex(value=3), Vertex(value=5), 0.18)
Edge(Vertex(value=4), Vertex(value=3), 0.34)
Edge(Vertex(value=7), Vertex(value=1), 0.21)
Edge(Vertex(value=7), Vertex(value=4), 0.46)
Edge(Vertex(value=7), Vertex(value=6), 0.25)

In [9]:
def find_krus_mst(g):
    "Kuskal mst algorithm"
    edges = list(g.edges)
    edges = sorted(edges, key=lambda edge: edge.weight)
    mst = MapGraph()
    uf = UF(g.vertices)
    for edge in edges:
        if mst.count_edges() == g.count_vertices() - 1:
            break
        head, tail = edge
        if uf.find(head) != uf.find(tail):
            uf.union(head, tail)
            mst.insert_edge(edge)
            head.mst, tail.mst = True, True

    return mst
min_span_tree = find_krus_mst(make_graph("0-6-0.51 0-1-0.32 0-2-0.29 4-3-0.34 5-3-0.18 "
               "7-4-0.46 5-4-0.4 0-5-0.6 6-4-0.51 7-0-0.31 7-6-0.25 7-1-0.21", 7))

In [10]:
min_span_tree

Edge(Vertex(value=0), Vertex(value=2), 0.29)
Edge(Vertex(value=4), Vertex(value=3), 0.34)
Edge(Vertex(value=5), Vertex(value=3), 0.18)
Edge(Vertex(value=7), Vertex(value=0), 0.31)
Edge(Vertex(value=7), Vertex(value=1), 0.21)
Edge(Vertex(value=7), Vertex(value=4), 0.46)
Edge(Vertex(value=7), Vertex(value=6), 0.25)