# Clustering with Scikit-Learn

## The data file

To demonstrate clustering in Scikit-learn we will be using the following image as input data.

<img src=https://raw.githubusercontent.com/emmanueliarussi/DataScienceCapstone/master/5_DataMining/data/butterfly.jpg width="500">

We will be clustering pixel RGB values in order to get a multi-label segmentation of the image. First, let's load the image into a numpy array using [`Image.open()`](https://pillow.readthedocs.io/en/stable/):

In [None]:
from PIL import Image
import numpy as np
import requests
from io import BytesIO

# Image open
response = requests.get('https://raw.githubusercontent.com/emmanueliarussi/DataScienceCapstone/master/5_DataMining/data/butterfly.jpg')
image = Image.open(BytesIO(response.content))
image = np.array(image)

# Image size (in pixels) 481x960. Last 3 channels are the RGB color values. 
original_shape = image.shape
print(original_shape)

Image arrays can be visualized using [`plt.imshow()`](https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.pyplot.imshow.html):

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

plt.figure(figsize=(18,8))

# No axis
plt.axis('off')
plt.imshow(image)

To perform a simple pixel clustering, we first need to eliminate width and height dimensions, therefore *flattening* the array:

In [None]:
# Flatten Width and Height dims, keep RGB. 
X = np.reshape(image, [-1, 3])
print(X.shape) 

Now run [`MeanShift()`](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.MeanShift.html) on the image to clusterize the pixels array X. Mean shift clustering aims to discover “blobs” in a smooth density of samples. It is a centroid-based algorithm, and works by updating candidates for centroids to be the mean of the points within a given region. These candidates are then filtered in a post-processing stage to eliminate near-duplicates to form the final set of centroids.

In [None]:
from sklearn.cluster import MeanShift
# Run MeanShift
ms = MeanShift(bin_seeding=True, n_jobs=4)
ms.fit(X)

Print some clustering output information to get a sense of what was done. From what we see, this yielded 4 clusters, (thus 4 centroid colors). If you re-run this notebook with different parameters, you may find different results.

In [None]:
# Clusters centers
cluster_centers = ms.cluster_centers_
print("Cluster centers:\n{}".format(cluster_centers))

# Number of
labels_unique = np.unique(labels)
n_clusters_ = len(labels_unique)
print("Number of estimated clusters :{}".format(n_clusters_))

Since centers are RGB colors, we can plot them using [`sns.palplot`](https://seaborn.pydata.org/tutorial/color_palettes.html):

In [None]:
import seaborn as sns
color_clusters = np.floor(cluster_centers).astype(int)/255
sns.palplot(sns.color_palette(color_clusters))

Finaly, let's plot the image labeling into clusters:

In [None]:
# Per-pixel labeling
labels = ms.labels_

# Reshape into segmented image
segmented_image = np.reshape(labels, original_shape[:2])  

In [None]:
# Replace labels with RGB centroids
image_color_clusters       = np.array(image, copy=True)  

for i in range(0,n_clusters_):
    mask = (segmented_image == i)
    image_color_clusters[mask] = color_clusters[i]*255

In [None]:
# Plot input and segmentation side by side
plt.figure(figsize=(16,8))
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(image_color_clusters)
plt.axis('off')