In [2]:
from collections import defaultdict

素朴な実装

In [6]:
class UnionFindBasic():
    def __init__(self, n):
        self.parents = list(range(n))

    def find(self, x):
        if self.parents[x] == x:
            return x
        else:
            return self.find(self.parents[x])

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return

        self.parents[y] = x

In [7]:
ufb = UnionFindBasic(5)
print(ufb.parents)
# [0, 1, 2, 3, 4]

[0, 1, 2, 3, 4]


In [8]:
ufb.union(3, 4)
print(ufb.parents)
ufb.union(2, 3)
print(ufb.parents)
ufb.union(1, 2)
print(ufb.parents)
ufb.union(0, 4)
print(ufb.parents)
# [0, 1, 2, 3, 3]
# [0, 1, 2, 2, 3]
# [0, 1, 1, 2, 3]
# [0, 0, 1, 2, 3]

print([ufb.find(i) for i in range(5)])
# [0, 0, 0, 0, 0]

[0, 1, 2, 3, 3]
[0, 1, 2, 2, 3]
[0, 1, 1, 2, 3]
[0, 0, 1, 2, 3]
[0, 0, 0, 0, 0]


# 経路圧縮
find()で調べる際に、調べた要素の親を根に変更して繋ぎ直す

In [9]:
class UnionFindPathCompression():
    def __init__(self, n):
        self.parents = list(range(n))
    
    def find(self, x):
        if self.parents[x] == x:
            return x
        else:
            self.parents[x] = self.find(self.parents[x])
            return self.parents[x]
    
    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return

        self.parents[y] = x

In [10]:
ufpc = UnionFindPathCompression(5)
print(ufpc.parents)
# [0, 1, 2, 3, 4]

ufpc.union(3, 4)
print(ufpc.parents)
ufpc.union(2, 3)
print(ufpc.parents)
ufpc.union(1, 2)
print(ufpc.parents)
ufpc.union(0, 4)
print(ufpc.parents)
# [0, 1, 2, 3, 3]
# [0, 1, 2, 2, 3]
# [0, 1, 1, 2, 3]
# [0, 0, 1, 1, 1]

print([ufpc.find(i) for i in range(5)])
# [0, 0, 0, 0, 0]

[0, 1, 2, 3, 4]
[0, 1, 2, 3, 3]
[0, 1, 2, 2, 3]
[0, 1, 1, 2, 3]
[0, 0, 1, 1, 1]
[0, 0, 0, 0, 0]


# ランク
木の高さの情報を保存しておき、併合する際に低い方を高い方の親の値にする

union()でランクを元に併合する。ランクが同じグループを併合する場合は親（根が変わらない方）のランクを1増やす。

In [13]:
class UnionFindByRank():
    def __init__(self, n):
        self.parents = list(range(n))
        self.rank = [0] * n
    
    def find(self, x):
        if self.parents[x] == x:
            return x
        else:
            self.parents[x] = self.find(self.parents[x])
            return self.parents[x]
    
    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return
        
        if self.rank[x] < self.rank[y]:
            self.parents[x] = y
        else:
            self.parents[y] = x
            if self.rank[x] == self.rank[y]:
                self.rank[x] += 1

In [14]:
ufbr = UnionFindByRank(5)
print(ufbr.parents)
# [0, 1, 2, 3, 4]

ufbr.union(3, 4)
print(ufbr.parents)
ufbr.union(2, 3)
print(ufbr.parents)
ufbr.union(1, 2)
print(ufbr.parents)
ufbr.union(0, 4)
print(ufbr.parents)
# [0, 1, 2, 3, 3]
# [0, 1, 3, 3, 3]
# [0, 3, 3, 3, 3]
# [3, 3, 3, 3, 3]

[0, 1, 2, 3, 4]
[0, 1, 2, 3, 3]
[0, 1, 3, 3, 3]
[0, 3, 3, 3, 3]
[3, 3, 3, 3, 3]


In [15]:
class UnionFindBySize():
    def __init__(self, n):
        self.parents = list(range(n))
        self.size = [1] * n
    def find(self, x):
        if self.parents[x] == x:
            return x
        else:
            self.parents[x] = self.find(self.parents[x])
            return self.parents[x]
    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return
        
        if self.size[x] < self.size[y]:
            self.size[y] += self.size[x]
            self.parents[x] = y
        else:
            self.size[x] += self.size[y]
            self.parents[y] = x
            

In [16]:
ufbs = UnionFindBySize(5)
print(ufbs.parents)
# [0, 1, 2, 3, 4]

ufbs.union(3, 4)
print(ufbs.parents)
ufbs.union(2, 3)
print(ufbs.parents)
ufbs.union(1, 2)
print(ufbs.parents)
ufbs.union(0, 4)
print(ufbs.parents)
# [0, 1, 2, 3, 3]
# [0, 1, 3, 3, 3]
# [0, 3, 3, 3, 3]
# [3, 3, 3, 3, 3]

[0, 1, 2, 3, 4]
[0, 1, 2, 3, 3]
[0, 1, 3, 3, 3]
[0, 3, 3, 3, 3]
[3, 3, 3, 3, 3]


In [17]:
class UnionFind():
    def __init__(self, n):
        self.parents = [-1] * n
    def find(self, x):
        if self.parents[x] < 0:
            return x
        else:
            self.parents[x] = self.find(self.parents[x])
            return self.parents[x]
    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)
        if x == y:
            return
        if self.parents[x] > self.parents[y]:
            x, y = y, x
        
        self.parents[x] += self.parents[y]
        self.parents[y] = x

In [18]:
uf = UnionFind(5)
print(uf.parents)
# [-1, -1, -1, -1, -1]

uf.union(3, 4)
print(uf.parents)
uf.union(2, 3)
print(uf.parents)
uf.union(1, 2)
print(uf.parents)
uf.union(0, 4)
print(uf.parents)
# [-1, -1, -1, -2, 3]
# [-1, -1, 3, -3, 3]
# [-1, 3, 3, -4, 3]
# [3, 3, 3, -5, 3]


[-1, -1, -1, -1, -1]
[-1, -1, -1, -2, 3]
[-1, -1, 3, -3, 3]
[-1, 3, 3, -4, 3]
[3, 3, 3, -5, 3]
