<img src='images/union-find.png' width=600>

In [140]:
# n = 5
# Expected result for test1:
# For K = 2 -> 8
# For K = 3 -> 4
# For K = 4 -> 1

test1 = [
    (1, 2, 1),
    (1, 3, 4),
    (1, 4, 5),
    (1, 5, 10),
    (2, 3, 5),
    (2, 4, 4),
    (2, 5, 8),
    (3, 4, 1),
    (3, 5, 12),
    (4, 5, 11)
]

# n = 5
# Expected result:
# For K = 2 -> 5
# For K = 3 -> 2
# For K = 4 -> 1

test2 = [
    (1, 2, 1),
    (1, 3, 2),
    (1, 4, 4),
    (1, 5, 5),
    (2, 3, 4),
    (2, 4, 3),
    (2, 5, 6),
    (3, 4, 1),
    (3, 5, 7),
    (4, 5, 8)
]


In [118]:
from collections import defaultdict
def convert_graph(lst):
    graph = defaultdict(dict)
    for t in lst:
        graph[t[0]][t[1]] = t[2]
        graph[t[1]][t[0]] = t[2]
    return graph


In [124]:
class UnionFind:
    def __init__(self, n):
        self.node = n
        self.leader = self
        self.members = [self]
    def __repr__(self):
        return '<UnionFind {}>'.format(self.node)
    

def union(uf1, uf2, s):  # s: remove node from set
    assert isinstance(uf1, UnionFind) and isinstance(uf2, UnionFind), 'Not instance of UnionFind' 
    # cluster that has more members will keep leader point
    # whereas cluster with fewer members will update leader point
    # if number of members eqal
    # the node with smaller value keep leader point
    
    if uf1.leader is uf2.leader:
        return 

    if len(uf1.leader.members) <= len(uf2.leader.members):
        s.remove(uf1.leader.node)
        uf2.leader.members.extend(uf1.leader.members)
        for uf in uf1.leader.members:
            uf.leader = uf2.leader
            uf.members = []
    else:
        s.remove(uf2.leader.node)
        uf1.leader.members.extend(uf2.leader.members)
        for uf in uf2.leader.members:
            uf.leader = uf1.leader
            uf.members = []
                

def cluster(lst, k):
    graph = convert_graph(lst)
    nodes = set(graph.keys())
    
    lst.sort(key=lambda x: x[2])
    ufs = {n: UnionFind(n) for n in nodes}
    
    i = 0
    while len(nodes) > k:
        edge = lst[i]
        union(ufs[edge[0]], ufs[edge[1]], nodes)
        i += 1

    return ufs, nodes, lst[i:]


def max_space(ufs, nodes, lst_rest):
    cur_len = len(nodes)
    
    i = 0
    while len(nodes) == cur_len:
        edge = lst[i]
        space = edge[2]
        union(ufs[edge[0]], ufs[edge[1]], nodes)
        i += 1
        
    return space


import copy
def max_space_(lst, k):
    graph = convert_graph(lst)
    nodes = set(graph.keys())
    
    lst.sort(key=lambda x: x[2])
    ufs = {n: UnionFind(n) for n in nodes}
    
    i = 0
    while len(nodes) >= k:
        edge = lst[i]
        space = edge[2]
        union(ufs[edge[0]], ufs[edge[1]], nodes)
        i += 1
        
    return space

In [122]:
ufs, nodes, lst = cluster(test1, 2)
ufs, nodes, lst 

({1: <UnionFind 1>,
  2: <UnionFind 2>,
  3: <UnionFind 3>,
  4: <UnionFind 4>,
  5: <UnionFind 5>},
 {4, 5},
 [(2, 4, 4),
  (1, 4, 5),
  (2, 3, 5),
  (2, 5, 8),
  (1, 5, 10),
  (4, 5, 11),
  (3, 5, 12)])

In [123]:
max_space(ufs, nodes, lst)

8

In [128]:
max_space_(test1, 3)

4

In [144]:
max_space_(test2, 2)

5

In [130]:
with open('datasets/clustering1.txt') as f:
    content = f.readlines()
    
num_of_nodes = int(content[0].strip())

import re
pattern = re.compile('\d+')
edges = [re.findall(pattern, e) for e in content[1:]]
edges = [tuple([int(i) for i in j])for j in edges]

In [131]:
num_of_nodes

500

In [139]:
max_space_(edges, 4)

106

The format is:

[# of nodes] [# of bits for each node’s label]

[first bit of node 1] … [last bit of node 1]

[first bit of node 2] … [last bit of node 2]

…

For example, the third line of the file “0 1 1 0 0 1 1 0 0 1 0 1 1 1 1 1 1 0 1 0 1 1 0 1” denotes the 24 bits associated with node #2.

The distance between two nodes u and v in this problem is defined as the Hamming distance— the number of differing bits — between the two nodes’ labels. For example, the Hamming distance between the 24-bit label of node #2 above and the label “0 1 0 0 0 1 0 0 0 1 0 1 1 1 1 1 1 0 1 0 0 1 0 1” is 3 (since they differ in the 3rd, 7th, and 21st bits).

The question is: what is the largest value of k such that there is a k-clustering with spacing at least 3? That is, how many clusters are needed to ensure that no pair of nodes with all but 2 bits in common get split into different clusters?

NOTE: The graph implicitly defined by the data file is so big that you probably can’t write it out explicitly, let alone sort the edges by cost. So you will have to be a little creative to complete this part of the question. For example, is there some way you can identify the smallest distances without explicitly looking at every pair of nodes?

In [276]:
# convert every node to a decimal int
# 1d and 2d distance functions
# hashtable for each node with uf data structure

def convert_nodes(lst):
    output = {}
    for n in lst:
        number = 0
        # convert bits to decimal
        for j, m in enumerate(reversed(n)):
            if m == 1:
                number += 2**j
        output[number] = UnionFind(number)
        uf = UnionFind(i)
    return output

# when two binary number differ only by 1 bit, the equivalent values of decimal
def distance_1d(bits):
    p1 = []
    for n in range(0, bits):
        p1.append(2**n)
    return p1

# when two binary number differ by 2 bit, the equivalent values of decimal
def distance_2d(bits):
    p2 = set()
    for n in range(0, bits):
        for m in (x for x in range(0, bits) if x != n):
            p2.add(2**n + 2**m)
    return list(p2) 

def distance_all(bits):
    p1 = distance_1d(bits)
    p2 = distance_2d(bits)
    p = list(set(p1 + p2))
    return p

def cluster_(ufs_dict, nodes_set, diff_lst):
    for n in nodes_set:
        pv = [n - x for x in diff_lst] + [n + x for x in diff_lst]
        for v in pv:
            if v in ufs_dict:
                union(ufs_dict[n], ufs_dict[v], nodes_set)
    return ufs_dict, nodes_set, len(nodes_set)

In [279]:
# Expected result 6

with open('datasets/clustering_test.txt') as f:
    content = f.readlines()

num, bits = [int(x) for x in re.findall(pattern, content[0])]
nodes = [re.findall(pattern, n) for n in content[1:]]
nodes = [[int(i) for i in j]for j in nodes]
ufs = convert_nodes(nodes)
n_set = set(ufs)
d_lst = distance_all(bits)

In [None]:
cluster_(ufs, n_set, d_lst)