# Disjoint Sets using union by rank and path compression

Supports 3 operations:-

* make_set: Creates a disjoint set
* union: Combines two sets
* find_set: Finds the representative of the set

In [1]:
class Node:
    def __init__(self, data, parent=None, rank=0):
        self.data = data
        self.parent = parent
        self.rank = rank

In [2]:
class DisjointSet:
    def __init__(self):
        self.map = {}
        
    def print_set(self):
        for k, v in self.map.items():
            print('Parent of Node:', k, 'is', v.parent.data)
    
    def make_set(self, data):
        node = Node(data=data)
        node.parent = node
        self.map[data] = node
        
    def find_set(self, node):
        parent = node.parent
        if parent == node:
            return parent
        
        node.parent = self.find_set(parent)
        return node.parent
        
    def union(self, data1, data2):
        node1 = self.map[data1]
        node2 = self.map[data2]
        
        parent1 = self.find_set(node1)
        parent2 = self.find_set(node2)
        
        if parent1.data == parent2.data:
            return False
        
        if parent1.rank >= parent2.rank:
            parent1.rank = (parent1.rank+1 if parent1.rank == parent2.rank else parent1.rank)
            parent2.parent = parent1
        else:
            parent1.parent = parent2
            
        return True

In [3]:
ds = DisjointSet()

ds.make_set(1)
ds.make_set(2)
ds.make_set(3)
ds.make_set(4)
ds.make_set(5)
ds.make_set(6)
ds.make_set(7)

ds.union(1, 2)
ds.union(2, 3)
ds.union(4, 5)
ds.union(6, 7)
ds.union(5, 6)
ds.union(3, 7)

ds.print_set()

Parent of Node: 1 is 4
Parent of Node: 2 is 1
Parent of Node: 3 is 1
Parent of Node: 4 is 4
Parent of Node: 5 is 4
Parent of Node: 6 is 4
Parent of Node: 7 is 4


Note here that the parent of 2 and 3 is 1. But, the parent of 1 is 4. Which implies that path compression has not occurred on these nodes yet. Below, we apply find_set() on nodes 2 and 3 which applies path compression and updates the nodes parent to 4.

In [4]:
# find_set of nodes which will apply path compression
ds.find_set(ds.map[2])
ds.find_set(ds.map[3])

# print the set
ds.print_set()

Parent of Node: 1 is 4
Parent of Node: 2 is 4
Parent of Node: 3 is 4
Parent of Node: 4 is 4
Parent of Node: 5 is 4
Parent of Node: 6 is 4
Parent of Node: 7 is 4
