In [1]:
from IPython.core.display import HTML
with open('../style.css') as file:
    css = file.read()
HTML(css)

# An Object-Oriented Implementation of the Union-Find Algorithm

The union find algorithm takes a *binary relation* $R \subseteq M \times M$ and computes the *equivalence relation* $\approx_R$
that is generated by $R$ on the set $M$.

The class `Node` represents the objects of the equivalence relation that is to be constructed.
It maintains two member variables:
* `mValue` is the value from the set $M$.
* `mParent` is a pointer to the parent node.
* `mHeight` is the height of the tree rooted at `self`.

In [2]:
class Node:
    def __init__(self, value):
        self.mValue  = value
        self.mParent = None
        self.mHeight = 1

Given a node $n$ from the set $M$, the function $\texttt{self}.\texttt{find}()$ 
returns the ancestor of $n$ that is at the root of the tree containing $n$.

In [3]:
def find(self):
    node = self
    while node.mParent != None:
        node = node.mParent
    return node

Node.find = find
del find

Given two nodes $x$ and $y$, the call $\texttt{union}(x, y)$ changes the underlying graph so that afterwards the equation
$$ x.\texttt{find}() = y.\texttt{find}() $$
holds.

In [4]:
def union(x, y):
    root_x = x.find()
    root_y = y.find()
    if root_x != root_y:
        if root_x.mHeight < root_y.mHeight:
            root_x.mParent = root_y
        elif root_x.mHeight > root_y.mHeight:
            root_y.mParent = root_x
        else:
            root_y.mParent  = root_x
            root_x.mHeight += 1

Given a set $M$ and a *binary relation* $R \subseteq M \times M$ the function $\texttt{partition}(M, R)$ computes and returns a partition $\mathcal{P}$ that represents
the *equivalence relation* $\approx_R$ that is generated by $R$ on the set $M$.

In [5]:
def partition(M, R):
    Nodes = { x: Node(x) for x in M }
    for x, y in R:
        node_x = Nodes[x]
        node_y = Nodes[y]
        union(node_x, node_y)
    Roots = { Nodes[x] for x in M if Nodes[x].mParent == None }
    return [{y for y in M if Nodes[y].find() == r} for r in Roots]

In [6]:
def demo():
    M = set(range(1, 10))
    R = { (1, 4), (7, 9), (3, 5), (2, 6), (5, 8), (1, 9), (4, 7) }
    P = partition(M, R)
    return P

In [7]:
P = demo()
P

[{3, 5, 8}, {1, 4, 7, 9}, {2, 6}]