In [16]:
# -------------------------------------------------------------
#
# scipy's "linkage()" function can't handle our data sizes
#
# ------------------------------------------------------------
%matplotlib inline


import time
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import AgglomerativeClustering, ward_tree
from sklearn.neighbors import kneighbors_graph
from scipy.cluster.hierarchy import dendrogram
from sklearn.utils.validation import check_memory

K = 100
N = 8
FILE = 'data/pills-blue.npy'

class AgglomerativeClusterer():
    def __init__(self, filename=FILE, N=10):
        self.filename = filename
        self.X = np.load(filename)
        print('X', self.X.shape)
        self.idx = np.arange(self.X.shape[0])
        self.link_type = 'ward'
        self.affinity = 'euclidean'
        self.N = N
        self.event_labels  = {i:i for i in range(self.N)}
        self.connectivity = None
    
    def to_range(self, l, cmap):
        total = cmap.N
        idx = int((l[-1] / self.N) * total)
        return cmap(idx)

    def boundaries(self, v):
        bounds = [0] + list(np.where(v[:-1] != v[1:])[0] + 1) + [len(v)-1]
        return [(bounds[i], bounds[i+1], v[bounds[i]]) for i in range(len(bounds)-1)]
    
    def remap_labels(self, labels):
        map_keys = np.argsort(np.unique(labels, return_counts=True)[1]).reshape(self.N, 1)[::-1]
        vals, idx = np.where(labels == map_keys)
        labels[idx] = vals
        return labels
    
    def compute_connectivity(self, k=100):
        t0 = time.time()
        graph = kneighbors_graph(self.X, k, include_self=False, n_jobs=-1)
        elapsed_time = time.time() - t0
        self.connectivity = graph
        print('setting connectivity:', graph.shape, graph.nnz)
        return elapsed_time

    def run_clustering(self, n=None):
        n = n or self.N
        model = AgglomerativeClustering(linkage=self.link_type, affinity=self.affinity,
                                        n_clusters=n, connectivity=self.connectivity,
                                        memory='HAC_cache')
        start = time.time()
        model.fit(self.X)
        end = time.time()
        total_time = end - start
        
        model_labels = model.labels_
        labels = self.remap_labels(model_labels)
        return (total_time, model, labels, np.unique(labels, return_counts=True)[1])
    
    def Z(self):
        mem = check_memory('HAC_cache')
        func = mem.cache(ward_tree)
        start = time.time()
        children, _, __, parents, dists = func(self.X, self.connectivity, n_clusters=self.N, return_distance=True)
        end = time.time()
        total_time = end - start
        return total_time, children, parents, dists
        

ac = AgglomerativeClusterer(FILE, N)
ctime = ac.compute_connectivity(K)
print('connectivity time: ', ctime)
t, mdl, lbls, counts = ac.run_clustering()
print('clustering time: ', t)

print('labels:', lbls.shape)
print('counts:', counts)

print('# leaves:', mdl.n_leaves_)
print('# components:', mdl.n_components_)

print('children:', mdl.children_.shape)
print(mdl.children_[:10])
print(mdl.get_params(True))

Zt, Zc, Zp, Zd = ac.Z()
print('Z time:', Zt)
print('children:', Zc.shape)
print('parents:', Zp.shape)
print('distances:', Zd.shape)

X (67111, 13)
setting connectivity: (67111, 67111) 6711100
connectivity time:  6.314779758453369
clustering time:  0.1887836456298828
labels: (67111,)
counts: [64382   730   723   605   270   163   135   103]
# leaves: 67111
# components: 1
children: (67110, 2)
[[28132 37154]
 [35320 37688]
 [41842 42054]
 [33153 33779]
 [14034 14255]
 [ 9094  9348]
 [30400 31237]
 [11291 11515]
 [43398 44219]
 [47321 49828]]
{'affinity': 'euclidean', 'compute_full_tree': 'auto', 'connectivity': <67111x67111 sparse matrix of type '<class 'numpy.float64'>'
	with 6711100 stored elements in Compressed Sparse Row format>, 'linkage': 'ward', 'memory': 'HAC_cache', 'n_clusters': 8, 'pooling_func': <function mean at 0x7fd2a40dfb70>}
Z time: 0.1536102294921875
children: (67103, 2)
parents: (134214,)
distances: (67103,)
