In [1]:
class UnionFind():
    def __init__(self, n):
        self.n = n
        self.parents = [-1] * n
    
    def root(self, x): #xの根を取得
        if self.parents[x] < 0:
            return x
        else:
            self.parents[x] = self.root(self.parents[x])
            return self.parents[x]

    def unite(self, x, y): #xとyをマージ
        x = self.root(x)
        y = self.root(y)

        if x == y:
            return
        if self.parents[x] > self.parents[y]:
            x, y = y, x

        self.parents[x] += self.parents[y]
        self.parents[y] = x

    def is_same(self, x, y): #xとyが同じ連結成分か判定
        return self.root(x) == self.root(y)

    def size(self, x):
        return -self.parents[self.root(x)]

    def get_groups(self): #すべてのグループについて親とそのメンバーを辞書で返します
        members = {}
        for member in range(self.n):
            p = self.root(member)        
            if members.get(p, False):
                members[p].append(member)
            else:
                members[p] = [member]

            # if members.get(self.root(member), -1) == -1:
            #     members[self.root(member)] = [member]
            # else:
            #     members[self.root(member)].append(member)
        return members
    
    def __str__(self): #すべてのグループについて親とそのメンバーを文字列型で返します
        members = self.get_groups()
        return '\n'.join([f'parents: {member}, member: {members[member]}' for member in members])

In [3]:
uf = UnionFind(7)
uf.unite(0, 2)
uf.unite(0, 3)
uf.unite(0, 1)
uf.unite(4, 5)
print(uf.size(0))
print(uf.size(4))
print(uf.size(6))
print(uf.get_groups())
print(uf)

4
2
1
{0: [0, 1, 2, 3], 4: [4, 5], 6: [6]}
parents: 0, member: [0, 1, 2, 3]
parents: 4, member: [4, 5]
parents: 6, member: [6]


In [5]:
dic = {1: [23, -1]}
dic.get(2, False)

False