In [1]:
import numpy as np 
import os
import sys
import cv2
import matplotlib.pyplot as plt
import skimage

Function return the closet cluster centroid to each sample in data

In [2]:
def findClosetCentroids(data,centroids):
    """Return the closet centroid index to each data point"""
    nSample=data.shape[0]
    #number of clusters
    K=centroids.shape[0]
    index=np.zeros((nSample,1))
    temp=np.zeros((K,1))
    for i in range(nSample):
        for j in range(K):
            temp[j]=np.sum((data[i,:]-centroids[j,:])**2)
            index[i]=np.argmin(temp)
    return index

Function to update the cluster's centroid base on the mean of members in each cluster

In [3]:
def updateCentroids(data,index,K):
    """Update the centroids"""
    temp=np.zeros((K,data.shape[1]))
    count=np.zeros((K,1))
    for i in range(index.shape[0]):
        temp[int(index[i])]+=data[i]
        count[int(index[i])]+=1
    centroids=temp/count
    print("Number of samples each cluster:\n",count)
    print("Centroids:\n",centroids)
    print("------------------------------------------------")
    return centroids

Initialize centroids list randomlly, each one is picked in the data sample

In [None]:
def randomInitCentroids(data,K):
    print(data.shape)
    centroids=np.zeros((K,data.shape[1]))
    for i in range(K):
        centroids[i]=data[np.random.randint(0,data.shape[0]+1)]
    return centroids

Implement K-mean clustering algorithm

In [None]:
def KmeanClustering(data,K,epoch=500):
    """K-mean clustering implementation"""
    centroids=randomInitCentroids(data,K)
    for i in range(epoch):
        print("Number of interations:",i+1)
        index=findClosetCentroids(data,centroids)
        centroids=updateCentroids(data,index,centroids.shape[0])
    return centroids

Testing with k in {2,4,6,8}

In [None]:
img=skimage.io.imread("test0.png")
data=img.reshape(-1,3)
print(data)
for i,k in enumerate([2,4,6,8]):
    print("K-mean clustering with K=",k)
    centroids=KmeanClustering(data,k,20)
    index=findClosetCentroids(data,centroids)
    flat_ret=data.copy()
    for j in range(flat_ret.shape[0]):
        flat_ret[j]=centroids[int(index[j])]
    ret=flat_ret.reshape(img.shape)
    print(ret)
    plt.subplot(2,2,i+1)
    plt.title("K=%d"%k)
    plt.imshow(ret)

[[255 255 255]
 [255 255 255]
 [255 255 255]
 ...
 [255 255 255]
 [255 255 255]
 [255 255 255]]
K-mean clustering with K= 2
(110889, 3)
Number of interations: 1
Number of samples each cluster:
 [[70299.]
 [40590.]]
Centroids:
 [[168.2740295   82.14287543  94.20271981]
 [245.39418576 224.21399359 216.93158413]]
------------------------------------------------
Number of interations: 2
Number of samples each cluster:
 [[73962.]
 [36927.]]
Centroids:
 [[171.15960899  85.54697007  96.11618128]
 [247.26457606 231.48869391 225.2732418 ]]
------------------------------------------------
Number of interations: 3
Number of samples each cluster:
 [[75296.]
 [35593.]]
Centroids:
 [[172.09058914  86.83827826  96.98499256]
 [248.14747282 234.22675807 228.27600933]]
------------------------------------------------
Number of interations: 4
Number of samples each cluster:
 [[75721.]
 [35168.]]
Centroids:
 [[172.36844469  87.26748194  97.28462382]
 [248.46835191 235.08379777 229.21749886]]
-------------

Visualize the result with k=2,4,6,8

In [None]:
plt.show()