In [101]:
import numpy as np
from PIL import Image
import time, re, os, sys
import matplotlib.pyplot as plt
import cv2
from scipy.spatial.distance import pdist,squareform
from array2gif import write_gif
#from util import *


In [102]:
FOLDER = ''
POINT_NUM = 10000
LENGTH = 100
CLUSTER_NUM = 4
GIF_path = './GIF'
colormap= np.random.choice(range(256),size=(100,3))


In [103]:
def openImage(path):
	image = cv2.imread(path)
	H, W, C = image.shape
	image_flat = np.zeros((W * H, C))
	for h in range(H):
		image_flat[h * W:(h + 1) * W] = image[h]
	return image_flat,H,W

In [104]:
def precomputed_kernel(X, gamma_s, gamma_c):
    n=len(X)
    # S(x) spacial information
    S=np.zeros((n,2))
    for i in range(n):
        S[i]=[i//100,i%100]
    print(pdist(S,'sqeuclidean').shape)
    K=squareform(np.exp(-gamma_s*pdist(S,'sqeuclidean')))*squareform(np.exp(-gamma_c*pdist(X,'sqeuclidean')))
    print(K.shape)
    return K

In [105]:
def visualize(C,k,H,W):
    '''
    @param C: (10000) belonging classes ndarray
    @param k: #clusters
    @param H: image_H
    @param W: image_W
    @return : (H,W,3) ndarray
    '''
    colors= colormap[:k,:]
    res=np.zeros((H,W,3))
    for h in range(H):
        for w in range(W):
            #np.array(res)[h,w,:] = np.array(colors)[C[h*W+w]]
            res[h,w,:] = colors[C[h*W+w]]

    return res.astype(np.uint8)

In [106]:
def kmeans(CLUSTER_NUM, Gram, H, W, init=None):
    Cluster = np.zeros((CLUSTER_NUM, Gram.shape[1]))
    if init == 'random':
        random_pick=np.random.randint(low=0,high=Gram.shape[0],size=CLUSTER_NUM)
        Cluster = Gram[random_pick,:]
    if init == 'gaussian':
        X_mean=np.mean(Gram,axis=0)
        X_std=np.std(Gram,axis=0)
        for c in range(Gram.shape[1]):
            Cluster[:,c]=np.random.normal(X_mean[c],X_std[c],size=CLUSTER_NUM)
      

    # kmeans++ init
    
    Cluster[0]=Gram[np.random.randint(low=0,high=Gram.shape[0],size=1),:]
    for c in range(1,CLUSTER_NUM):
            Dist=np.zeros((len(Gram),c))
            for i in range(len(Gram)):
                for j in range(c):
                    Dist[i,j]=np.sqrt(np.sum((Gram[i]-Cluster[j])**2))
            Dist_min=np.min(Dist,axis=1)
            sum=np.sum(Dist_min)*np.random.rand()
            for i in range(len(Gram)):
                sum-=Dist_min[i]
                if sum<=0:
                    Cluster[c]=Gram[i]
                    break
    # kmeans++
    diff = 1e9
    eps = 1e-9
    count = 1
    # Classes of each Xi
    C=np.zeros(len(Gram),dtype=np.uint8)
    segments=[]
    while diff > eps:
        # E-step
        for i in range(len(Gram)):
            dist=[]
            for j in range(CLUSTER_NUM):
                dist.append(np.sqrt(np.sum((Gram[i]-Cluster[j])**2)))
            C[i]=np.argmin(dist)
        
        #M-step
        New_Mean=np.zeros(Cluster.shape)
        for i in range(CLUSTER_NUM):
            belong=np.argwhere(C==i).reshape(-1)
            for j in belong:
                New_Mean[i]=New_Mean[i]+Gram[j]
            if len(belong)>0:
                New_Mean[i]=New_Mean[i]/len(belong)

        diff = np.sum((New_Mean - Cluster)**2)
        Cluster=New_Mean
        # visualize
        segment = visualize(C, CLUSTER_NUM, H, W)
        segments.append(segment)
        print('iteration {}'.format(count))
        for i in range(CLUSTER_NUM):
            print('k={}: {}'.format(i + 1, np.count_nonzero(C == i)))
        print('diff {}'.format(diff))
        print('-------------------')
        cv2.imshow('', segment)
        cv2.waitKey(1)
    return C, segments

In [107]:
def kernelkmeans(path):
    image_flat, H, W = openImage(path)
    gamma_s = 0.001
    gamma_c = 0.001
    Gram = precomputed_kernel(image_flat, gamma_s, gamma_c)
    C, segments = kmeans(CLUSTER_NUM, Gram, H, W, init='gaussian')
    # save_gif
    for i in range(len(segments)):
        segments[i] = segments[i].transpose(1, 0, 2)
    filename = path.split('.')[0] + '.gif'
    gif_path = os.path.join(GIF_path, filename)
    write_gif(segments, gif_path, fps=2)
    
    cv2.waitKey(0)
    cv2.destroyAllWindows()
    

        




In [108]:
kernelkmeans('image1.png')


(49995000,)
(10000, 10000)
iteration 1
k=1: 732
k=2: 2032
k=3: 868
k=4: 6368
diff 185.11460974798092
-------------------
iteration 1
k=1: 966
k=2: 2379
k=3: 875
k=4: 5780
diff 11.971091360371217
-------------------
iteration 1
k=1: 985
k=2: 2452
k=3: 874
k=4: 5689
diff 1.1560538511088327
-------------------
iteration 1
k=1: 989
k=2: 2493
k=3: 873
k=4: 5645
diff 0.12683931597907114
-------------------
iteration 1
k=1: 990
k=2: 2504
k=3: 873
k=4: 5633
diff 0.009181589889564041
-------------------
iteration 1
k=1: 991
k=2: 2509
k=3: 873
k=4: 5627
diff 0.00044583598139736485
-------------------
iteration 1
k=1: 991
k=2: 2510
k=3: 873
k=4: 5626
diff 4.1407888002866564e-05
-------------------
iteration 1
k=1: 991
k=2: 2511
k=3: 873
k=4: 5625
diff 3.423012230074337e-05
-------------------
iteration 1
k=1: 991
k=2: 2511
k=3: 873
k=4: 5625
diff 0.0
-------------------


In [109]:
kernelkmeans('image2.png')

(49995000,)
(10000, 10000)
iteration 1
k=1: 6972
k=2: 1455
k=3: 806
k=4: 767
diff 229.35228226977898
-------------------
iteration 1
k=1: 6086
k=2: 2150
k=3: 840
k=4: 924
diff 3.877502664278032
-------------------
iteration 1
k=1: 4017
k=2: 4121
k=3: 859
k=4: 1003
diff 9.169870651099924
-------------------
iteration 1
k=1: 1912
k=2: 6246
k=3: 875
k=4: 967
diff 55.53941830116036
-------------------
iteration 1
k=1: 1722
k=2: 6441
k=3: 881
k=4: 956
diff 1.607604856264227
-------------------
iteration 1
k=1: 1672
k=2: 6493
k=3: 882
k=4: 953
diff 0.10374828123864306
-------------------
iteration 1
k=1: 1656
k=2: 6511
k=3: 882
k=4: 951
diff 0.012181739675775655
-------------------
iteration 1
k=1: 1654
k=2: 6520
k=3: 882
k=4: 944
diff 0.004675892980404283
-------------------
iteration 1
k=1: 1652
k=2: 6528
k=3: 882
k=4: 938
diff 0.004396360992127396
-------------------
iteration 1
k=1: 1651
k=2: 6536
k=3: 882
k=4: 931
diff 0.003831860241225232
-------------------
iteration 1
k=1: 1651
k=2: 