In this programming problem you'll code up Prim's minimum spanning tree algorithm.

This file describes an undirected graph with integer edge costs. It has the format

[number_of_nodes] [number_of_edges]

[one_node_of_edge_1] [other_node_of_edge_1] [edge_1_cost]

[one_node_of_edge_2] [other_node_of_edge_2] [edge_2_cost]

...

For example, the third line of the file is "2 3 -8874", indicating that there is an edge connecting vertex #2 and vertex #3 that has cost -8874.

You should NOT assume that edge costs are positive, nor should you assume that they are distinct.


Your task is to run Prim's minimum spanning tree algorithm on this graph. You should report the overall cost of a minimum spanning tree --- an integer, which may or may not be negative --- in the box below.

IMPLEMENTATION NOTES: This graph is small enough that the straightforward O(mn) time implementation of Prim's algorithm should work fine. OPTIONAL: For those of you seeking an additional challenge, try implementing a heap-based version. The simpler approach, which should already give you a healthy speed-up, is to maintain relevant edges in a heap (with keys = edge costs). The superior approach stores the unprocessed vertices in the heap, as described in lecture. Note this requires a heap that supports deletions, and you'll probably need to maintain some kind of mapping between vertices and their positions in the heap.

# Notes on Prim's MST

## Finding a Minimum Spanning Tree (MST)
**Input:** Undirected graph $G = (V, E)$ and a cost $c_e$ for each edge $e \in E$.<br>
**Output:** minimum cost tree $T \subseteq E$ that spans all vertices. Note that $T$ has no cycles and the subgraph $V, T$ is connected.
**Assumption:** Input graph $G$ is connected. Otherwise there is no spanning tree.

## Prim's Algorithm
```python
def prim_mst(G):
  Initialize X = {s} [s ∈ V chosen arbitrarily]
  T = ∅ [invariant: X = vertices spanned by tree-so-far T ]
  
  # Increase number of spanned vertices in cheapest way possible.
  while X != V:
    Let e = (u,v) be the cheapest edge of G with u in X, v not in X.
    Add e to T.
    Add v to X.
  
  return T
```
```python
def prim_mst_heap(G):
  # select node 1 of G to begin with
  s = 1

  T = {s: 0}            # the final MST
  Q = binary_heap()     # nodes not in T
    
  # initialize heap with infinite distance
  for node in G.keys():
    if node != s:
      T[node] = float('Infinity')
    Q.insert((node, T[node]))

  while not Q.is_empty():
    v, vd = Q.extract_min()
    for u, ud in G[v]:
      if ud < T[u]:
        T[u] = ud
        Q.update_node_value(u, ud)

  return T
```

In [1]:
# modified the initialization of my previous heap for Dijkstra
class binary_heap:
    """ Binary heap for Dijkstra's algorithm
    array -- a list of node-distance pair
    """
    
    def __init__(self, array=[]):
        self.data = []         # node-distance pairs
        self.indices = {}      # map node to its index
        self.size = len(array) # the number of nodes currently in the heap
        
        if self.size != 0:
            for i, pair in enumerate(array):
                node, distance = pair
                self.data.append(pair)
                self.indices[node] = i
            self.heapify()
        
        return
    
    def __repr__(self):
        return "Key value:\n" + str(self.data) + "\nNode locataion:\n" + str(self.indices)
    
    def __contains__(self, node):
        return node in self.indices
    
    def is_empty(self):
        return self.size == 0
    
    def validate_index(self, i):
        """ Check if index i lies in bound. """
        if i < 0 or i >= self.size:
            print "Index i = {0}, size of heap: {1}".format(i, self.size)
            raise ValueError("Index out of range.")
        return
    
    def parent_index(self, i):
        """ Return the parent index of child i. """
        self.validate_index(i)
        return (i - 1) / 2 if i != 0 else 0
    
    def children_indices(self, i):
        """ Return the children indices of parent i. """
        self.validate_index(i)
        return 2 * i + 1, 2 * i + 2
    
    def is_leaf(self, i):
        """ Check if index i is a leaf or not. """
        self.validate_index(i)
        c1, c2 = self.children_indices(i)
        return c1 >= self.size and c2 >= self.size
    
    def one_child(self, i):
        """ Check if parent of index i has only one child. """
        self.validate_index(i)
        c1, c2 = self.children_indices(i)
        return c1 < self.size and c2 >= self.size
    
    def min_value_child(self, i):
        """ Return the child index of parent i with the smaller value. """
        self.validate_index(i)
        c1, c2 = self.children_indices(i)
        
        c = None
        if c2 < self.size: # node i has two children
            c = c1
            if self.data[c1][1] > self.data[c2][1]:
                c = c2
        else:
            if c1 < self.size: # node i has only one child
                c = c1
            else:              # node i is a leaf
                c = i
        return c
    
    def up_heapify(self, i):
        """ Bubble up from index i. """
        self.validate_index(i)
        ic = i
        ip = self.parent_index(ic)
        while self.data[ic][1] < self.data[ip][1]:
            node_c, node_p = self.data[ic][0], self.data[ip][0]
            self.data[ic], self.data[ip] = self.data[ip], self.data[ic] # swap data
            self.indices[node_c], self.indices[node_p] = ip, ic         # update index map
            ic = ip
            ip = self.parent_index(ic)
        return
    
    def down_heapify(self, i):
        """ Bubble down from index i. """
        self.validate_index(i)
        ip = i
        ic = self.min_value_child(ip)
        while self.data[ic][1] < self.data[ip][1]:
            node_c, node_p = self.data[ic][0], self.data[ip][0]
            self.data[ic], self.data[ip] = self.data[ip], self.data[ic] # swap data
            self.indices[node_c], self.indices[node_p] = ip, ic         # update index map
            ip = ic
            ic = self.min_value_child(ip)
        return
    
    def heapify(self):
        """ Heapify the current self """
        start = self.parent_index(self.size - 1)
        while start >= 0:
            self.down_heapify(start)
            start -= 1
        return
    
    def insert(self, pair):
        self.data.append(pair)
        self.indices[pair[0]] = self.size
        self.size += 1
        self.up_heapify(self.size - 1)
        return
    
    def get_min(self):
        return self.data[0]
    
    def get_value(self, node):
        if node in self.indices:
            i = self.indices[node]
            return self.data[i][1]
        else:
            raise ValueError("Requested node not in the heap!")
    
    def extract_min(self):
        if self.size == 0:
            raise ValueError("Cannot extract min from an empty heap!")
        
        m_node, m_value = self.get_min()
        self.data[0] = self.data[-1]
        self.data.pop()
        self.size -= 1
        self.indices.pop(m_node)
        if self.size > 0:
            self.indices[self.data[0][0]] = 0
            self.down_heapify(0)
        return (m_node, m_value)
    
    def update_node_value(self, node, value):
        if node in self.indices:
            i = self.indices[node]
            self.data[i] = (node, value)
            
            ip = self.parent_index(i)
            ic = self.min_value_child(i)
            if value < self.data[ip][1]:
                self.up_heapify(i)
            elif value > self.data[ic][1]:
                self.down_heapify(i)
        else:
            raise ValueError("Requested node not in the heap!")
        return

In [2]:
# some tests of binary_heap
h = binary_heap([(1,0),(2,10),(3,4),(4,4),(6,8)])
print "Original", h
print "Is empty ?", h.is_empty()
h.insert((5,1))
print "After insert (5, 1)", h
h.extract_min()
print "After extract_min", h
h.update_node_value(2, 8)
print "After updating value of node 2 to 8", h
h.update_node_value(5, 6)
print "After updating value of node 5 to 6", h
h.update_node_value(2, 1)
print "After updating value of node 2 to 1", h
print h.is_leaf(0), h.is_leaf(3)

Original Key value:
[(1, 0), (4, 4), (3, 4), (2, 10), (6, 8)]
Node locataion:
{1: 0, 2: 3, 3: 2, 4: 1, 6: 4}
Is empty ? False
After insert (5, 1) Key value:
[(1, 0), (4, 4), (5, 1), (2, 10), (6, 8), (3, 4)]
Node locataion:
{1: 0, 2: 3, 3: 5, 4: 1, 5: 2, 6: 4}
After extract_min Key value:
[(5, 1), (4, 4), (3, 4), (2, 10), (6, 8)]
Node locataion:
{2: 3, 3: 2, 4: 1, 5: 0, 6: 4}
After updating value of node 2 to 8 Key value:
[(5, 1), (4, 4), (3, 4), (2, 8), (6, 8)]
Node locataion:
{2: 3, 3: 2, 4: 1, 5: 0, 6: 4}
After updating value of node 5 to 6 Key value:
[(4, 4), (5, 6), (3, 4), (2, 8), (6, 8)]
Node locataion:
{2: 3, 3: 2, 4: 0, 5: 1, 6: 4}
After updating value of node 2 to 1 Key value:
[(2, 1), (4, 4), (3, 4), (5, 6), (6, 8)]
Node locataion:
{2: 0, 3: 2, 4: 1, 5: 3, 6: 4}
False True


In [3]:
# Read the graph from the file with name filename
def read_graph(filename):
    G = {}
    nV, nE = 0, 0
    for line in open(filename, 'r'):
        ls = line.split()
        if len(ls) == 2:
            nV, nE = int(ls[0]), int(ls[1])
        else:
            key = int(ls[0])
            node, distance = int(ls[1]), int(ls[2])
            try:
                G[key].append((node, distance))
            except:
                G[key] = [(node, distance)]
            
            # also save the reverse direction
            try:
                G[node].append((key, distance))
            except:
                G[node] = [(key, distance)]
    
    if DEBUG:
        print "File input:\n  Number of vertices: {0}\n  Number of edges: {1}".format(nV, nE)
        print G
    assert len(G) == nV, "Error in read_graph: number of vertices does not match the number in {0}".format(filename)
    assert sum(len(G[n]) for n in G.keys()) == 2 * nE, "Error in read_graph: number of edges does not match the number in {0}".format(filename)
    
    return G

In [4]:
def prim_mst_heap(G):
    # select node 1 of G to begin with
    s = 1
    
    # the final MST
    T = {s: 0}
    
    # nodes not in T so far
    Q = binary_heap([(node, float('Infinity')) for node in G.keys()])
    Q.update_node_value(s, T[s])
    
    # initialize heap with infinite distance
    explored = {}
    for node in G.keys():
        explored[node] = False
        if node != s:
            T[node] = float('Infinity')
    
    while not Q.is_empty():
        v, vd = Q.extract_min()
        explored[v] = True
        for u, ud in G[v]:
            if ud < T[u] and not explored[u]:
                T[u] = ud
                Q.update_node_value(u, ud)

    return T

In [5]:
DEBUG = True

# test case 1
# total cost of MST: 7
G = read_graph('./test1.txt')
T = prim_mst_heap(G)
assert sum(T.values()) == 7, "prim_mst_heap does not pass test1"
print "prim_mst_heap passes test1"

# test case2
# total cost of MST: 14
G = read_graph('./test2.txt')
T = prim_mst_heap(G)
assert sum(T.values()) == 14, "prim_mst_heap does not pass test2"
print "prim_mst_heap passes test2"

File input:
  Number of vertices: 4
  Number of edges: 5
{1: [(2, 1), (3, 4), (4, 3)], 2: [(1, 1), (4, 2)], 3: [(1, 4), (4, 5)], 4: [(2, 2), (3, 5), (1, 3)]}
prim_mst_heap passes test1
File input:
  Number of vertices: 6
  Number of edges: 10
{1: [(2, 6), (4, 5), (5, 4)], 2: [(1, 6), (4, 1), (5, 2), (3, 5), (6, 3)], 3: [(2, 5), (6, 4)], 4: [(1, 5), (2, 1), (5, 2)], 5: [(1, 4), (2, 2), (4, 2), (6, 4)], 6: [(2, 3), (3, 4), (5, 4)]}
prim_mst_heap passes test2


In [6]:
# timer grabbed from 
# https://stackoverflow.com/questions/7370801/measure-time-elapsed-in-python
from timeit import default_timer as timer
class benchmark(object):
    def __init__(self, msg, fmt="%0.3g"):
        self.msg = msg
        self.fmt = fmt

    def __enter__(self):
        self.start = timer()
        return self

    def __exit__(self, *args):
        t = timer() - self.start
        print(("%s : " + self.fmt + " seconds") % (self.msg, t))
        self.time = t

In [7]:
DEBUG = False

with benchmark("Read Graph") as r:
    G = read_graph('./edges.txt')

with benchmark("Heap implementation O[m * log(n)]") as r:
    T = prim_mst_heap(G)

print sum(T.values())

Read Graph : 0.0203 seconds
Heap implementation O[m * log(n)] : 0.0474 seconds
