In [3]:
ok="""
         _          _                  _           _        
        /\ \       / /\               / /\        / /\      
       /  \ \     / /  \             / /  \      / /  \     
      / /\ \ \   / / /\ \           / / /\ \__  / / /\ \__  
     / / /\ \_\ / / /\ \ \         / / /\ \___\/ / /\ \___\ 
    / / /_/ / // / /  \ \ \        \ \ \ \/___/\ \ \ \/___/ 
   / / /__\/ // / /___/ /\ \        \ \ \       \ \ \       
  / / /_____// / /_____/ /\ \   _    \ \ \  _    \ \ \      
 / / /      / /_________/\ \ \ /_/\__/ / / /_/\__/ / /      
/ / /      / / /_       __\ \_\\ \/___/ /  \ \/___/ /       
\/_/       \_\___\     /____/_/ \_____\/    \_____\/        
                                                            
"""

In this notebook you will implement various data structures for the Union Find problem, as well as Kruskal's algorithm for finding the minimum spanning tree.

### Q1. Union Find

First, consider a different approach from the Disjoint Forest data structure taught in class. Recall in Union Find we must maintain a partition $S_0,\ldots,S_{k-1}$ of {0,...,n-1} which initially is a partition into singletons ($k=n$ and $S_i = \{i\}$). We must then support three operations:
* **find(x)**: return the name of the set containing $x\in\{0,1\ldots,n-1\}$
* **union(x,y)**: let $S_i$ be the partion containing $x$ and $S_j$ be that containing $y$. Then remove these two sets from our partition, and add in a new set $S_i\cup S_j$

Now, consider the following solution: we represent each element in $\{0,\ldots,n-1\}$ as a node, and each set in the partition as a linked list on the nodes in that set. Then the "name" of the set is simply the item represented by the head node of the linked list.

#### Q1a. Complete the implementation below

In [9]:
class Node:
    def __init__(self, val):
        self.val = val
        self.prev = None # the previous node in the linked list that this node is in; "None" means this is the beginning
        self.next = None # the next node in the linked list that this node is in; "None" means this is the end
        
class UnionFindLL:
    # implement Union Find using the linked list approach mentioned in the previous cell
    def __init__(self, n):
        self.n = n
        self.nodes = [None]*n
        for i in range(n):
            self.nodes[i] = Node(i)
    
    def find(self, x):
        """
        For you to implement. You should return the "val" of the head node of the linked list containing x's node.
        Note that you can obtain x's linked list node as self.nodes[x].
        """
        head = self.nodes[x]
        while head.prev:
            head = head.prev
        return head.val
        
    def union(self, x,y):
        """
        For you to implement. If x and y are in the same linked list, do nothing. Otherwise, append x's linked
        list to the end of y's, or vice versa (whichever you prefer!).
        """
        xr = self.find(x)
        yr = self.find(y)
        if xr != yr:
            curr = self.nodes[yr]
            while curr.next:
                curr = curr.next
            curr.next = self.nodes[xr]
            self.nodes[xr].prev = curr
            

#### Verification

In [11]:
import random

# staff Union Find implementation, using Disjoint Forest
class DisjointForest:
    def __init__(self, n):
        self.n = n
        self.rank = [0]*n
        self.parent = list(range(n))

    def find(self, x):
        if self.parent[x] == x:
            return x
        else:
            y = self.find(self.parent[x])
            self.parent[x] = y
            return y

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)
        if x != y:
            if self.rank[x] < self.rank[y]:
                x,y = y,x
            elif self.rank[x] == self.rank[y]:
                self.rank[x] += 1
            self.parent[y] = x

def test_union_find(structure, n):
    UF = structure(n)
    DF = DisjointForest(n)
    for _ in range(n):
        i = random.randint(0, n-1)
        j = random.randint(0, n-2)
        if j >= i:
            j += 1
        UF.union(i, j)
        DF.union(i, j)
    mapping1 = {}
    mapping2 = {}
    cur1,cur2 = 0,0
    for i in range(n):
        a = UF.find(i)
        if a not in mapping1:
            mapping1[a] = cur1
            cur1 += 1
        b = DF.find(i)
        if b not in mapping2:
            mapping2[b] = cur2
            cur2 += 1
    for i in range(n):
        assert mapping1[UF.find(i)] == mapping2[DF.find(i)]

test_union_find(UnionFindLL, 50)
print(ok)


         _          _                  _           _        
        /\ \       / /\               / /\        / /\      
       /  \ \     / /  \             / /  \      / /  \     
      / /\ \ \   / / /\ \           / / /\ \__  / / /\ \__  
     / / /\ \_\ / / /\ \ \         / / /\ \___\/ / /\ \___\ 
    / / /_/ / // / /  \ \ \        \ \ \ \/___/\ \ \ \/___/ 
   / / /__\/ // / /___/ /\ \        \ \ \       \ \ \       
  / / /_____// / /_____/ /\ \   _    \ \ \  _    \ \ \      
 / / /      / /_________/\ \ \ /_/\__/ / / /_/\__/ / /      
/ / /      / / /_       __\ \_\ \/___/ /  \ \/___/ /       
\/_/       \_\___\     /____/_/ \_____\/    \_____\/        
                                                            



In your implementation above, note that **find** is slow: it could take $\Omega(n)$ time (consider the case that $x$ is the last element in its linked list, in which case we have to walk all the way back to the beginning of the linked list before finding its set's name). This could be sped up by modifying the **Node** class to have an extra field called **first**, where node.first is the Node object that is at the beginning of node's linked list. Then **find** can be implemented to take $O(1)$ time. For reasons that will become clear soon, we'll also add another field **last** to the Node class, and we will maintain the invariant that *if* node is the head of its linked list, then node.last will be the tail (the last node) of its linked list (if node is not the head, we make no promises about what node.last will point to).

**union** though is still slow. Let's say L$_x$ is $x$'s linked list, and similarly for L$_y$, and let's say we would like to append L$_y$ to the end of L$_x$. Then we need to first do a **find** on each of $x$ and $y$ to obtain their heads ($O(1)$ time). Then given the previous paragraph, we can then find each of their tails as well in $O(1)$ time once we have their heads (just look at the **last** field). Then we can modify the tail of L$_x$ and the head of L$_y$ to point to each other, and voila! But not so fast: recall that each node is supposed to remember who its head is in the **first** field. The head has now changed *for every single node in L$_y$*! That's a lot of **first** fields that might need to be updated, which could again take $\Omega(n)$ time.

But, there's a way we'll be able to make this work! Recall there was a choice: in the implementation of **union** above, we could have either put x's linked list at the end of y's, or vice versa. Intuitively, which should we do? Well, we should make the choice that will cause us to have to update fewer **first** fields! Let **size(x)** denote the size (number of nodes) in L_$x$. Then, we show the following claim: 

**Claim:** Suppose that during **union**, if **size(x)** $\ge$ **size(y)** we append L$_y$ to the end of L$_x$; else, we append L$_x$ to the end of L_$y$. Then any sequence of $m$ calls to **union** and $f$ calls to **find** takes $O(m + f + n\log n)$ time.

**Proof:** Both implementations amount to (1) looking at a constant number of fields, and possibly (2) changing some number of **first** fields. All work of type (1) takes $O(m+f)$ time combined over all operations ($O(1)$ time per operation). For work of type (2), note that for each $x\in\{0,\ldots,n-1\}$, each time we change $x$'s **first** field, it is because $x$ was union'd into a set at least twice as big as its previous set (since the smaller list is appended to the end of the bigger one). Since each $x$ is initially in a set of size $1$ and can never be in a set of size more than $n$, the total work of type (2) is
$$
\sum_{x=0}^{n-1}\text{(times changed $x$'s first field)} \le \sum_{x=0}^{n-1} \log_2 n \le n\log_2 n
$$
Thus the total work, types (1) and (2) combined, is $O(m+f+n\log n)$.

#### Q1b. Below, implement this modified version of the linked list approach, which is faster.

In [16]:
class Node2:
    def __init__(self, val):
        """
        fill out the fields filled with elipses with the appropriate values based on the above description
        """
        self.val = val
        self.prev = None # the previous node in the linked list that this node is in; "None" means this is the beginning
        self.next = None # the next node in the linked list that this node is in; "None" means this is the end
        
        self.first = self
        self.last = self
        self.size = 1
        
class UnionFindLL2:
    # implement Union Find using the linked list approach mentioned in the previous cell
    def __init__(self, n):
        self.n = n
        self.nodes = [None]*n
        for i in range(n):
            self.nodes[i] = Node2(i)
    
    def find(self, x):
        """
        For you to implement. You should return the "val" of the head node of the linked list containing x's node.
        Note that you can obtain x's linked list node as self.nodes[x].
        """
        return self.nodes[x].first.val
        
    def union(self, x,y):
        """
        For you to implement. If x and y are in the same linked list, do nothing. Otherwise, append x's linked
        list to the end of y's, or vice versa (remember to append the smaller list to the larger one).
        Don't forget to update first/last/size as necessary, and also don't forget that only the head of a
        linked list needs to store 'last' and 'size' information; for non-head nodes, it is OK if these fields
        are stale.
        """
        x = self.find(x)
        y = self.find(y)
        if x != y:
            x = self.nodes[x]
            y = self.nodes[y]
            if x.size < y.size:
                x,y = y,x
            t = x.last
            t.next = y
            y.prev = t
            x.size += y.size
            x.last = y.last
            cur = y
            while cur != None:
                cur.first = x
                cur = cur.next
        

#### Verification

In [17]:
test_union_find(UnionFindLL2, 300)
print(ok)


         _          _                  _           _        
        /\ \       / /\               / /\        / /\      
       /  \ \     / /  \             / /  \      / /  \     
      / /\ \ \   / / /\ \           / / /\ \__  / / /\ \__  
     / / /\ \_\ / / /\ \ \         / / /\ \___\/ / /\ \___\ 
    / / /_/ / // / /  \ \ \        \ \ \ \/___/\ \ \ \/___/ 
   / / /__\/ // / /___/ /\ \        \ \ \       \ \ \       
  / / /_____// / /_____/ /\ \   _    \ \ \  _    \ \ \      
 / / /      / /_________/\ \ \ /_/\__/ / / /_/\__/ / /      
/ / /      / / /_       __\ \_\ \/___/ /  \ \/___/ /       
\/_/       \_\___\     /____/_/ \_____\/    \_____\/        
                                                            



### Q2. Kruskal's MST Algorithm

Now it is time to implement Kruskal's algorithm!

The input graph G is represented as a list of $n$ lists, where each element of G[u] (for $u \in \{0,...,n-1\}$ a vertex)
is a list of length 2: [v, w] means there's an edge $(u,v)$ of weight $w$.

You can assume that we will only feed simple graphs G to your implementation as input. Also G is undirected,
so if an edge (u,v) exists it will be found in both G[u] and G[v].

If the graph is not connected, you should return None.

In [32]:
def kruskal(G):
    # this is the list of edges you will return in the MST; each element of T should be a list [u,v] of size 2,
    # where (u,v) is an edge of the graph represented by G
    T = []
    
    edges = []
    """
    Make a single list 'edges' of all the edges in the graph, where each element of edges is a list of length
    3 of the form [w,u,v], representing an edge (u,v) of weight w
    """
    # for you to implement and populate 'edges'
    # ...
    #
    for u in range(len(G)):
        for e in G[u]:
            edges.append([e[1], u, e[0]])   
    ########
    ## BELOW IS STAFF SOLUTION -- REMOVE BEFORE RELEASING TO STUDENTS
    ######## END OF STAFF SOLUTION
    
    edges.sort()
    
    UF = UnionFindLL2(len(G))
    
    for e in edges:
        """
        figure out whether or not this edge should be added to T
        use Union Find data structure UF above, and call union and find as needed!
        """
        ########
        ## BELOW IS STAFF SOLUTION -- REMOVE BEFORE RELEASING TO STUDENTS
    ############ END OF STAFF SOLUTION
        u, v = e[1], e[2]
        if UF.find(u) != UF.find(v):
            UF.union(u, v)
            T.append([u, v])
            
    if len(T) != len(G) - 1:
        return None
    
    # don't forget to also check if you should return None
    return T
            

#### Verification

In [33]:
# staff MST implementation, using Prim
from heapq import heappush, heappop
def prim(G, s=0):
    visited = [False]*len(G)
    H = []
    T = [] # list of edges in the MST
    from_vertex = [-1]*len(G)
    keys = [float('inf')]*len(G)
    for u in range(len(G)):
        if u!=s: 
            heappush(H, (float('inf'), u))
    heappush(H, (0, s))
    keys[s] = 0
    while len(H) > 0:
        weight,u = heappop(H)
        if visited[u]:
            continue
        if weight == float('inf'):
            return None
        visited[u] = True
        if u != s:
            T.append([from_vertex[u], u])
        for v,w in G[u]:
            if w < keys[v] and not visited[v]:
                heappush(H, (w, v))
                keys[v] = w
                from_vertex[v] = u
    return T

# now test MST
def is_connected(T, n):
    # check if T is a connected graph on vertex set {0,...,n-1} using DFS
    visited = [False]*n
    G = [[] for _ in range(n)]
    for e in T:
        G[e[0]].append(e[1])
        G[e[1]].append(e[0])
        
    def explore(u):
        nonlocal G, visited
        for v in G[u]:
            if not visited[v]:
                visited[v] = True
                explore(v)
        
    visited[0] = True
    explore(0)
    for u in range(n):
        if not visited[u]:
            return False
    return True
    

import networkx as nx
for _ in range(30):
    n = 50
    weights = {}
    random_graph = nx.gnp_random_graph(n,0.1)
    edge_list = random_graph.edges
    G = [[] for _ in range(n)]
    for e in edge_list:
        w = random.randint(-1000, 1000)
        G[e[0]].append([e[1], w])
        G[e[1]].append([e[0], w])
        weights[e] = w 
        weights[(e[1], e[0])] = w
    T1 = kruskal(G)
    T2 = prim(G)
    if T1 == None or T2 == None:
        #print('not connected')
        assert T1 == T2
        continue
    #print('connected')
    assert len(T1) == n-1
    for e in T1:
        assert type(e)==list and len(e)==2 and e[0]!=e[1] and e[0]>=0 and e[0]<n and e[1]>=0 and e[1]<n and ((e[0], e[1]) in weights)
    w1,w2 = 0,0
    for e in T1:
        w1 += weights[(e[0], e[1])]
    for e in T2:
        w2 += weights[(e[0], e[1])]
    assert w1 == w2
    assert is_connected(T1, n)
print(ok)


         _          _                  _           _        
        /\ \       / /\               / /\        / /\      
       /  \ \     / /  \             / /  \      / /  \     
      / /\ \ \   / / /\ \           / / /\ \__  / / /\ \__  
     / / /\ \_\ / / /\ \ \         / / /\ \___\/ / /\ \___\ 
    / / /_/ / // / /  \ \ \        \ \ \ \/___/\ \ \ \/___/ 
   / / /__\/ // / /___/ /\ \        \ \ \       \ \ \       
  / / /_____// / /_____/ /\ \   _    \ \ \  _    \ \ \      
 / / /      / /_________/\ \ \ /_/\__/ / / /_/\__/ / /      
/ / /      / / /_       __\ \_\ \/___/ /  \ \/___/ /       
\/_/       \_\___\     /____/_/ \_____\/    \_____\/        
                                                            

