<img src='images/union-find.png' width=600>

In [140]:
# n = 5
# Expected result for test1:
# For K = 2 -> 8
# For K = 3 -> 4
# For K = 4 -> 1

test1 = [
    (1, 2, 1),
    (1, 3, 4),
    (1, 4, 5),
    (1, 5, 10),
    (2, 3, 5),
    (2, 4, 4),
    (2, 5, 8),
    (3, 4, 1),
    (3, 5, 12),
    (4, 5, 11)
]

# n = 5
# Expected result:
# For K = 2 -> 5
# For K = 3 -> 2
# For K = 4 -> 1

test2 = [
    (1, 2, 1),
    (1, 3, 2),
    (1, 4, 4),
    (1, 5, 5),
    (2, 3, 4),
    (2, 4, 3),
    (2, 5, 6),
    (3, 4, 1),
    (3, 5, 7),
    (4, 5, 8)
]


In [118]:
from collections import defaultdict
def convert_graph(lst):
    graph = defaultdict(dict)
    for t in lst:
        graph[t[0]][t[1]] = t[2]
        graph[t[1]][t[0]] = t[2]
    return graph


In [124]:
class UnionFind:
    def __init__(self, n):
        self.node = n
        self.leader = self
        self.members = [self]
    def __repr__(self):
        return '<UnionFind {}>'.format(self.node)
    

def union(uf1, uf2, s):  # s: remove node from set
    assert isinstance(uf1, UnionFind) and isinstance(uf2, UnionFind), 'Not instance of UnionFind' 
    # cluster that has more members will keep leader point
    # whereas cluster with fewer members will update leader point
    # if number of members eqal
    # the node with smaller value keep leader point
    
    if uf1.leader is uf2.leader:
        return 

    if len(uf1.leader.members) <= len(uf2.leader.members):
        s.remove(uf1.leader.node)
        uf2.leader.members.extend(uf1.leader.members)
        for uf in uf1.leader.members:
            uf.leader = uf2.leader
            uf.members = []
    else:
        s.remove(uf2.leader.node)
        uf1.leader.members.extend(uf2.leader.members)
        for uf in uf2.leader.members:
            uf.leader = uf1.leader
            uf.members = []
                

def cluster(lst, k):
    graph = convert_graph(lst)
    nodes = set(graph.keys())
    
    lst.sort(key=lambda x: x[2])
    ufs = {n: UnionFind(n) for n in nodes}
    
    i = 0
    while len(nodes) > k:
        edge = lst[i]
        union(ufs[edge[0]], ufs[edge[1]], nodes)
        i += 1

    return ufs, nodes, lst[i:]


def max_space(ufs, nodes, lst_rest):
    cur_len = len(nodes)
    
    i = 0
    while len(nodes) == cur_len:
        edge = lst[i]
        space = edge[2]
        union(ufs[edge[0]], ufs[edge[1]], nodes)
        i += 1
        
    return space


import copy
def max_space_(lst, k):
    graph = convert_graph(lst)
    nodes = set(graph.keys())
    
    lst.sort(key=lambda x: x[2])
    ufs = {n: UnionFind(n) for n in nodes}
    
    i = 0
    while len(nodes) >= k:
        edge = lst[i]
        space = edge[2]
        union(ufs[edge[0]], ufs[edge[1]], nodes)
        i += 1
        
    return space

In [122]:
ufs, nodes, lst = cluster(test1, 2)
ufs, nodes, lst 

({1: <UnionFind 1>,
  2: <UnionFind 2>,
  3: <UnionFind 3>,
  4: <UnionFind 4>,
  5: <UnionFind 5>},
 {4, 5},
 [(2, 4, 4),
  (1, 4, 5),
  (2, 3, 5),
  (2, 5, 8),
  (1, 5, 10),
  (4, 5, 11),
  (3, 5, 12)])

In [123]:
max_space(ufs, nodes, lst)

8

In [128]:
max_space_(test1, 3)

4

In [144]:
max_space_(test2, 2)

5

In [130]:
with open('datasets/clustering1.txt') as f:
    content = f.readlines()
    
num_of_nodes = int(content[0].strip())

import re
pattern = re.compile('\d+')
edges = [re.findall(pattern, e) for e in content[1:]]
edges = [tuple([int(i) for i in j])for j in edges]

In [131]:
num_of_nodes

500

In [139]:
max_space_(edges, 4)

106

In [154]:
# Expected result 6

with open('datasets/clustering_test.txt') as f:
    content = f.readlines()

num, bits = [int(x) for x in re.findall(pattern, content[0])]
nodes = [re.findall(pattern, n) for n in content[1:]]
nodes = [tuple([int(i) for i in j])for j in nodes]