# $k$-means-based approximate spectral clustering

This algorithm is from "[Fast Approximate Spectral Clustering](https://people.eecs.berkeley.edu/~jordan/papers/yan-huang-jordan-kdd09.pdf)" by Yan et el.

In [17]:
from itertools import permutations
import numpy as np
import scipy as sp
import scipy.sparse.linalg
import scipy.cluster.vq

In [2]:
A = np.loadtxt('../data/processed/usps.csv', delimiter=',')

In [85]:
inds = A[:, -1] < 6
X = A[inds, :-2]
Y = A[inds, -1].astype(int)
k = len(np.unique(Y))
n, d = X.shape
n, d

(4240, 254)

In [86]:
# data reduction ratio
gamma = 8
k_prime = n // gamma

In [87]:
centroids_prime, distortion_prime = sp.cluster.vq.kmeans(X, k_prime)
distortion_prime

3.0195420842678549

In [88]:
y_prime = np.empty(n)
for i in range(n):
    dists = np.array([np.linalg.norm(X[i] - c) for c in centroids_prime])
    y_prime[i] = np.argmin(dists)

In [89]:
X_prime = centroids_prime
n_prime, d_prime = X_prime.shape
n_prime, d_prime

(502, 254)

In [90]:
W = np.empty((n_prime, n_prime))
for i in range(n_prime):
    for j in range(i, n_prime):
#         val = np.e ** (-1 * np.linalg.norm(X_prime[i] - X_prime[j]) ** 2)
        val = np.linalg.norm(X_prime[i] - X_prime[j]) ** 2
        W[i, j] = val
        W[j, i] = val

In [91]:
ww = W.sum(axis=0)
D = np.diag(ww)
D_ = np.diag(1 / np.sqrt(ww))
L = np.identity(n_prime) - D_.dot(W).dot(D_)

In [92]:
V, Z = sp.linalg.eigh(L, eigvals=(n_prime-2, n_prime-1))

In [93]:
Z_ = sp.cluster.vq.whiten(Z)
centroids, distortion = sp.cluster.vq.kmeans(Z_, k)
centroids, distortion

(array([[-0.47361874,  0.28963956],
        [ 0.05891703, -1.42207706],
        [ 0.91877875,  1.30359236],
        [ 0.98327114, -0.27152449],
        [-1.47151427,  0.32201446]]), 0.54906943243154493)

In [94]:
y_hat_prime = np.empty(n_prime)
for i in range(n_prime):
    dists = np.array([np.linalg.norm(Z_[i] - c) for c in centroids])
    y_hat_prime[i] = np.argmin(dists)

In [95]:
y_hat = np.empty(n, dtype=int)
for i in range(n):
    y_hat[i] = int(y_hat_prime[ int(y_prime[i]) ])

In [96]:
perms = []
for p in permutations(np.arange(1, k + 1)):
    P = dict()
    for i in range(k):
        P[i] = p[i]
    perms.append(P)

In [97]:
accuracy = np.zeros(len(perms))
for i in range(len(perms)):
    yy = y_hat.copy()
    for key, val in perms[i].items():
        yy[y_hat == key] = val
    accuracy[i] = (Y == yy).sum() / n * 100
accuracy.max()

67.09905660377359

In [98]:
np.bincount(y_hat)

array([ 873, 1132,  786,  746,  703], dtype=int64)