This code generates images that are segmented using pixel clustering. 10, 20 and 50 clusters are used for segmentation.

In [300]:
from scipy import misc
from scipy.misc import toimage
import numpy as np 
from scipy.spatial import distance
from scipy.misc import logsumexp

# read image file into an array of pixels
# The image name below needs to be changed as per requirement

arr = misc.imread('smallsunset.jpg') 
print(arr.shape)
# transform to 3*n array
arr = arr.transpose(2,0,1).reshape(3,-1)
arr.shape

(330, 600, 3)


(3, 198000)

In [301]:
# Initialize variables

first = 1
k = 20  # no. of segments

#pixels = 307200   # for robert image
#pixels = 239400   # for strelitzia image
pixels = 198000    # for sunset image

prev_wij = np.zeros((pixels,k))

In [302]:
# Kmeans to find initial cluster centers

from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=k)
data = arr.transpose()
kmeans.fit(data)

KMeans(algorithm='auto', copy_x=True, init='k-means++', max_iter=300,
    n_clusters=20, n_init=10, n_jobs=1, precompute_distances='auto',
    random_state=None, tol=0.0001, verbose=0)

In [303]:
# Initialize pi and means (mu_j). Compute log(pi)

pi = []
clusterids = kmeans.predict(data)
mu_j = kmeans.cluster_centers_.reshape(k,3)
for i in range(0,k):
    pi.append(np.count_nonzero(clusterids==i)/pixels)
print(sum(pi))
logpi = np.log(pi).reshape(1,k)
print(mu_j.shape)
print(mu_j[5])

1.0
(20, 3)
[  58.91398608   92.25310074  208.36714732]


In [304]:
# Run EM algorithm until convergence of matrix wij

while(first == 1 or wij_diff > 1e-7):
    first = 0
    dist_matrix = distance.cdist(data,mu_j,'sqeuclidean')
    log_wij_num = logpi - 0.5*dist_matrix
    #print(log_wij_num.shape)
    log_wij_den = np.apply_along_axis(logsumexp,1,log_wij_num).reshape(pixels,1)
    wij = np.exp(log_wij_num - log_wij_den)
    wij_diff = np.amax(np.absolute(wij - prev_wij))
    print(wij_diff)
    prev_wij = wij
    mu_j_num = np.dot(data.transpose(),wij)
    #print(mu_j_num)
    mu_j_den = np.apply_along_axis(sum,0,wij).reshape(1,k)
    #print(mu_j_den)
    new_mu = mu_j_num / mu_j_den
    mu_j = new_mu.transpose()
    #print(mu_j)
    new_pi = np.log(np.apply_along_axis(sum,0,wij)/pixels)
    logpi = new_pi
print('done')

1.0
0.983860867789
0.975788840792
0.949045522113
0.800019819779
0.660193004323
0.724949306298
0.666671801341
0.620146413176
0.598568360669
0.649136662275
0.857138814011
0.813839074119
0.79988425665
0.738832250909
0.689407062737
0.69251950421
0.794238714407
0.798989340873
0.791675593092
0.636327923325
0.449080845559
0.267540789157
0.219202677661
0.17576684878
0.115518255737
0.0908285097203
0.0653874630262
0.0511138907084
0.0391264568543
0.0291299650933
0.0213897301136
0.0156278865722
0.0114187898709
0.00836573357945
0.00774613585082
0.00732471756256
0.00696968030502
0.00667078670696
0.00642041130983
0.00621281902505
0.0060436706933
0.00590968836485
0.00580843085263
0.00573814529478
0.00569767161856
0.00568638463554
0.00570416435857
0.00575138827644
0.00586506092566
0.00614468717601
0.00647589762078
0.0068668501999
0.00732768273335
0.0078710239302
0.00851265480372
0.00927236395166
0.0101750486544
0.0112521179993
0.0125432449678
0.0140984712456
0.0159805461598
0.018267087339
0.02105150100

In [290]:
# identify cluster id for each pixel and replace that pixel with the center of the cluster it belongs to

segmented_img = np.zeros((pixels,3))
new_clusterids = np.apply_along_axis(np.argmax,1,wij)
for i in range(0,pixels):
    segmented_img[i] = mu_j.astype(int)[new_clusterids[i]]

In [291]:
#print((mu_j).astype(int))
print(segmented_img.shape)
print(np.unique(new_clusterids))

(198000, 3)
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]


In [292]:
# reshape array to 3 dimensions and display
#newimg = segmented_img.transpose().reshape(3,399,600)
newimg = segmented_img.transpose().reshape(3,330,600)
newimg.shape
toimage(newimg).save('smallsunset20.jpg')