# K-means iteration
Stough, DIP

Here we do k-means clustering on an image, to get
representative colors for the image. 

In [None]:
%matplotlib widget
import matplotlib.pyplot as plt
import numpy as np

# For importing from alternative directory sources
import sys  
sys.path.insert(0, '../dip_utils')

from matrix_utils import (arr_info,
                          make_linmap)
from vis_utils import (vis_rgb_cube,
                       vis_hists,
                       vis_pair,
                       vis_surface)

from scipy.spatial.distance import cdist

K = 16
MAXITER = 20
NUMPOINTS = 100

In [None]:
I = plt.imread('../dip_pics/bellagio.jpg').astype(float)
X = np.stack([I[...,i].ravel() for i in range(3)]).T

In [None]:
plt.figure()
plt.imshow(I/255)

In [None]:
X[:10,:]

In [None]:
# For fun: Let's pick some random (and good) cluster colors
clusterColors = np.random.rand(K, 3) # Picking random colors for each cluster.
varsSoFar = np.var(clusterColors, axis=0) # Should be 1 x 2 of the variance of each column.

for i in range(3*K):
    tempColors = np.random.rand(K, 3) #Picking K random colors
    vartemp = np.var(tempColors, axis=0)
    if np.all(vartemp > varsSoFar):
        clusterColors = tempColors
        varsSoFar = vartemp

In [None]:
clusterColors

&nbsp;

### Pick some initial cluster centers.

In [None]:
# K-means: initialization
# pick K initial cluster centers.
# whichinit = random.randint(0, len(X), size=(K,)) # Could generate repeats.
whichinit = np.random.choice(len(X), size=K, replace=False)
CC = X[whichinit, :].copy() # Cluster Centers

In [None]:
CC_init = CC.copy()
CC

&nbsp;

### The main Expectation-Maximization loop

Basically, we assign a cluster to each point, and then
recompute the clusters based on that assignment.

In [None]:
# K-means: compute: for every data point determine which center is closest.
# Need some magic function that computes the distance between every row of X
# (the points) and every row of CC (the clusters).
for i in range(MAXITER):
    D = cdist(X, CC, 'euclidean')
    # D should be NUMPOINTS x K

    whichCluster = np.argmin(D, axis=1) # NUMPOINTS x 1 of which center was closest

    # K-means: recompute the cluster centers as the mean of the data in each cluster
    for c in range(K):
        if np.any(whichCluster == c):
            CC[c,:] = np.mean(X[whichCluster == c, :], axis=0) # average of just those that were closest to c.

In [None]:
# Doing this on a big image, don't want to scatter 100Ks of points, really slow.
rands = np.sort(np.random.choice(len(X), size=500*K, replace=False))


f, ax = plt.subplots(1,3, figsize=(9,3), sharex=True, sharey=True)
ax[0].scatter(X[rands,0], X[rands,1], c='gray', s=20)
ax[0].set_title('Original Data')


ax[1].scatter(X[rands,0], X[rands,1], c='gray', alpha=.5, s=20)
ax[1].scatter(CC_init[:,0], CC_init[:,1], c=CC_init/255, s=50)
ax[1].set_title('Initial Cluster Centers')


pointColors = CC[whichCluster[rands], :]
clusterEdgeColors = 1 - clusterColors # for contrast, make the cluster center edges opposite.
# clusterEdgeColors = 1 - CC/255 # for contrast, make the cluster center edges opposite.

ax[2].scatter(X[rands,0], X[rands,1], c=pointColors/255, alpha=.5, s=20)
ax[2].scatter(CC[:,0], CC[:,1], c=CC/255, edgecolors=clusterEdgeColors, s=50)
ax[2].set_title('Recomputed Clusters')

plt.tight_layout()

In [None]:
f, ax = plt.subplots(1,2, figsize=(8,3), sharex=True, sharey=True)

ax[0].imshow(I/255)
ax[0].set_title('Original Image')

# Reconstructed Image.
Ir = np.reshape(CC[whichCluster,:], I.shape)
ax[1].imshow(Ir/255) # Because it's floating point.
ax[1].set_title('{} color reconstruction'.format(K))

plt.tight_layout()

In [None]:
vis_rgb_cube(I)

In [None]:
vis_rgb_cube(Ir)