### Union-Find
1. array 사용

In [50]:
class DisjointSet:
    def __init__(self, n):
        self.data = list(range(n))
        self.size = n
        self.length = len(set(self.data))

    def find(self, index):
        return self.data[index]


    def union(self, x, y):
        x, y = self.find(x), self.find(y)

        if x == y:
            return

        for i in range(self.size):
            if self.find(i) == y:
                self.data[i] = x


#     @property
#     def length(self):
#         return len(set(self.data))




disjoint = DisjointSet(10)

disjoint.union(0, 1)
disjoint.union(1, 2)
disjoint.union(2, 3)
disjoint.union(4, 5)
disjoint.union(5, 6)
disjoint.union(6, 7)
disjoint.union(8, 9)

print(disjoint.data)
print(disjoint.length)

[0, 0, 0, 0, 4, 4, 4, 4, 8, 8]
10


2. Tree 사용

In [43]:
# Union-by-size
class DisjointSet:
    def __init__(self, n):
        self.data = [-1 for _ in range(n)]
        self.size = n

    def find(self, index):
        value = self.data[index]
        if value < 0:
            return index

        return self.find(value)

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return

        if self.data[x] < self.data[y]:
            self.data[x] += self.data[y]
            self.data[y] = x
        else:
            self.data[y] += self.data[x]
            self.data[x] = y

        self.size -= 1


disjoint = DisjointSet(10)

disjoint.union(0, 1)
disjoint.union(1, 2)
disjoint.union(2, 3)
disjoint.union(4, 5)
disjoint.union(5, 6)
disjoint.union(6, 7)
disjoint.union(8, 9)

print(disjoint.data)
print(disjoint.size)


[1, -4, 1, 1, 5, -4, 5, 5, 9, -2]
3


In [44]:
# Union-by-height

class DisjointSet:
    def __init__(self, n):
        self.data = [-1] * n
        self.size = n

    def find(self, index):
        value = self.data[index]
        if value < 0:
            return index

        return self.find(value)

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return

        if self.data[x] < self.data[y]:
            self.data[y] = x
        elif self.data[x] > self.data[y]:
            self.data[x] = y
        else:
            self.data[x] -= 1
            self.data[y] = x

        self.size -= 1


disjoint = DisjointSet(10)

disjoint.union(0, 1)
disjoint.union(1, 2)
disjoint.union(2, 3)
disjoint.union(4, 5)
disjoint.union(5, 6)
disjoint.union(6, 7)
disjoint.union(8, 9)

print(disjoint.data)
print(disjoint.size)



[-2, 0, 0, 0, -2, 4, 4, 4, -2, 8]
3


In [45]:
#  path comprehension

class DisjointSet:
    def __init__(self, n):
        self.data = [-1 for _ in range(n)]
        self.size = n

    def upward(self, change_list, index):
        value = self.data[index]
        if value < 0:
            return index

        change_list.append(index)
        return self.upward(change_list, value)

    def find(self, index):
        change_list = []
        result = self.upward(change_list, index)

        for i in change_list:
            self.data[i] = result

        return result

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return

        if self.data[x] < self.data[y]:
            self.data[y] = x
        elif self.data[x] > self.data[y]:
            self.data[x] = y
        else:
            self.data[x] -= 1
            self.data[y] = x

        self.size -= 1


disjoint = DisjointSet(10)

disjoint.union(0, 1)
disjoint.union(1, 2)
disjoint.union(2, 3)
disjoint.union(4, 5)
disjoint.union(5, 6)
disjoint.union(6, 7)
disjoint.union(8, 9)

print(disjoint.data)
print(disjoint.size)


[-2, 0, 0, 0, -2, 4, 4, 4, -2, 8]
3
