## KMeans Image Segmentation

This code reads an image and treats pixel as three element feature vector of unsupervised data. It performs KMeans clustering of the pixels and then re-colours the image using the median of each colour component within the clusters.

In [None]:
# The usual loading of modules
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.image as image
import pandas as pd
from sklearn import cluster
from scipy import misc
import numpy as np

In [None]:
# Load an image.
imagex = image.imread("parrot.jpeg") # You may want to test with a smaller image, it will be faster
plt.figure(figsize = (15,8))
plt.axis("off") # An axis free plot
plt.imshow(imagex) # Display the image
x, y, z = imagex.shape # Get the dimensions of the image (z is the number of colour bytes)
image_2d = imagex.reshape(x*y, z) # Turn the 2D pixels into 1D pixels
image_2dr = np.zeros([x*y,z]) # and create an output array for the re-coloured image.

In [None]:
class KMeansCluster(object):
    def __init__(self, n_clusters, thresh=10**(-4)):
        self.n_clusters = n_clusters
        self.thresh = thresh
        self.cluster_centers = None
        self.cluster_labels = None
        self.n_features = None
        
    def __choose_start_centers__(self, data):
        #calculate initial centers according to the k++ algorithm
        possible_centers = np.copy(data)
        start_centers = np.zeros(shape=(self.n_clusters, self.n_features))
        #choose random center from all points
        center_index = np.random.choice(len(possible_centers))
        start_centers[0] = possible_centers[center_index]
        for i in range(1, self.n_clusters):
            #only choose from points that haven't been selected before
            possible_centers = np.delete(possible_centers, center_index, axis=0)
            d_vecs = possible_centers - start_centers[:i, None]
            dists = np.linalg.norm(d_vecs, axis=-1)
            shortest_distances = np.min(dists, axis=0)
            #probability of choosing a point as the next center is proportional
            #to its distance to the next center
            next_center_probs = shortest_distances**2/np.sum(shortest_distances**2)
            center_index = np.random.choice((len(possible_centers)), p=next_center_probs)
            start_centers[i] = possible_centers[center_index]
        return start_centers
            
    def fit(self, data):
        self.n_features = data.shape[-1]
        self.n_vectors = len(data)
        changes = np.ones(self.n_clusters)
        self.cluster_centers = self.__choose_start_centers__(data)
        while any(changes>self.thresh):
            #calculate distances between points and current cluster centers
            d_vectors = data - self.cluster_centers[:, None]
            distances = np.linalg.norm(d_vectors, axis=-1)
            #determine index of nearest cluster
            self.labels = np.argmin(distances, axis=0)
            new_cluster_centers = np.zeros_like(self.cluster_centers) 
            #calculate new positions of cluster centers by calculating the mean
            #of all points belonging to that cluster
            for k in range(self.n_clusters):
                mask = self.labels == k
                new_cluster_centers[k] = np.sum(data[mask], axis=0)/np.sum(mask)
            changes = np.linalg.norm(new_cluster_centers-self.cluster_centers, axis=-1)
            self.cluster_centers = np.copy(new_cluster_centers)
                    

In [None]:
# THIS IS THE PART YOU ARE TO REPLACE
cluster_count = 5 # How many clusters we want
kmeans_cluster = KMeansCluster(n_clusters=cluster_count) # Create a clusterer
kmeans_cluster.fit(image_2d)  # Apply it to the data
cluster_centers = kmeans_cluster.cluster_centers # we're not using these
cluster_labels = kmeans_cluster.labels # Cluster membership for each pixel

In [None]:
# Turn the clustered data into a DataFrame and add the class as a fourth column
cluster_2d = pd.DataFrame(image_2d, columns=list('abc'))
cluster_2d['d'] = cluster_labels

# loop over each of the clusters and calculate a colour to represent it.
for a in range(0,cluster_count):   
    subset = cluster_2d.loc[cluster_2d['d'] == a] # Extract all the pixels that belong to the cluster
    m = subset.median().values # and calcuate the median of each colour channel
    print(m) # Show each of the colours we're going to use
    
    # Each pixel in the cluster is then painted with the cluster colour
    for i in range(0, len(cluster_labels)):
        if (cluster_labels[i] == a):
            image_2dr[i] = m[0:3]

In [None]:
# Display the results
image_out = image_2dr.reshape(x, y, z)/255.0 #Scale to 0-1 range
plt.figure(figsize = (15,8))
plt.axis("off")
plt.imshow(image_out)

# Use the following to save the image
plt.savefig('clustered_out.png', bbox_inches='tight')