# Standard Nystrom-based Spectral Clustering

This algorithm is from "[Fast Spectral Clustering via the Nystrom Method](http://www.cs.columbia.edu/~jebara/papers/ALT2013FSCVTNM.pdf)" by Choromanska et el.

In [186]:
from itertools import permutations
import numpy as np
import scipy as sp
import scipy.sparse.linalg
import scipy.cluster.vq

In [187]:
A = np.loadtxt('../data/processed/usps.csv', delimiter=',')

In [188]:
np.random.shuffle(A)
inds = A[:, -1] < 3
X = A[inds, :-2]
Y = A[inds, -1].astype(int)
k = len(np.unique(Y))
n, d = X.shape
n, d

(2199, 254)

In [189]:
m = 500
# inds = np.random.choice(n, m, replace=False)
inds = np.arange(m)
inv_inds = np.array([a for a in range(n) if a not in inds])

In [190]:
mu = 0
for i in range(m):
    for j in range(m):
        mu += np.linalg.norm(X[inds[i]] - X[inds[j]]) ** 2
mu /= (m ** 2)
mu = 1 / mu

In [191]:
A_11 = np.empty((m,m))
for i in range(m):
    for j in range(i, m):
        val = np.e ** (-mu * np.linalg.norm(X[inds[i]] - X[inds[j]]) ** 2)
        A_11[i, j] = val
        A_11[j, i] = val

In [192]:
C = np.empty((n-m, m))
for i in range(n-m):
    for j in range(m):
        val = np.e ** (-mu * np.linalg.norm(X[inv_inds[i]] - X[inds[j]]) ** 2)
        C[i, j] = val
C = np.vstack((A_11, C))

$\hat{D} \leftarrow \text{diag}(A_{:1} A_{11}^{-1} A_{:1}^{\intercal} \matrix{1})$ <br>
$D_{:1} \leftarrow \text{diag}(A_{:1} \matrix{1})$

In [193]:
dd_hat = C.dot(np.linalg.inv(A_11)).dot(C.T).dot(np.ones(n))
D_hat = np.diag(dd_hat)
D_hat_inv = np.diag(1 / np.sqrt(dd_hat))

D_n1 = np.diag(C.dot(np.ones(m)))
D_11_inv = np.diag(1 / np.sqrt(A_11.dot(np.ones(m))))

$ \hat{M}_{:1} = [ M_{11} \hat{M}_{21}^{\intercal} ]^{\intercal} \leftarrow \hat{D}^{-\frac{1}{2}} A_{:1} D_{11}^{-\frac{1}{2}}$

In [194]:
M_hat_n1 = D_hat_inv.dot(C).dot(D_11_inv)

In [195]:
M_11 = M_hat_n1[inds, :].T
M_21 = M_hat_n1[inv_inds, :].T

V, U = np.linalg.eig(np.linalg.inv(M_11))

M_11_inv = U.dot(np.diag(np.sqrt(V))).dot(U.T)

S = M_11 + M_11_inv.dot(M_21).dot(M_21.T).dot(M_11_inv)
U, L, T = np.linalg.svd(S)

V = M_hat_n1.dot(M_11_inv).dot(U).dot(np.diag(1 / np.sqrt(L)))

In [196]:
Z = V[:, :k]
centroids, distortion = sp.cluster.vq.kmeans(Z, k)
centroids, distortion

(array([[-0.02330462, -0.01969856],
        [-0.01934146,  0.02120317]]), 0.0042111799406794751)

In [197]:
y_hat = np.zeros(n, dtype=int)
for i in range(n):
    dists = np.array([np.linalg.norm(Z[i] - c) for c in centroids])
    y_hat[i] = np.argmin(dists)

perms = []
for p in permutations(np.arange(1, k+1)):
    P = dict()
    for i in range(k):
        P[i] = p[i]
    perms.append(P)

accuracy = np.zeros(len(perms))
for i in range(len(perms)):
    yy = y_hat.copy()
    for key, val in perms[i].items():
        yy[y_hat == key] = val
    accuracy[i] = (Y == yy).sum() / n * 100
accuracy.max()

98.954070031832657

In [198]:
np.bincount(yy)

array([   0, 1171, 1028], dtype=int64)

In [199]:
np.bincount(Y)

array([   0, 1194, 1005], dtype=int64)