**NOTE: This notebook is written for the Google Colab platform. However it can also be run (possibly with minor modifications) as a standard Jupyter notebook.** 



In [None]:
#@title -- Installation of Packages -- { display-mode: "form" }
import sys
!{sys.executable} -m pip install git+https://github.com/michalgregor/class_utils.git

In [None]:
#@title -- Import of Necessary Packages -- { display-mode: "form" }
import numpy as np
from PIL import Image
from sklearn.cluster import MiniBatchKMeans
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt

In [None]:
#@title -- Downloading Data -- { display-mode: "form" }
DATA_HOME = "https://github.com/michalgregor/ml_notebooks/blob/main/data/{}?raw=1"

from class_utils.download import download_file_maybe_extract
download_file_maybe_extract(DATA_HOME.format("images/photo_rome.jpg"), directory="data")

# also create a directory for storing any outputs
import os
os.makedirs("output", exist_ok=True)

In [None]:
#@title -- Auxiliary Functions -- { display-mode: "form" }

def plot_colors(colors, cluster_centers=None):
    _, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(14, 6))
    ax1.scatter(colors[:, 0], colors[:, 1], s=10, c=colors/255.0)

    ax1.set_xlabel("red")
    ax1.set_ylabel("green")
    ax1.grid(ls='--')
    ax1.set_axisbelow(True)

    if not cluster_centers is None:
        ax1.scatter(cluster_centers[:, 0], cluster_centers[:, 1], s=100,
            c='orange', edgecolors='k', linewidths=2.5
        )

    ax2.scatter(colors[:, 0], colors[:, 2], s=10, c=colors/255.0)
    ax2.set_xlabel("red")
    ax2.set_ylabel("blue")
    ax2.grid(ls='--')
    ax2.set_axisbelow(True)

    if not cluster_centers is None:
        ax2.scatter(cluster_centers[:, 0], cluster_centers[:, 2], s=100,
            c='orange', edgecolors='k', linewidths=2.5
        )

    ax3.scatter(colors[:, 1], colors[:, 2], s=10, c=colors/255.0)
    ax3.set_xlabel("green")
    ax3.set_ylabel("blue")
    ax3.grid(ls='--')
    ax3.set_axisbelow(True)

    if not cluster_centers is None:
        ax3.scatter(cluster_centers[:, 1], cluster_centers[:, 2], s=100,
            c='orange', edgecolors='k', linewidths=2.5
        )

## k-Means for Colour Quantization

In this example, we are going to apply $k$-means to something a little different. We are going to take an image and try to compress it by quantizing the colour space. By default, our image is going to be RGB. We are going to have three 8-bit numbers: one for each colour channel (reg, green, blue). This gives us 256 different levels for each colour channel, i.e. $256^3 = 16\ 777\ 216$ different colours.

Let's say that we will instead store a small palette of colours along with the image and for each pixel just store the index of its colour in that palette. That way we can use one 8-bit number for each pixel in place of three. Naturally, the number does not even need to be an 8-bit one: it could be smaller or larger depending on how large our palette is.

In any case, the principle is simple enough – the real question is how to find a good palette. We want to pick colours so that the compressed image is not distorted too much. What we are going to do, then, is take the pixels from our image and perform $k$-means clustering on them. This way we'll obtain a palette with $k$ colours that represent clusters in colour space.

### Loading an Image

Let's start by loading and displaying an image. As we'll be able to see, it mostly has green, purple, blue, brown and white colours.



In [None]:
img = np.array(Image.open("data/photo_rome.jpg"))
plt.figure(figsize=(10, 6))
plt.imshow(img)
plt.axis('off');

### Reshaping the Image

Now let's reshape our image into a matrix of points in the colour space, i.e. an $m \times n$ matrix, where $m$ is the total number of pixels in the image and $n$ is the dimension of the colour space – in our case $n=3$ because our image in RGB.



In [None]:
img_shape = img.shape
X = img.reshape(-1, img_shape[2])
X.shape

### Exploring the Colour Space

Afterwards we can use `np.unique` to check how many uniques colours there are in the image – as it turns out, it's actually a lot.



In [None]:
colors = np.unique(X, axis=0)
len(colors)

To get a better visual understanding of what regions our image occupies in the colour space, we can plot the points in three planes: in the red vs. green plane, the red vs. blue plane and the green vs. blue plane. We colour each point by its actual RGB colour.

As we can see, our colours do indeed cover just a relatively small sub-space of the colour space so some compression should be possible.



In [None]:
np.random.seed(10)
sel_colors = colors[np.random.randint(0, len(colors), size=2500)]
plot_colors(sel_colors)

### Mini-Batch $k$-Means

Next we are going to apply $k$-means. We are going to set the number of clusters to 32 – this means we'll looking for a 32-colour palette. The number of points we are working with is relatively large. For this reason we will be using the mini-batch version of $k$-means – this will make finding the cluster centres much faster.

The idea behind mini-batch $k$-means is not to use all points in every step, but just draw a different sub-sample at every step and work with that. This way one can even apply $k$-means to data that does not fit into memory all at once. Note that virtually the same idea is used to train artificial neural networks on very large datasets.



In [None]:
model = MiniBatchKMeans(n_clusters=32)
model.fit(X)

Once we have fitted our model, we run the dataset through it and obtain the cluster identifiers – these determine which palette colour we are going to assign to each original pixel. We also retrieve the palette itself by copying cluster centres from the model and casting them back to 8-bit integers. With these two elements, our image is effectively quantized.



In [None]:
clusts = model.predict(X)
cluster_centers = model.cluster_centers_.astype(np.uint8)

To inspect the results, we are again going to plot the colours in the three different planes, now also displaying the points corresponding to our cluster centres.



In [None]:
plot_colors(sel_colors, cluster_centers)

### Reconstructing the Quantized Image

Finally, let's reconstruct the image from our quantized version and see what the result looks like. The only thing we need to do is to walk through all the points again and read out their matching colours from the palette. Once we are done, we reshape the data back into the original image shape and we display the reconstructed image.

As we can see, the colours are definitely less vibrant, but even with 32 colours the bulk of the image is preserved quite well. The most notable exception to this is the sky, which was formed by a gradient of colours, which has now been very visibly quantized.



In [None]:
quantized_X = cluster_centers[clusts]
quantized_img = quantized_X.reshape(img_shape)

plt.figure(figsize=(14, 10))
plt.imshow(quantized_img)
plt.axis('off');

---
### Task: Re-run with Different Palette Sizes

**Run the algorithm again with different palette sizes, e.g. with 16 colours and with 64 colours. Plot the resulting reconstructions.** 

---


In [None]:

# ---


### Quantizing with a Random Palette

Now, just for the sake of comparison, let us attempt colour quantization with a randomly drawn palette. We are first going to pick a number of colours at random and then – for each pixel – find the nearest match in this new palette using `NearestNeighbors`. Finally, we again display the resulting reconstruction. As you are going to see, the result is far from ideal.



In [None]:
np.random.seed(10)

model = NearestNeighbors(n_neighbors=1)
cluster_centers_ = np.random.uniform(0, 255, (32, 3))
model.fit(cluster_centers_)

clusts = model.kneighbors(X)[1]

cluster_centers = cluster_centers_.astype(np.uint8)
quantized_X = cluster_centers[clusts]
quantized_img = quantized_X.reshape(img_shape)

In [None]:
plt.figure(figsize=(10, 6))
plt.imshow(quantized_img)
plt.axis('off');