# K-Means 

K-Means is an unsupervised clustering algorithm that aims to partition $n$ observations into $k$ clusters (which is a hyperparameter).

The algorithm is straightforward:

1. Initialize $k$ centroids randomly (choose $k$ random points to be the initial centroids)
2. Assign each observation to the closest centroid
3. Recompute the centroids as the mean of the observations assigned to it
4. Repeat steps 2 and 3 until the centroids don't change

In [17]:
import numpy as np

class KMeans():

    def __init__(self, k=3, max_iter=100, eps=1e-4):
        self.k = k
        self.max_iter = max_iter
        self.eps = eps
        
    def fit_predict(self, X):
        self.X = X
        self.n_samples, self.n_features = X.shape
        
        # initialize cluster centers
        random_sample_idxs = np.random.choice(self.n_samples, self.k, replace=False)
        self.centroids = [X[idx] for idx in random_sample_idxs]

        for _ in range(self.max_iter):
            # assign samples to closest centroids (create clusters)
            clusters = self._create_clusters(self.centroids)

            # calculate new centroids from the clusters
            centroids_old = self.centroids
            self.centroids = self._get_centroids(clusters, self.n_features)

            # check if clusters have changed
            if self._is_converged(centroids_old, self.centroids):
                break
        
        return self._get_cluster_labels(clusters)

    def _get_cluster_labels(self, clusters):
        # each sample will get the label of the cluster it was assigned to
        labels = np.empty(self.n_samples)
        for cluster_idx, cluster in enumerate(clusters):
            for sample_idx in cluster:
                labels[sample_idx] = cluster_idx
        return labels

    def _is_converged(self, centroids_old, centroids):
        # distances between each old and new centroids, for all centroids
        distances = [self._euclidean_distance(centroids_old[i], centroids[i]) for i in range(self.k)]
        return sum(distances) <= self.eps
        

    def _get_centroids(self, clusters, n_features):
        # assign mean value of clusters to centroids
        centroids = np.zeros((self.k, n_features))
        for cluster_idx, cluster in enumerate(clusters):
            cluster_mean = np.mean(self.X[cluster], axis=0)
            centroids[cluster_idx] = cluster_mean
        return centroids

    def _create_clusters(self, centroids):
        # assign the samples to the closest centroids to create clusters
        clusters = [[] for _ in range(self.k)]
        for idx, sample in enumerate(self.X):
            centroid_idx = self._closest_centroid(sample, centroids)
            clusters[centroid_idx].append(idx)
        return clusters

    def _closest_centroid(self, sample, centroids):
        # distance of the current sample to each centroid
        distances = [self._euclidean_distance(sample, point) for point in centroids]
        closest_idx = np.argmin(distances)
        return closest_idx
    
    def _euclidean_distance(self, x1, x2):
        return np.sqrt(np.sum((x1 - x2)**2))


In [23]:
# Make blobs to try the k-means algorithm

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
from matplotlib.animation import FuncAnimation
from IPython import display
import matplotlib.pyplot as plt

# Make blobs
n_clusters = 4
X, y = make_blobs(n_samples=1000, centers=n_clusters, cluster_std=0.60, random_state=0)

num_iterations = 50

def animate(frame: int):
    # Clear the plot
    plt.clf()

    if (frame >= 1):
        kmeans = KMeans(k=n_clusters, max_iter=frame)
        y_pred = kmeans.fit_predict(X)

        # Plot the centroids
        plt.scatter(X[:, 0], X[:, 1], c=y, cmap='viridis', alpha=0.5)
        plt.scatter(kmeans.centroids[:, 0], kmeans.centroids[:, 1], c='red', marker='x')
        plt.title(f'K-means clustering after {frame} iterations')

fig = plt.figure()
anim = FuncAnimation(fig, animate, init_func=None, frames=num_iterations, interval=500, blit=False)
video = anim.to_html5_video()
html = display.HTML(video)
display.display(html)
plt.close()
