In [1]:
import numpy as np 
import os
import sys
import cv2
import matplotlib.pyplot as plt
import skimage

Function return the closet cluster centroid to each sample in data

In [2]:
def findClosetCentroids(data,centroids):
    """Return the closet centroid index to each data point"""
    nSample=data.shape[0]
    #number of clusters
    K=centroids.shape[0]
    index=np.zeros((nSample,1))
    temp=np.zeros((K,1))
    for i in range(nSample):
        for j in range(K):
            temp[j]=np.sum((data[i,:]-centroids[j,:])**2)
            index[i]=np.argmin(temp)
    return index

Function to update the cluster's centroid base on the mean of members in each cluster

In [3]:
def updateCentroids(data,index,K):
    """Update the centroids"""
    temp=np.zeros((K,data.shape[1]))
    count=np.zeros((K,1))
    for i in range(index.shape[0]):
        temp[int(index[i])]+=data[i]
        count[int(index[i])]+=1
    centroids=temp/count
    print(count)
    print(centroids)
    return centroids

Initialize centroids list randomlly, each one is picked in the data sample

In [None]:
def randomInitCentroids(data,K):
    print(data.shape)
    centroids=np.zeros((K,data.shape[1]))
    for i in range(K):
        centroids[i]=data[np.random.randint(0,data.shape[0]+1)]
    return centroids

Implement K-mean clustering algorithm

In [None]:
def KmeanClustering(data,K,epoch=500):
    """K-mean clustering implementation"""
    centroids=randomInitCentroids(data,K)
    for i in range(epoch):
        print("Number of interations:",i+1)
        index=findClosetCentroids(data,centroids)
        centroids=updateCentroids(data,index,centroids.shape[0])
    return centroids

Testing with k in {2,4,6,8}

In [None]:
img=skimage.io.imread("test0.png")
data=img.reshape(-1,3)
print(data)
for i,k in enumerate([2,4,6,8]):
    print("K-mean clustering with K=",k)
    centroids=KmeanClustering(data,k,20)
    index=findClosetCentroids(data,centroids)
    flat_ret=data.copy()
    for j in range(flat_ret.shape[0]):
        flat_ret[j]=centroids[int(index[j])]
    ret=flat_ret.reshape(img.shape)
    print(ret)
    plt.subplot(2,2,i+1)
    plt.title("K=%d"%k)
    plt.imshow(ret)

[[255 255 255]
 [255 255 255]
 [255 255 255]
 ...
 [255 255 255]
 [255 255 255]
 [255 255 255]]
K-mean clustering with K= 2
(110889, 3)
Number of interations: 1
[[54187.]
 [56702.]]
[[238.61900456 200.77051691 192.52440622]
 [156.255476    70.47821946  88.09724525]]
Number of interations: 2
[[44937.]
 [65952.]]
[[242.88459399 216.09671318 208.23702072]
 [164.90085213  78.30952814  92.03758794]]
Number of interations: 3
[[39079.]
 [71810.]]
[[246.11415338 227.17712838 220.31029453]
 [169.50495753  83.51974655  94.946442  ]]
Number of interations: 4
[[36321.]
 [74568.]]
[[247.64970678 232.73524958 226.64370474]
 [171.59050799  86.12583146  96.49828345]]
Number of interations: 5
[[35391.]
 [75498.]]
[[248.29654432 234.63776101 228.72244356]
 [172.22420461  87.03996132  97.12699674]]
Number of interations: 6
[[35097.]
 [75792.]]
[[248.52520158 235.22702795 229.37185514]
 [172.41340775  87.33962687  97.33673739]]
Number of interations: 7
[[35025.]
 [75864.]]
[[248.57695931 235.37187723 229.

Visualize the result with k=2,4,6,8

In [None]:
plt.show()