In [1]:
import numpy as np
from PIL import Image
import time, re, os, sys
import matplotlib.pyplot as plt
import cv2
from scipy.spatial.distance import pdist,squareform
from array2gif import write_gif

In [2]:
FOLDER = ''
POINT_NUM = 10000
LENGTH = 100
CLUSTER_NUM = 2
GIF_path = '.\GIF'
colormap= np.random.choice(range(256),size=(100,3))

In [3]:
def openImage(path):
    image = cv2.imread(path)
    H, W, C = image.shape
    image_flat = np.zeros((W * H, C))
    for h in range(H):
        image_flat[h * W:(h + 1) * W] = image[h]
    return image_flat,H,W

In [4]:
def precomputed_kernel(X, gamma_s, gamma_c):
    n=len(X)
    # S(x) spacial information
    S=np.zeros((n,2))
    for i in range(n):
        S[i]=[i//100,i%100]
    print(pdist(S,'sqeuclidean').shape)
    K=squareform(np.exp(-gamma_s*pdist(S,'sqeuclidean')))*squareform(np.exp(-gamma_c*pdist(X,'sqeuclidean')))
    print(K.shape)
    return K

In [5]:
def visualize(C,k,H,W):
    '''
    @param C: (10000) belonging classes ndarray
    @param k: #clusters
    @param H: image_H
    @param W: image_W
    @return : (H,W,3) ndarray
    '''
    colors= colormap[:k,:]
    res=np.zeros((H,W,3))
    for h in range(H):
        for w in range(W):
            #np.array(res)[h,w,:] = np.array(colors)[C[h*W+w]]
            res[h,w,:] = colors[C[h*W+w]]

    return res.astype(np.uint8)

In [6]:
def kmeans(CLUSTER_NUM, Gram, H, W, init=None):
    Cluster = np.zeros((CLUSTER_NUM, Gram.shape[1]))
    if init == 'random':
        random_pick=np.random.randint(low=0,high=Gram.shape[0],size=CLUSTER_NUM)
        Cluster = Gram[random_pick,:]
    if init == 'gaussian':
        X_mean=np.mean(Gram,axis=0)
        X_std=np.std(Gram,axis=0)
        for c in range(Gram.shape[1]):
            Cluster[:,c]=np.random.normal(X_mean[c],X_std[c],size=CLUSTER_NUM)
      

    # kmeans++ init
    
    Cluster[0]=Gram[np.random.randint(low=0,high=Gram.shape[0],size=1),:]
    for c in range(1,CLUSTER_NUM):
            Dist=np.zeros((len(Gram),c))
            for i in range(len(Gram)):
                for j in range(c):
                    Dist[i,j]=np.sqrt(np.sum((Gram[i]-Cluster[j])**2))
            Dist_min=np.min(Dist,axis=1)
            sum=np.sum(Dist_min)*np.random.rand()
            for i in range(len(Gram)):
                sum-=Dist_min[i]
                if sum<=0:
                    Cluster[c]=Gram[i]
                    break
    # kmeans++
    diff = 1e9
    eps = 1e-9
    count = 1
    # Classes of each Xi
    C=np.zeros(len(Gram),dtype=np.uint8)
    segments=[]
    while diff > eps:
        # E-step
        for i in range(len(Gram)):
            dist=[]
            for j in range(CLUSTER_NUM):
                dist.append(np.sqrt(np.sum((Gram[i]-Cluster[j])**2)))
            C[i]=np.argmin(dist)
        
        #M-step
        New_Mean=np.zeros(Cluster.shape)
        for i in range(CLUSTER_NUM):
            belong=np.argwhere(C==i).reshape(-1)
            for j in belong:
                New_Mean[i]=New_Mean[i]+Gram[j]
            if len(belong)>0:
                New_Mean[i]=New_Mean[i]/len(belong)

        diff = np.sum((New_Mean - Cluster)**2)
        Cluster=New_Mean
        # visualize
        segment = visualize(C, CLUSTER_NUM, H, W)
        segments.append(segment)
        print('iteration {}'.format(count))
        for i in range(CLUSTER_NUM):
            print('k={}: {}'.format(i + 1, np.count_nonzero(C == i)))
        print('diff {}'.format(diff))
        print('-------------------')
        cv2.imshow('', segment)
        cv2.waitKey(1)
    return C, segments

In [7]:
def kernelkmeans(path):
    image_flat, H, W = openImage(path)
    gamma_s = 0.001
    gamma_c = 0.001
    Gram = precomputed_kernel(image_flat, gamma_s, gamma_c)
    C, segments = kmeans(CLUSTER_NUM, Gram, H, W, init='random')
    # save_gif
    for i in range(len(segments)):
        segments[i] = segments[i].transpose(1, 0, 2)
    filename = path.split('.')[0] + '.gif'
    gif_path = os.path.join(GIF_path, filename)
    write_gif(segments, gif_path, fps=2)
    
    cv2.waitKey(0)
    cv2.destroyAllWindows()

In [8]:
kernelkmeans('image1.png')

(49995000,)
(10000, 10000)
iteration 1
k=1: 8201
k=2: 1799
diff 127.51598452629786
-------------------
iteration 1
k=1: 7721
k=2: 2279
diff 4.189636475122056
-------------------
iteration 1
k=1: 7535
k=2: 2465
diff 0.5412200578035654
-------------------
iteration 1
k=1: 7463
k=2: 2537
diff 0.06405410416724847
-------------------
iteration 1
k=1: 7436
k=2: 2564
diff 0.008056391954235216
-------------------
iteration 1
k=1: 7425
k=2: 2575
diff 0.0013064366597941067
-------------------
iteration 1
k=1: 7421
k=2: 2579
diff 0.00019374428997362043
-------------------
iteration 1
k=1: 7421
k=2: 2579
diff 0.0
-------------------


In [9]:
kernelkmeans('image2.png')

(49995000,)
(10000, 10000)
iteration 1
k=1: 8693
k=2: 1307
diff 169.92870862300688
-------------------
iteration 1
k=1: 8669
k=2: 1331
diff 0.9868861522136408
-------------------
iteration 1
k=1: 8650
k=2: 1350
diff 0.3013295353780692
-------------------
iteration 1
k=1: 8618
k=2: 1382
diff 0.10173290881516872
-------------------
iteration 1
k=1: 8581
k=2: 1419
diff 0.0336957868644991
-------------------
iteration 1
k=1: 8534
k=2: 1466
diff 0.03787926166633537
-------------------
iteration 1
k=1: 8480
k=2: 1520
diff 0.038745550738433665
-------------------
iteration 1
k=1: 8402
k=2: 1598
diff 0.06512746628645556
-------------------
iteration 1
k=1: 8299
k=2: 1701
diff 0.0962600295149068
-------------------
iteration 1
k=1: 8164
k=2: 1836
diff 0.13466060127623597
-------------------
iteration 1
k=1: 7899
k=2: 2101
diff 0.3778572918134071
-------------------
iteration 1
k=1: 7265
k=2: 2735
diff 1.209239402173455
-------------------
iteration 1
k=1: 5918
k=2: 4082
diff 2.810154613115306
-