# $k$-means Clustering

In [1]:
import numpy as np
from sklearn.datasets import load_iris
from sklearn.neighbors import NearestNeighbors
from scipy.spatial import distance
import matplotlib.pyplot as plt
from ipywidgets import interact
import ipywidgets as widgets

## Initialization

In [2]:
dataset = load_iris()  # If you want to know more about this dataset: https://en.wikipedia.org/wiki/Iris_flower_data_set
data = dataset['data'][:, 2:]  # We use only the petal length and width
k = 3
iterations = 20

## Implement the Algorithm
### Functions

In [3]:
def reassign(prototypes):
    nb = NearestNeighbors(n_neighbors=1).fit(prototypes)
    
    # Find the closest prototype for each data point
    indices = nb.kneighbors(data, return_distance=False)[:, 0]
    
    return indices

def recalculate(indices):
    prototypes = np.zeros((k, data.shape[1]))
    
    for i in range(k):
        points = data[indices == i, :]
        prototypes[i, :] = np.mean(points, axis=0)
    
    return prototypes

def total_variance(indices, prototypes):
    variance = 0
    
    for i in range(k):
        prototype = prototypes[[i], :]
        points = data[indices == i, :]
        
        # Calculate the distance from the prototype to all assigned data points
        pdist = distance.cdist(prototype, points, metric='sqeuclidean')
        
        variance += np.sum(pdist)
    
    return variance

### Run the Algorithm

In [4]:
np.random.seed(2)  # For reproducibility
cluster_prototypes = []
cluster_indices = []
variances = []

for i in range(iterations+1):
    if i == 0:
        cluster_prototypes.append(data[np.random.randint(0, len(data), 3), :])
    else:
        cluster_prototypes.append(recalculate(cluster_indices[-1]))
    
    cluster_indices.append(reassign(cluster_prototypes[-1]))
    variances.append(total_variance(cluster_indices[-1], cluster_prototypes[-1]))

## Visualize the Result

In [5]:
%matplotlib widget
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(5, 6), gridspec_kw={'height_ratios': [2, 1]})
fig.subplots_adjust(hspace=0.35, top=0.95, bottom=0.1)

def plot_cluster(iteration):
    indices = cluster_indices[iteration]
    prototypes = cluster_prototypes[iteration]
    
    ax1.clear()
    ax2.clear()

    # Plot each cluster
    for i in range(k):
        points = data[indices == i, :]
        ax1.scatter(points[:, 0], points[:, 1], c='C{:d}'.format(i), s=15)
        ax1.scatter(prototypes[i, 0], prototypes[i, 1], c='C{:d}'.format(i), marker='D', s=50, edgecolor='k', linewidth=2)
    
    ax1.set_xlabel('Petal Lengths (cm)')
    ax1.set_ylabel('Petal Widths (cm)')
    ax1.set_title('Cluster result')
    ax1.set_axisbelow(True)
    ax1.grid()
    
    # Variance plot
    ax2.plot(variances)
    ax2.scatter(iteration, variances[iteration], c='gray', zorder=3)
    ax2.set_xlabel('Iteration')
    ax2.set_ylabel('$D_{var}$')
    ax2.set_title('Variance criterion')
    ax2.set_axisbelow(True)
    ax2.grid()

iteration_slider = widgets.IntSlider(min=0, max=iterations, description='Iteration:')
interact(plot_cluster, iteration=iteration_slider);

FigureCanvasNbAgg()

interactive(children=(IntSlider(value=0, description='Iteration:', max=20), Output()), _dom_classes=('widget-iâ€¦

In [6]:
sol_vars = ['cluster_prototypes', 'cluster_indices', 'variances']