In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial.distance import cdist

In [49]:
def display_data(X, labels):
    X0 = X[labels==0, :]
    X1 = X[labels==1, :]
    plt.plot(X0[:,0], X0[:,1], 'b^')
    plt.plot(X1[:,0], X1[:,1], 'ro')
    plt.show()
    
def kmeans_init_centers(X, k):
    # replace = False : counter duplicate number
    return X[np.random.choice(X.shape[0], k, replace=False)]

def kmeans_assign_labels(X, centers):
    d = cdist(X, centers)
    #return d
    # each point in centers?
    return np.argmax(d, axis=1)

def kmeans_update_centers(X, labels, k):
    # k = unique(labels)
    new_centers = np.zeros((k, X.shape[1]))
    for j in range(k):
        xj = X[labels==j, :]
        new_centers[j] = np.mean(xj, axis=0)
    return new_centers

def has_converged(new_center, old_center):
    #return np.array_equal(new_center, old_center)
    return set([tuple(a) for a in new_center]) == set([tuple(a) for a in old_center])

def has_converged_labels(new_labels, original_labels):
    invert_new_labels = [1 if x == 0 else 1 for x in new_labels]
    return np.array_equal(new_labels, original_labels) or np.array_equal(invert_new_labels, original_labels)

def kmeans(X, k, original_labels):
    centers = [kmeans_init_centers(X, k)]
    #print(type(centers[-1]))
    #print(centers)
    labels = []
    it = 0
    while True:
        it += 1
        labels.append(kmeans_assign_labels(X, centers[-1]))
        new_center = kmeans_update_centers(X, labels[-1], k)
        #print(type(new_center))
        print('time:', it)
        print(new_center)
        print(centers[-1])
        if has_converged(new_center, centers[-1]):
            break
        #if has_converged_labels(labels[-1], original_labels):
        #    break
        centers.append(new_center)
    return centers, labels, it
    
# generate data
means = [[1,2], [4,7]]
cov = [[1,0], [0,1]]
N = 50
np.random.seed(2)
X0 = np.random.multivariate_normal(means[0], cov, N)
X1 = np.random.multivariate_normal(means[1], cov, N)
X = np.concatenate((X0,X1), axis=0)
original_labels = np.asarray([0]*N + [1]*N)
K = 2
# show
#display_data(X, original_labels)

# algorithm
centers, labels, it = kmeans(X, K, original_labels)
print(centers[-1])


time: 1
[[-0.39126444  2.4801604 ]
 [ 3.00830695  4.95630813]]
[[0.6864918  2.77101174]
 [0.34674973 2.84245628]]
time: 2
[[4.09569661 6.83497913]
 [0.62480069 1.99486288]]
[[-0.39126444  2.4801604 ]
 [ 3.00830695  4.95630813]]
time: 3
[[0.70218049 2.09033724]
 [4.22657056 7.02991175]]
[[4.09569661 6.83497913]
 [0.62480069 1.99486288]]
time: 4
[[4.22657056 7.02991175]
 [0.70218049 2.09033724]]
[[0.70218049 2.09033724]
 [4.22657056 7.02991175]]
[[0.70218049 2.09033724]
 [4.22657056 7.02991175]]
