In [None]:
# load the modules
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import scipy
from scipy.sparse.csgraph import minimum_spanning_tree

def adjacency(X, alpha):
    Gd = (X[None,:,:] - X[:,None,:])**2
    return np.sqrt(Gd[:,:,0:2].sum(axis=-1) + alpha*Gd[:,:,2])

def cost(X, alpha, perc=95, show_plot=False):
    A = adjacency(X, alpha)
    T = minimum_spanning_tree(A).toarray()
    cut = np.percentile(T[T>0],[perc])[0]
    T[T>cut] = 0
    
    graph = scipy.sparse.csr_matrix(T)
    n_components, labels = scipy.sparse.csgraph.connected_components(graph)
    intra_dist = []
    mean_pos = []
    for c in range(1, n_components):
        sel = labels == c
        if sum(sel) > 1:
            mean_pos.append(X[sel].mean(axis=0))
            dist_c = A[sel,:][:,sel]
            intra_dist.append(dist_c[dist_c>0].mean())
    all_mean = A[A>0].mean()
    inter_X = np.stack(mean_pos, axis=0)
    inter_A = adjacency(inter_X, alpha)
    inter_dist = inter_A[inter_A>0].mean()

    if show_plot:
        fig = plt.figure()
        ax = fig.add_subplot(111, aspect='equal')
        ax.scatter(X[:,0], X[:,1], c=labels, alpha=0.1, cmap='prism')
        i,j = np.where(T>0)
        plt.quiver(X[i,0], X[i,1], X[j,0]-X[i,0], X[j,1]-X[i,1], angles='xy', scale_units='xy', scale=1, headwidth=0, headaxislength=0, headlength=0, minlength=0)
        fig.tight_layout()

    return np.mean(intra_dist) / inter_dist#all_mean

In [None]:
# load the data
data = np.load("clusters_zred.npy")
print(m, data.dtype.names)
ra0, dec0 = data['RA'].mean(), data['DEC'].mean()
X = np.dstack(((ra0-data['RA'])*np.cos(np.radians(data['DEC'])), data['DEC']-dec0, data['ZRED']))[0]  
n = len(data)   # number of data points

In [None]:
cost(X, 1e6, perc=95)

In [None]:
# creating adjacency matrices for different values of alpha
alpha = 1e6
A = adjacency(X, alpha)
plt.imshow(A)
plt.colorbar()

In [None]:
T = minimum_spanning_tree(A).toarray()
perc = 95 # depends on expected rate of outliers
cut = np.percentile(T[T>0],[perc])[0]
T[T>cut] = 0
i,j = np.where(T>0)

fig = plt.figure()
ax = fig.add_subplot(111, aspect='equal')
ax.scatter(X[:,0], X[:,1], alpha=0.1)
plt.quiver(X[i,0], X[i,1], X[j,0]-X[i,0], X[j,1]-X[i,1], angles='xy', scale_units='xy', scale=1, headwidth=0, headaxislength=0, headlength=0, minlength=0)
fig.tight_layout()

In [None]:
graph = scipy.sparse.csr_matrix(T)
n_components, labels = scipy.sparse.csgraph.connected_components(graph)
bc = np.bincount(labels)
print (n_components, (bc>1).sum())

fig = plt.figure()
ax = fig.add_subplot(111, aspect='equal')
ax.scatter(X[:,0], X[:,1], c=labels, alpha=0.1, cmap='prism')
#plt.quiver(X[i,0], X[i,1], X[j,0]-X[i,0], X[j,1]-X[i,1], angles='xy', scale_units='xy', scale=1, headwidth=0, headaxislength=0, headlength=0, minlength=0)

In [None]:
alphas = 10.**np.arange(-6,6,1)
plt.semilogx(alphas, [cost(X, a, perc=95) for a in alphas])

## Clustering likelihood


The quantity $\lVert V V^T - Y Y^T\rVert^2$ is a quadratic deviation of the predicted cluster labels from the true labels. To avoid ordering issues, the labels generate a one-hot-encoding matrix, which is then symmetrized. One can consider this quadratic deviation the likelihood of the clustering labels given true labels (supervised learning).

In [None]:
def vvt(v):
    return v @ v.T

def one_hot_encode(labels):
    unique_labels = list(np.unique(labels))
    n, m = len(labels), len(unique_labels)
    idx = np.zeros((n, m))
    for i,l in enumerate(labels):
        j = unique_labels.index(l)
        idx[i,j] = 1
    return idx

In [None]:
plt.imshow(vvt(one_hot_encode(labels)))

In [None]:
true_VVT = vvt(one_hot_encode(data['MEM_MATCH_ID']))
plt.imshow(true_VVT)

In [None]:
def cost_vvt(X, alpha, true_VVT, perc=95, show_plot=False):
    A = adjacency(X, alpha)
    T = minimum_spanning_tree(A).toarray()
    cut = np.percentile(T[T>0],[perc])[0]
    T[T>cut] = 0
    
    graph = scipy.sparse.csr_matrix(T)
    n_components, labels = scipy.sparse.csgraph.connected_components(graph)

    if show_plot:
        fig = plt.figure()
        ax = fig.add_subplot(111, aspect='equal')
        ax.scatter(X[:,0], X[:,1], c=labels, alpha=0.1, cmap='prism')
        i,j = np.where(T>0)
        plt.quiver(X[i,0], X[i,1], X[j,0]-X[i,0], X[j,1]-X[i,1], angles='xy', scale_units='xy', scale=1, headwidth=0, headaxislength=0, headlength=0, minlength=0)
        fig.tight_layout()
    
    VVT = vvt(one_hot_encode(labels))
    return ((VVT - true_VVT)**2).sum()

In [None]:
alphas = 10.**np.arange(-10,10,2)
plt.semilogx(alphas, [cost_vvt(X, a, true_VVT) for a in alphas])

In [None]:
alphas = 10.**np.arange(-10,10,2)
plt.semilogx(alphas, [cost_vvt(X, a, true_VVT, perc=90) for a in alphas])

In [None]:
alphas = 10.**np.arange(-10,10,2)
plt.semilogx(alphas, [cost_vvt(X, a, true_VVT, perc=80) for a in alphas])