# Clustering

The goal of this interactive demo is to show you how a machine learning model can perform clustering. However, keep in mind, that the code in this notebook was simplified for the demo, and should not be used as a plug and play example for real machine learning projects.

In this notebook we will explore three different types of clustering approaches:

- Centroid based
- Density based
- Connectivity based

## Centroid based clustering

Centroid based clustering is one of the more "simple" and straight to the point clustering. You specify how many clusters you want to have and the model will partition your data into exactly as many clusters.

Let's take a look at a rather simple dataset, a single image.

In [None]:
# Load an image of a flower
import imageio as io
img = io.v2.imread("flower.jpg") / 255.0
print(f"The image has a shape of {img.shape}.")

# Plot image
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 6))
plt.imshow(img);

Very beautiful! Now, as we can see, this color image has a size of 214 x 320 pixels. So let's consider each of this pixels as an individual data point, and the three RGB color channels red, green and blue, as the datasets features.

In [None]:
# Reshape image to 2-dimensional dataset
X = img.reshape(-1, 3)
X.shape

Given that this dataset only has 3 dimensions, let's go ahead and visualize each pixel and their corresponding color value in a 3D plot.

In [None]:
from utils import plot_rgb_space
plot_rgb_space(X)

Great! So we can see that we have a lot of green and red colors,dark and bright but not much of blue. So let's go ahead and use a centroid based clustering approach to partiion this RGB color space into N clusters.

In [None]:
# Number of clusters
n_clusters = 8

For the centroid clustering routine we will be using `KMeans`, a simple but efficent model.

In [None]:
from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=n_clusters)

# Fit model to data
%time kmeans.fit(X)

print("\nFinished training.")

Once the model is trained we can take a closer look at the centroids it found. In our case, these centroids actually represent points in the RGB-color space and as such can be visualized as follows:

In [None]:
# Plotting of RGB-color centroids
plt.figure(figsize=(12, 2))
plt.imshow(kmeans.cluster_centers_[None, ...], aspect="auto", interpolation="nearest")
plt.axis("off");

Going one step further, we could now replace any of the $256^3$ color combinations in the original image, by the closest centroid value.

In [None]:
# Compute closest centroid label
centroid_labels = kmeans.predict(X)

# Reshape centroid labels back into an image
img_centroid = kmeans.cluster_centers_[centroid_labels].reshape(img.shape)

# Plot centroid labeled image
plt.figure(figsize=(10, 6))
plt.imshow(img_centroid)
plt.title(f"Quantized image using {n_clusters} centroids");

<div class="alert alert-success">
  <h2>Exercise</h2>
    <p></p>
Change the <code>n_clusters</code> parameter above to anything between 1 and 1000 and rerun all the code  after that. How does this effect the quantized image here? Is there an ideal sweet spot of number of clusters?
</div>

## Density based clustering

For the density based clustering approach, let's take a different dataset. Mostly because these algorithms sometimes struggle when the datapoints are too close (i.e. dense) to each other, as can be seen in the 3-dimensional RGB plots above.

To simplify things, let's quickly create a synthethic dataset:

In [None]:
# Create synthethic dataset
from utils import create_synthethic_dataset

X = create_synthethic_dataset(n_points_per_cluster=250)

# Visualize snythethic dataset
plt.figure(figsize=(7, 7))
plt.scatter(*X.T, s=10, alpha=0.5);

The great thing about density based clustering routines is that we don't have to specify how many clusters we want to extract, as this is not always easy to know. However what we need to specify are the density criteria to classify a region as dense enough to be a cluster.

In [None]:
# Create density estimator of type OPTICS
from sklearn.cluster import OPTICS
clust = OPTICS(min_samples=50, xi=0.05, min_cluster_size=0.05)

# Train the density clustering model
%time clust.fit(X)

print("\nFinished training.")

Once the model is trained we can take a closer look at which clusters it found.

In [None]:
# Plot data with cluster labels
plt.figure(figsize=(7, 7))
for idx in range(len(set(clust.labels_)) - 1):
    Xk = X[clust.labels_ == idx]
    plt.scatter(Xk[:, 0], Xk[:, 1], s=10, alpha=0.5, label=idx)
plt.plot(X[clust.labels_ == -1, 0],
         X[clust.labels_ == -1, 1],
         "k+", alpha=0.2, label="outlier")
plt.legend();

To better understand how the density algorithm identified certain points as outliers, let's take a look at the density of the individual clusters.

In [None]:
# Extract important cluster properties
import numpy as np

space = np.arange(len(X))
reachability = clust.reachability_[clust.ordering_]
labels = clust.labels_[clust.ordering_]

# Create reachability plot
plt.figure(figsize=(12, 5))
for idx in range(len(set(clust.labels_)) - 1):
    Xk = space[labels == idx]
    Rk = reachability[labels == idx]
    plt.scatter(Xk, Rk, alpha=0.5, label=idx)
plt.plot(space[labels == -1], reachability[labels == -1],
         "k.", alpha=0.2, label="outlier")
plt.ylabel("Reachability (epsilon distance)")
plt.legend();

## Connectivity based clustering

Last but not least, let's take a look at a connectivity based clustering routine. As for the other approaches, there are multiple models that can perform this task. Each with different advantages and disadvantages.

For this example, we will simplified version of the MNIST dataset and as a connectivity based clustering model, we will be using `AgglomerativeClustering`.

So let's start with preparing the dataset!

In [None]:
# Load digits dataset from file
data = np.load('digits.npy')
X = data[:, :2]
y = data[:, 2]
X.shape

The digits dataset contains multiple hand written examples of the digits from 0 to 9. Once the dataset dimension was reduces to two, we can easily visualize it in a nice plot.

In [None]:
from utils import plot_clustering
plot_clustering(X, y, y, title="Ground truth / Correct labeling")

Let's now use a `AgglomerativeClustering` model to cluster this 2-dimensional dataset. While there are multiple parameters that we could tweak, let's only manipualte the `distance_threshold` parameter.

In [None]:
# Linkage distance threshold above which, clusters will not be merged
distance_threshold = 0.01

In [None]:
# Create AgglomerativeClustering model
from sklearn.cluster import AgglomerativeClustering
model = AgglomerativeClustering(distance_threshold=distance_threshold, n_clusters=None)

# Train model on dataset
%time model.fit(X)

print("\nFinished training.")

Once the connectivity based clustering model is trained, we can go ahead and plot the original dataset, and color code the predicted cluster labels individually.

In [None]:
plot_clustering(X, y, model.labels_, "Detected clusters")

The great thing about connectivity based clustering routines is that they allow us to perform hierarchical clustering. In other words, the model can tell us which clusters are closest to each other and potentially could be merged, or how a big cluster could be split into smaller ones.

A great way to visualize this inter-dependence between samples and clusters are dendrograms.

Note, the color coding in the following figure doesn't correspond to the colors in number plot above! But what it shows is the 8 clusters, each represented by two nodes (i.e. two smaller clusters) and how these 8 clusters hierarchically combine into the full dataset.

In [None]:
# Plot dendrogram
from utils import plot_dendrogram

plt.figure(figsize=(15, 4))
plt.title("Hierarchical Clustering Dendrogram")
plot_dendrogram(model, truncate_mode="level", p=3, color_threshold=distance_threshold)
plt.xlabel("Number of points in node.")
plt.show()

<div class="alert alert-success">
  <h2>Exercise</h2>
    <p></p>
Change the <code>distance_threshold</code> parameter above to anything between 0 and 0.1 (see the y-axis in the dendrogram to know where to cut) and rerun the code  after that. How does this effect the clustering?
</div>