### Libraries

In [None]:
import matplotlib
import tools as t
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

from scipy import spatial
from scipy.sparse import csr_matrix
from mpl_toolkits.mplot3d import Axes3D
from scipy.sparse.csgraph import connected_components

sns.set_style('darkgrid')
%matplotlib inline

### Hyperparameters

In [None]:
k = 2 
l = 500
n_verts = 30
dist_thrshld = 2.5

### Create Affinity Matrix

In [None]:
'''Choose points from specific layer and epoch'''
data = np.load('/home/zz452/cluster/data/modelTishby_0_layer_data.p')
layer = data[-2] 
n_epochs, n_sample, n_dim = layer.shape
points = layer[-1,:,:]

M_d = spatial.distance_matrix(points,
                              points,
                              p=2)**2
mask = M_d<dist_thrshld
remove_diag = np.eye(n_sample)==0
mask = remove_diag*mask

stay_prob = np.eye(n_sample)*0.5
d = mask.sum(0).max()
move_prob = 1/(2*d)

M = move_prob*mask + stay_prob
add_self_loop = np.diag(1-M.sum(0))
M = M + add_self_loop
M = M.astype('float32')

'''Visualize'''
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(points[:,0],
           points[:,1], 
           points[:,2])
ax.title.set_text('Latent Visualizer')

fig = plt.figure(figsize=(5, 5))
plt.hist(M_d.flatten(), log=True, bins=100)
plt.title('Distance Distribution')

### Random Sample of Vertices

In [None]:
# create a mask so that we dont sample verteces that are not connected to anything 
singles_mask = mask.sum(0)>0
singles_mask = singles_mask.reshape(-1,1)

#total number of vertices that are connected (not to itself)
n = (mask.sum(0) > 0).sum() 

S = np.random.random_sample((M.shape[1], 
                             n_verts))
S = S*singles_mask
S = (S.max(axis=0,keepdims=1) == S)*1
S = S.astype('float32')

### Random Walk

In [None]:
M_l = np.linalg.matrix_power(M, l)
S_l = np.matmul(M_l, S)
p_l2 = np.linalg.norm(S_l, ord=2, axis=0)**2
sigma = 192*n_verts*k/n
keep_idx = p_l2<sigma

'''
TODO: put a function that sample more vertices if it doesnt pass the sigma test'
【・ヘ・】
'''
assert len(keep_idx)==n_verts, 'Sample more vertices, didnt pass sigma test'

### Similarlity Graph

In [None]:
H = spatial.distance_matrix(np.swapaxes(S_l,0,1),
                            np.swapaxes(S_l,0,1),
                            p=2)**2
remove_diag = np.eye(n_verts)*9999
H = H+remove_diag
H = H<=1/(4*n)

graph = csr_matrix(H)
n_components, labels = connected_components(csgraph=graph, directed=False, return_labels=True)

print('{} islands'.format(n_components))
if n_components>k:
    print('Need more clustering （・∩・)')
else:
    print('Good amount of clustering (･o･)')