In [1]:
ufsets = set()

In [125]:
from collections import defaultdict

class UnionFind:
    def __init__(self, n):
        self._parents = [-1] * n
        self._sizes = [1] * n

    def root(self, x):
        if self._parents[x] == -1:
            return x
        else:
            self._parents[x] = self.root(self._parents[x])
            return self._parents[x]

    def unite(self, x, y):
        rx = self.root(x)
        ry = self.root(y)

        if rx == ry:
            return False
        elif self._sizes[rx] < self._sizes[ry]:
            self._parents[rx] = ry
            self._sizes[ry] += self._sizes[rx]
        else:
            self._parents[ry] = rx
            self._sizes[rx] += self._sizes[ry]

        return True

    def connected(self, x, y):
        return self.root(x) == self.root(y)

    def size(self, x):
        return self._sizes[self.root(x)]

    def __repr__(self):
        bins = defaultdict(set)
        for i, p in enumerate(map(self.root, range(len(self._parents)))):
            bins[p].add(i)
        return "<UnionFind> " + str(list(bins.values()))


In [126]:
uf = UnionFind(10)
uf

<UnionFind> [{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9}]

In [127]:
uf.unite(2, 3)
uf

<UnionFind> [{0}, {1}, {2, 3}, {4}, {5}, {6}, {7}, {8}, {9}]

In [128]:
uf.connected(2, 3)

True

In [129]:
uf.size(2), uf.size(3)

(2, 2)

In [130]:
uf.unite(5, 6)
uf.unite(6, 9)
uf.unite(2, 6)
uf

<UnionFind> [{0}, {1}, {2, 3, 5, 6, 9}, {4}, {7}, {8}]

In [132]:
uf.size(5)

5

In [131]:
uf._parents

[-1, -1, 5, 5, -1, -1, 5, -1, -1, 5]

## UnionFind で無向グラフの連結成分の個数を求める

In [133]:
N = 10
E = [
    (0, 1), (1, 2), (1, 3), (3, 0),
    (5, 6), (6, 9),
    (7, 8)
]

In [138]:
uf = UnionFind(N)
for s, t in E:
    uf.unite(s, t)

print(len(list(filter(lambda v: uf.root(v) == v, range(N)))))

4


In [139]:
uf

<UnionFind> [{0, 1, 2, 3}, {4}, {9, 5, 6}, {8, 7}]

In [140]:
uf.size(0), uf.size(4), uf.size(5), uf.size(8)

(4, 1, 3, 2)

In [141]:
uf._parents

[-1, 0, 0, 0, -1, -1, 5, -1, 7, 5]