In [44]:


from abc import *

class DisjointSet(metaclass=ABCMeta):
    
    @abstractmethod
    def find(parent: list[int], x: int) -> int:
        pass
    
    @abstractmethod
    def union(parent: list[int], a: int, b: int):
        pass


# Default implementation (w/o path compression)
class SimpleDisjointSet(DisjointSet):
    
    @staticmethod
    def find(parent: list[int], x: int) -> int:
        if parent[x] != x:
            return SimpleDisjointSet.find(parent, parent[x])
        
        return x

    @staticmethod
    def union(parent: list[int], a: int, b: int):
        a = SimpleDisjointSet.find(parent, a)
        b = SimpleDisjointSet.find(parent, b)
        if a < b:
            parent[b] = a
        else:
            parent[a] = b
            

# Another implementation (w/ path compression)
class PathCompressionDisjointSet(DisjointSet):
    
    @staticmethod
    def find(parent: list[int], x: int) -> int:
        if parent[x] != x:
            parent[x] = PathCompressionDisjointSet.find(parent, parent[x])
            
        return parent[x]

    @staticmethod
    def union(parent: list[int], a: int, b: int):
        a = PathCompressionDisjointSet.find(parent, a)
        b = PathCompressionDisjointSet.find(parent, b)
        if a < b:
            parent[b] = a
        else:
            parent[a] = b


# Check if a graph contains at least one cycle using Disjoint-set.
def detect_cycle(graph):
    """
        :param graph: List of tuple(destination node number, weight). List index is source node number.
        :return: returns True if 'graph' contains any cycle or False otherwise
    """
    
    parent = [0] * len(graph)
    for i in range(len(graph)):
        parent[i] = i
    
    for src in range(len(graph)):
        for dst, _ in graph[src]:
            if PathCompressionDisjointSet.find(parent, src) == PathCompressionDisjointSet.find(parent, dst):
                return True
            else:
                PathCompressionDisjointSet.union(parent, src, dst)

    return False


def kruskal(graph) -> int:
    """
        :param graph: List of tuple(destination node number, weight). List index is source node number.
        :return: returns total cost of MST
    """
    
    parent = [0] * len(graph)
    for i in range(len(graph)):
        parent[i] = i
    
    
    edges = []
    for i in range(len(graph)):
        for j, w in graph[i]:
            edges.append((i, j, w))
    
    edges.sort(key=lambda x: x[2])  # sort by weight in ascending order
    
    total_cost = 0
    for i, j, weight in edges:
        root_i = PathCompressionDisjointSet.find(parent, i)
        root_j = PathCompressionDisjointSet.find(parent, j)
        
        if root_i != root_j:
            total_cost += weight
            PathCompressionDisjointSet.union(parent, i, j)
    
    return total_cost



In [45]:



INF = int(1e9)
parent = [INF, 1, 2, 3, 4, 5, 6]

for i in range(1, len(parent)):
    parent[i] = i


PathCompressionDisjointSet.union(parent, 1, 4)
PathCompressionDisjointSet.union(parent, 2, 3)
PathCompressionDisjointSet.union(parent, 2, 4)
PathCompressionDisjointSet.union(parent, 5, 6)

print("union complete.")

for i in range(1, len(parent)):
    print(PathCompressionDisjointSet.find(parent, i), end=' ')

print("\n==========================")
for i in range(1, len(parent)):
    print(parent[i], end=' ')
    
    
graph = [[] for _ in range(3)]

graph[0].append((1, 100))
graph[1].append((0, 100))

graph[1].append((2, 100))
graph[2].append((1, 100))

graph[2].append((0, 100))
graph[0].append((2, 100))

print("\n==========================")
print(detect_cycle(graph))




# kruskal example graph from [https://www.javatpoint.com/kruskal-algorithm]
kg = [[] for _ in range(5)]

kg[0].append((1, 1))
kg[1].append((0, 1))

kg[0].append((2, 7))
kg[2].append((0, 7))

kg[0].append((3, 10))
kg[3].append((0, 10))

kg[0].append((4, 5))
kg[4].append((0, 5))

kg[1].append((2, 3))
kg[2].append((1, 3))

kg[2].append((3, 4))
kg[3].append((2, 4))

kg[3].append((4, 2))
kg[4].append((3, 2))

print("\n==========================")
print(kruskal(kg))






union complete.
1 1 1 1 5 5 
1 1 1 1 5 5 
True

10
