In [18]:
from skimage import io
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial.distance import cdist
np.random.seed(12)
import cv2

In [19]:
img = cv2.imread('tiger1.jfif')


In [20]:
rows = img.shape[0]
cols = img.shape[1]

In [21]:
flat_img = img.reshape(img.shape[1]*img.shape[0], 3)

In [27]:
flat_img.shape

(50373, 3)

In [23]:
#number of cluster 
def kmeans_init_centers(X, k):
    # randomly pick k rows of X as initial centers
    return X[np.random.choice(X.shape[0], k, replace=False)]

def kmeans_assign_labels(X, centers):
    # calculate pairwise distances btw data and centers
    D = cdist(X, centers)
    # return index of the closest center
    return np.argmin(D, axis = 1)

def kmeans_update_centers(X, labels, K):
    centers = np.zeros((K, X.shape[1]))
    for k in range(K):
        # collect all points assigned to the k-th cluster 
        Xk = X[labels == k, :]
        # take average
        centers[k,:] = np.mean(Xk, axis = 0)
    return centers

def has_converged(centers, new_centers):
    # return True if two sets of centers are the same
    return (set([tuple(a) for a in centers]) == 
        set([tuple(a) for a in new_centers]))


In [24]:
def kmeans(X, K):
    centers = [kmeans_init_centers(X, K)]
    labels = []
    it = 0 
    while True:
        labels.append(kmeans_assign_labels(X, centers[-1]))
        new_centers = kmeans_update_centers(X, labels[-1], K)
        if has_converged(centers[-1], new_centers):
            break
        centers.append(new_centers)
        it += 1
    return (centers, labels, it)

    



In [25]:
(centers, labels, it) = kmeans(flat_img, 10)




In [26]:
centers

[array([[233, 226, 229],
        [243, 236, 239],
        [241, 234, 237],
        [182, 198, 221],
        [247, 238, 241],
        [184, 193, 207],
        [244, 237, 240],
        [117, 116, 118],
        [117, 134, 155],
        [240, 234, 235]], dtype=uint8),
 array([[226.93991943, 224.23734454, 227.41250657],
        [242.55084018, 236.19134061, 238.06782779],
        [240.81893266, 234.72935197, 236.42217281],
        [190.63647643, 204.43920596, 220.50806452],
        [247.68094468, 240.70231681, 243.69710848],
        [178.10921886, 184.65925405, 194.79718508],
        [244.21132075, 237.53790738, 240.1670669 ],
        [ 74.66679063,  78.95425809,  87.63666791],
        [136.18681983, 145.6429867 , 162.18863362],
        [238.62287951, 232.70552414, 234.22292301]]),
 array([[225.61690141, 223.34084507, 226.77183099],
        [242.55243124, 236.19204322, 238.06335953],
        [240.54998978, 234.39501125, 236.00347577],
        [196.4942813 , 204.26800618, 214.16228748],
     