In [None]:
# default_exp module3

# Import

In [None]:
# export
import numpy as np

from cs371.utils import *

# Disjoint Set

## Do naive approach

In [None]:
# export
class MyDisjointSet:
    def __init__(self, N): 
        self.N = N
        self.bubbles = [{i} for i in range(N)]
        
    def _get_bubble_idx(self, val):
        for idx, bubble in enumerate(self.bubbles):
            if val in bubble: return idx
        return -1
    
    def find(self, i, j): 
        idx1, idx2 = map(self._get_bubble_idx, [i, j])
        if idx1 != -1 and idx1 == idx2: return True
        else:                           return False
                
    def union(self, i, j):
        idx1, idx2 = map(self._get_bubble_idx, [i, j])
        if idx1 != -1 and idx2 != -1 and idx1 != idx2:
            self.bubbles[idx1] = self.bubbles[idx1] | self.bubbles[idx2]
            del self.bubbles[idx2]

Run examples from http://www.ctralie.com/Teaching/CS371_S2021/ClassExercises/Week2/Week2_UnionFind/

In [None]:
ds = MyDisjointSet(10)
assert_allclose(ds.bubbles, [{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9}])
ds.union(0, 2)
assert_allclose(ds.bubbles, [{0, 2}, {1}, {3}, {4}, {5}, {6}, {7}, {8}, {9}])
ds.union(1, 8)
assert_allclose(ds.bubbles, [{0, 2}, {1, 8}, {3}, {4}, {5}, {6}, {7}, {9}])
ds.union(8, 7)
assert_allclose(ds.bubbles, [{0, 2}, {1, 8, 7}, {3}, {4}, {5}, {6}, {9}])
assert_allclose(ds.find(0, 3), False)
assert_allclose(ds.find(1, 7), True)
ds.union(1, 6)
assert_allclose(ds.bubbles, [{0, 2}, {1, 8, 7, 6}, {3}, {4}, {5}, {9}])
ds.union(0, 1)
assert_allclose(ds.bubbles, [{0, 1, 2, 6, 7, 8}, {3}, {4}, {5}, {9}])
assert_allclose(ds.find(0, 7), True)
assert_allclose(ds.find(1, 9), False)

In [None]:
# export
class MyDisjointSet2:
    def __init__(self, N): 
        self.N = N
        self.idx_bubbles = np.arange(N)
        
    def _get_bubble_idx(self, val): return self.idx_bubbles[val]
    
    def find(self, i, j): 
        idx1, idx2 = map(self._get_bubble_idx, [i, j])
        if idx1 != -1 and idx1 == idx2: return True
        else:                           return False
                
    def union(self, i, j):
        idx1, idx2 = map(self._get_bubble_idx, [i, j])
        if idx1 != -1 and idx2 != -1 and idx1 != idx2:
            idx1, idx2 = np.sort([idx1, idx2])
            self.idx_bubbles[self.idx_bubbles == idx2] = idx1

In [None]:
ds = MyDisjointSet2(10)
ds.union(0, 2)
assert_allclose(ds.idx_bubbles, [0, 1, 0, 3, 4, 5, 6, 7, 8, 9])
ds.union(1, 8)
assert_allclose(ds.idx_bubbles, [0, 1, 0, 3, 4, 5, 6, 7, 1, 9])
ds.union(8, 7)
assert_allclose(ds.idx_bubbles, [0, 1, 0, 3, 4, 5, 6, 1, 1, 9])
assert_allclose(ds.find(0, 3), False)
assert_allclose(ds.find(1, 7), True)
ds.union(1, 6)
assert_allclose(ds.idx_bubbles, [0, 1, 0, 3, 4, 5, 1, 1, 1, 9])
ds.union(0, 1)
assert_allclose(ds.idx_bubbles, [0, 0, 0, 3, 4, 5, 0, 0, 0, 9])
assert_allclose(ds.find(0, 7), True)
assert_allclose(ds.find(1, 9), False)

In [None]:
# export
class MyDisjointSet3:
    def __init__(self, N): 
        self.N = N
        self.parents = np.arange(N)
        
    def _get_root(self, val):
        if self.parents[val] != val: return self._get_root(self.parents[val])
        else:                        return val
    
    def find(self, i, j): 
        root_i = self._get_root(i)
        root_j = self._get_root(j)
        return root_i == root_j
                
    def union(self, i, j):
        root_i = self._get_root(i)
        root_j = self._get_root(j)
        if not self.find(i, j):
            self.parents[root_j] = root_i

In [None]:
ds = MyDisjointSet3(10)
ds.union(0, 2)
assert_allclose(ds.parents, [0, 1, 0, 3, 4, 5, 6, 7, 8, 9])
ds.union(1, 8)
assert_allclose(ds.parents, [0, 1, 0, 3, 4, 5, 6, 7, 1, 9])
ds.union(8, 7)
assert_allclose(ds.parents, [0, 1, 0, 3, 4, 5, 6, 1, 1, 9])
assert_allclose(ds.find(0, 3), False)
assert_allclose(ds.find(1, 7), True)
ds.union(1, 6)
assert_allclose(ds.parents, [0, 1, 0, 3, 4, 5, 1, 1, 1, 9])
ds.union(0, 1)
assert_allclose(ds.parents, [0, 0, 0, 3, 4, 5, 1, 1, 1, 9])
assert_allclose(ds.find(0, 7), True)
assert_allclose(ds.find(1, 9), False)

Use simple forest class from textbook

In [None]:
# export
class DisjointSetsForest:
    def __init__(self, L):
        self._parent = {item : item for item in L}
        
    def _root(self, item):
        while item is not self._parent[item]:
            item = self._parent[item]
        return item
    
    def find(self, a, b):
        return self._root(a) is self._root(b)
    
    def union(self, a, b):
        if not self.find(a,b):
            self._parent[self._root(b)] = self._root(a)

In [None]:
ds = DisjointSetsForest(np.arange(10))
ds.union(0, 2)
assert_allclose(list(ds._parent.values()), [0, 1, 0, 3, 4, 5, 6, 7, 8, 9])
ds.union(1, 8)
assert_allclose(list(ds._parent.values()), [0, 1, 0, 3, 4, 5, 6, 7, 1, 9])
ds.union(8, 7)
assert_allclose(list(ds._parent.values()), [0, 1, 0, 3, 4, 5, 6, 1, 1, 9])
assert_allclose(ds.find(0, 3), False)
assert_allclose(ds.find(1, 7), True)
ds.union(1, 6)
assert_allclose(list(ds._parent.values()), [0, 1, 0, 3, 4, 5, 1, 1, 1, 9])
ds.union(0, 1)
assert_allclose(list(ds._parent.values()), [0, 0, 0, 3, 4, 5, 1, 1, 1, 9])
assert_allclose(ds.find(0, 7), True)
assert_allclose(ds.find(1, 9), False)

Try to build a case with a long path

In [None]:
ds = DisjointSetsForest(np.arange(5))
ds.union(1, 0)
print(ds._parent)
ds.union(2, 0)
print(ds._parent)
ds.union(3, 0)
print(ds._parent)
ds.union(4, 0)
print(ds._parent)

{0: 1, 1: 1, 2: 2, 3: 3, 4: 4}
{0: 1, 1: 2, 2: 2, 3: 3, 4: 4}
{0: 1, 1: 2, 2: 3, 3: 3, 4: 4}
{0: 1, 1: 2, 2: 3, 3: 4, 4: 4}


0 -> 1 -> 2 -> 3 -> 4 -> 4

## Try simple path compression

In [None]:
# export
class DisjointSetsPathCompression:
    def __init__(self, L):
        self._parent = {item : item for item in L}
        
    def _root(self, item):
        while item is not self._parent[item]:
            parent = self._parent[item]
            self._parent[item] = self._parent[parent]
            item = parent
        return item
    
    def find(self, a, b):
        return self._root(a) is self._root(b)
    
    def union(self, a, b):
        if not self.find(a,b):
            self._parent[self._root(b)] = self._root(a)

In [None]:
ds = DisjointSetsPathCompression(np.arange(5))
ds.union(1, 0)
print(ds._parent)
ds.union(2, 0)
print(ds._parent)
ds.union(3, 0)
print(ds._parent)
ds.union(4, 0)
print(ds._parent)

{0: 1, 1: 1, 2: 2, 3: 3, 4: 4}
{0: 1, 1: 2, 2: 2, 3: 3, 4: 4}
{0: 2, 1: 2, 2: 3, 3: 3, 4: 4}
{0: 3, 1: 2, 2: 3, 3: 4, 4: 4}


0 -> 3 -> 4

Path from 1 to root is still long, see the effect of calling `_root()`

In [None]:
print(ds._parent)
ds._root(1)
print(ds._parent)

{0: 3, 1: 2, 2: 3, 3: 4, 4: 4}
{0: 3, 1: 3, 2: 4, 3: 4, 4: 4}


Path went from: `1 -> 2 -> 3 -> 4 -> 4` to: `1 -> 3 -> 4 -> 4`

Do two pass path compression where every node in path gets set to new root.

In [None]:
# export
class DisjointSetsTwoPassPC:
    def __init__(self, L):
        self._parent = {item : item for item in L}
        
    def _root(self, item):
        root = item
        while root is not self._parent[root]:
            root = self._parent[root]
        self._compress(item, root)
        return root
    
    def _compress(self, item, root_new):
        while item is not self._parent[item]:
            item_next = self._parent[item]
            self._parent[item] = root_new
            item = item_next
    
    def find(self, a, b):
        return self._root(a) is self._root(b)
    
    def union(self, a, b):
        if not self.find(a,b):
            self._parent[self._root(b)] = self._root(a)

In [None]:
ds = DisjointSetsTwoPassPC(np.arange(5))
ds.union(1, 0)
print(ds._parent)
ds.union(2, 0)
print(ds._parent)
ds.union(3, 0)
print(ds._parent)
ds.union(4, 0)
print(ds._parent)

{0: 1, 1: 1, 2: 2, 3: 3, 4: 4}
{0: 1, 1: 2, 2: 2, 3: 3, 4: 4}
{0: 2, 1: 2, 2: 3, 3: 3, 4: 4}
{0: 3, 1: 2, 2: 3, 3: 4, 4: 4}


See the effect of calling `_root(1)`

In [None]:
print(ds._parent)
ds._root(1)
print(ds._parent)

{0: 3, 1: 2, 2: 3, 3: 4, 4: 4}
{0: 3, 1: 4, 2: 4, 3: 4, 4: 4}


Path went from: `1 -> 2 -> 3 -> 4 -> 4` to: `1 -> 4`

## Try "Merge by Height"

This doesn't seem correct

In [None]:
# export
class DisjointSetsMergeByHeight:
    def __init__(self, L):
        self._parent = {item : item for item in L}
        self._height = {item : 0 for item in L}
        
    def _root(self, item):
        while item is not self._parent[item]:
            item = self._parent[item]
        return item
    
    def find(self, a, b):
        return self._root(a) is self._root(b)
    
    def union(self, a, b):
        if not self.find(a,b):
            if self._height[a] < self._height[b]:
                a,b = b,a
            self._parent[self._root(b)] = self._root(a)
            self._height[a] = max(self._height[a], self._height[b] + 1)

In [None]:
ds = DisjointSetsMergeByHeight(np.arange(5))
ds.union(1, 0)
print(ds._parent)
ds.union(2, 0)
print(ds._parent)
ds.union(3, 0)
print(ds._parent)
ds.union(4, 0)
print(ds._parent)

{0: 1, 1: 1, 2: 2, 3: 3, 4: 4}
{0: 1, 1: 2, 2: 2, 3: 3, 4: 4}
{0: 1, 1: 2, 2: 3, 3: 3, 4: 4}
{0: 1, 1: 2, 2: 3, 3: 4, 4: 4}


In [None]:
ds = DisjointSetsMergeByHeight(np.arange(10))
ds.union(0, 2)
assert_allclose(list(ds._parent.values()), [0, 1, 0, 3, 4, 5, 6, 7, 8, 9])
ds.union(1, 8)
assert_allclose(list(ds._parent.values()), [0, 1, 0, 3, 4, 5, 6, 7, 1, 9])
ds.union(8, 7)
assert_allclose(list(ds._parent.values()), [0, 1, 0, 3, 4, 5, 6, 1, 1, 9])
assert_allclose(ds.find(0, 3), False)
assert_allclose(ds.find(1, 7), True)
ds.union(1, 6)
assert_allclose(list(ds._parent.values()), [0, 1, 0, 3, 4, 5, 1, 1, 1, 9])
ds.union(0, 1)
assert_allclose(list(ds._parent.values()), [0, 0, 0, 3, 4, 5, 1, 1, 1, 9])
assert_allclose(ds.find(0, 7), True)
assert_allclose(ds.find(1, 9), False)

Implement one on Wikipedia which stores height in root nodes, this makes more sense to me.

In [None]:
# export
class DisjointSetsMergeByHeight2:
    def __init__(self, L):
        self._parent = {item : item for item in L}
        self._height = {item : 0    for item in L}
        
    def _root(self, item):
        while item is not self._parent[item]:
            item = self._parent[item]
        return item
    
    def find(self, a, b):
        return self._root(a) is self._root(b)
    
    def union(self, a, b):
        a = self._root(a)                                             # Replace with "representative" node
        b = self._root(b)                                             # Replace with "representative" node
        if a is not b:                                                # find operation
            if self._height[a] < self._height[b]: a, b = b, a         # Make sure "a" has larger height
            self._parent[b] = a                                       # Make root of b a
            self._height[a] = max(self._height[a], self._height[b]+1) # Update height of a

In [None]:
ds = DisjointSetsMergeByHeight2(np.arange(5))
ds.union(1, 0)
print(ds._parent.values())
ds.union(2, 0)
print(ds._parent.values())
ds.union(3, 0)
print(ds._parent.values())
ds.union(4, 0)
print(ds._parent.values())
print(ds._height.values())

dict_values([1, 1, 2, 3, 4])
dict_values([1, 1, 1, 3, 4])
dict_values([1, 1, 1, 1, 4])
dict_values([1, 1, 1, 1, 1])
dict_values([0, 1, 0, 0, 0])


In [None]:
ds = DisjointSetsMergeByHeight2(np.arange(10))
ds.union(0, 2)
assert_allclose(list(ds._parent.values()), [0, 1, 0, 3, 4, 5, 6, 7, 8, 9])
ds.union(1, 8)
assert_allclose(list(ds._parent.values()), [0, 1, 0, 3, 4, 5, 6, 7, 1, 9])
ds.union(8, 7)
assert_allclose(list(ds._parent.values()), [0, 1, 0, 3, 4, 5, 6, 1, 1, 9])
assert_allclose(ds.find(0, 3), False)
assert_allclose(ds.find(1, 7), True)
ds.union(1, 6)
assert_allclose(list(ds._parent.values()), [0, 1, 0, 3, 4, 5, 1, 1, 1, 9])
ds.union(0, 1)
assert_allclose(list(ds._parent.values()), [0, 0, 0, 3, 4, 5, 1, 1, 1, 9])
assert_allclose(ds.find(0, 7), True)
assert_allclose(ds.find(1, 9), False)

In [None]:
ds._height.values()

dict_values([2, 1, 0, 0, 0, 0, 0, 0, 0, 0])

Store heights but dont swap `a` and `b` to see effect

In [None]:
# export
class DisjointSetsMergeByHeight3:
    def __init__(self, L):
        self._parent = {item : item for item in L}
        self._height = {item : 0    for item in L}
        
    def _root(self, item):
        while item is not self._parent[item]:
            item = self._parent[item]
        return item
    
    def find(self, a, b):
        return self._root(a) is self._root(b)
    
    def union(self, a, b):
        a = self._root(a)                                             # Replace with "representative" node
        b = self._root(b)                                             # Replace with "representative" node
        if a is not b:                                                # `find` operation
            self._parent[b] = a                                       # Make root of b a
            self._height[a] = max(self._height[a], self._height[b]+1) # Update height of a

In [None]:
ds = DisjointSetsMergeByHeight3(np.arange(10))
ds.union(0, 2)
ds.union(1, 8)
ds.union(7, 8) # Swap this around
ds.union(1, 6)
ds.union(0, 1)

In [None]:
ds._height.values()

dict_values([3, 1, 0, 0, 0, 0, 0, 2, 0, 0])

In [None]:
ds = DisjointSetsMergeByHeight3(np.arange(5))
ds.union(1, 0)
ds.union(2, 0)
ds.union(3, 0)
ds.union(4, 0)
ds._height.values(), ds._parent.values()

(dict_values([0, 1, 2, 3, 4]), dict_values([1, 2, 3, 4, 4]))

In [None]:
ds = DisjointSetsMergeByHeight2(np.arange(5))
ds.union(1, 0)
ds.union(2, 0)
ds.union(3, 0)
ds.union(4, 0)
ds._height.values(), ds._parent.values()

(dict_values([0, 1, 0, 0, 0]), dict_values([1, 1, 1, 1, 1]))

## Try "Merge by Weight"

Weight is the number of nodes in each set. It kind of works by itself but mainly is used with path compression.

In [None]:
# export
class DisjointSetsMergeByWeight:
    def __init__(self, L):
        self._parent = {item : item for item in L}
        self._weight = {item : 1    for item in L}
        
    def _root(self, item):
        while item is not self._parent[item]:
            item = self._parent[item]
        return item
    
    def find(self, a, b):
        return self._root(a) is self._root(b)
    
    def union(self, a, b):
        a = self._root(a)                                             # Replace with "representative" node
        b = self._root(b)                                             # Replace with "representative" node
        if a is not b:                                                # find operation
            if self._weight[a] < self._weight[b]: a, b = b, a         # Make sure "a" has larger height
            self._parent[b] = a                                       # Make root of b a
            self._weight[a] = self._weight[a] + self._weight[b]       # Update weight of a

In [None]:
ds = DisjointSetsMergeByWeight(np.arange(5))
ds.union(1, 0)
ds.union(2, 0)
ds.union(3, 0)
ds.union(4, 0)
ds._weight.values(), ds._parent.values()

(dict_values([1, 5, 1, 1, 1]), dict_values([1, 1, 1, 1, 1]))

## Combine path compression with merge by weight

We cannot combine path compression with merge by weight since path compression changes the height (but not the weight)

In [None]:
# export
class DisjointSetsMergeByHeightAndPathCompression:
    def __init__(self, L):
        self._parent = {item : item for item in L}
        self._weight = {item : 1    for item in L}
        
    def _root(self, item):
        while item is not self._parent[item]:
            parent = self._parent[item]
            self._parent[item] = self._parent[parent]
            item = parent
        return item
    
    def find(self, a, b):
        return self._root(a) is self._root(b)
    
    def union(self, a, b):
        a = self._root(a)                                             # Replace with "representative" node
        b = self._root(b)                                             # Replace with "representative" node
        if a is not b:                                                # find operation
            if self._weight[a] < self._weight[b]: a, b = b, a         # Make sure "a" has larger height
            self._parent[b] = a                                       # Make root of b a
            self._weight[a] = self._weight[a] + self._weight[b]       # Update weight of a

In [None]:
ds = DisjointSetsMergeByHeightAndPathCompression(np.arange(10))
ds.union(0, 2)
assert_allclose(list(ds._parent.values()), [0, 1, 0, 3, 4, 5, 6, 7, 8, 9])
ds.union(1, 8)
assert_allclose(list(ds._parent.values()), [0, 1, 0, 3, 4, 5, 6, 7, 1, 9])
ds.union(8, 7)
assert_allclose(list(ds._parent.values()), [0, 1, 0, 3, 4, 5, 6, 1, 1, 9])
assert_allclose(ds.find(0, 3), False)
assert_allclose(ds.find(1, 7), True)
ds.union(1, 6)
assert_allclose(list(ds._parent.values()), [0, 1, 0, 3, 4, 5, 1, 1, 1, 9])
ds.union(0, 1)
assert_allclose(list(ds._parent.values()), [1, 1, 0, 3, 4, 5, 1, 1, 1, 9])
assert_allclose(ds.find(0, 7), True)
assert_allclose(ds.find(1, 9), False)

In [None]:
ds = DisjointSetsMergeByHeightAndPathCompression(np.arange(5))
ds.union(1, 0)
ds.union(2, 0)
ds.union(3, 0)
ds.union(4, 0)
ds._weight.values(), ds._parent.values()

(dict_values([1, 5, 1, 1, 1]), dict_values([1, 1, 1, 1, 1]))

In [None]:
# export
class DisjointSetsMergeByHeightAndPathCompression2:
    def __init__(self, L):
        self._parent = {item : item for item in L}
        self._weight = {item : 1    for item in L}
        
    def _root(self, item):
        root = item
        while root is not self._parent[root]:
            root = self._parent[root]
        self._compress(item, root)
        return root
    
    def _compress(self, item, root_new):
        while item is not self._parent[item]:
            item_next = self._parent[item]
            self._parent[item] = root_new
            item = item_next
    
    def find(self, a, b):
        return self._root(a) is self._root(b)
    
    def union(self, a, b):
        a = self._root(a)                                             # Replace with "representative" node
        b = self._root(b)                                             # Replace with "representative" node
        if a is not b:                                                # find operation
            if self._weight[a] < self._weight[b]: a, b = b, a         # Make sure "a" has larger height
            self._parent[b] = a                                       # Make root of b a
            self._weight[a] = self._weight[a] + self._weight[b]       # Update weight of a

In [None]:
ds = DisjointSetsMergeByHeightAndPathCompression2(np.arange(10))
ds.union(0, 2)
assert_allclose(list(ds._parent.values()), [0, 1, 0, 3, 4, 5, 6, 7, 8, 9])
ds.union(1, 8)
assert_allclose(list(ds._parent.values()), [0, 1, 0, 3, 4, 5, 6, 7, 1, 9])
ds.union(8, 7)
assert_allclose(list(ds._parent.values()), [0, 1, 0, 3, 4, 5, 6, 1, 1, 9])
assert_allclose(ds.find(0, 3), False)
assert_allclose(ds.find(1, 7), True)
ds.union(1, 6)
assert_allclose(list(ds._parent.values()), [0, 1, 0, 3, 4, 5, 1, 1, 1, 9])
ds.union(0, 1)
assert_allclose(list(ds._parent.values()), [1, 1, 0, 3, 4, 5, 1, 1, 1, 9])
assert_allclose(ds.find(0, 7), True)
assert_allclose(ds.find(1, 9), False)

In [None]:
ds = DisjointSetsMergeByHeightAndPathCompression2(np.arange(5))
ds.union(1, 0)
ds.union(2, 0)
ds.union(3, 0)
ds.union(4, 0)
ds._weight.values(), ds._parent.values()

(dict_values([1, 5, 1, 1, 1]), dict_values([1, 1, 1, 1, 1]))

# Build

In [None]:
build_notebook()

<IPython.core.display.Javascript object>

Converted module3.ipynb.
