# k-means

In [8]:
import math
import sys
import random
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML

In [20]:
class k_means:
    
    def __init__(self, data, ground, cluster_num):
        
        # read files
        with open(data) as f:
            lines = f.readlines()
        self.points = [[float(x) for x in line.strip().split(" ")] for line in lines]

        with open(ground) as f:
            lines = f.readlines()
        self.labels = [float(line.strip()) for line in lines]
        
        # init variables
        self.cluster_num = cluster_num
        self.predicts = [0 for i in range(len(self.points))]
        self.centroids = [self.points[idx] for idx in random.sample(range(len(self.points)), cluster_num)]
        
        self.error = 0
        self.converged = False
    
    def E_step(self):
        for i in range(len(self.points)):
            point = self.points[i]
            min_distance = sys.maxint
            
            for j in range(len(self.centroids)):
                centroid = self.centroids[j]
                # Euclidien distance
                distance = math.pow(point[0] - centroid[0], 2) + math.pow(point[1] - centroid[1], 2)
                if distance < min_distance:
                    self.predicts[i] = j
                    min_distance = distance
            
    def M_step(self):
        # save previous centroids and it will be used to see if centroids moved or not
        prev_centroids = self.centroids
        
        # sum up all points in the cluster
        self.centroids = [[0.0, 0.0] for i in range(self.cluster_num)]
        count = [0] * self.cluster_num
        for i in range(len(self.points)):
            point = self.points[i]
            self.centroids[self.predicts[i]][0] += point[0]
            self.centroids[self.predicts[i]][1] += point[1]
            count[self.predicts[i]] += 1
        
        # calculate means by dividing it by number of points in cluster
        for i in range(self.cluster_num):
            if count[i] != 0:
                self.centroids[i][0] /= count[i]
                self.centroids[i][1] /= count[i]
        
        # calculate euclidien distance between last centroids and current centroids
        error = 0.0
        for i in range(self.cluster_num):
            error += self.centroids[i][0] * prev_centroids[i][0] + self.centroids[i][1] * prev_centroids[i][1]
            
        # if distance is zero, it means centroid didn't move
        if error - self.error == 0:
            self.converged = True
        self.error = error
            

In [21]:
color_list = ["green", "blue", "red", "orange", "yellow", "purple", "black"]

model = k_means("data/test1_data.txt", "data/test1_ground.txt", 2)
x, y = zip(*model.points)

fig, ax = plt.subplots()
fig.set_size_inches((8, 4.5))

centroid_colors = [color_list[i] for i in range(len(model.centroids))]

def update(i):
    ax.clear()
    
    if i == 0:
        ax.scatter(x, y, color='black', marker='.')
        centroids_x, centroids_y = zip(*model.centroids)
        return ax.scatter(centroids_x, centroids_y, color=centroid_colors, marker='^', s=400)
    
    if i % 2 == 1:
        model.E_step()
        
        colors = [color_list[i] for i in model.predicts]
        ax.scatter(x, y, color=colors, marker='.')
        
        centroids_x, centroids_y = zip(*model.centroids)
        return ax.scatter(centroids_x, centroids_y, color=centroid_colors, marker='^', s=400)
        
    else:
        model.M_step()
        
        colors = [color_list[i] for i in model.predicts]
        ax.scatter(x, y, color=colors, marker='.')
        
        centroids_x, centroids_y = zip(*model.centroids)
        return ax.scatter(centroids_x, centroids_y, color=centroid_colors, marker='^', s=400)
    
def generator():
    n = 0
    # stop generate when model is converged
    while not model.converged:
        yield n
        n += 1
    
ani = animation.FuncAnimation(fig=fig, func=update, frames=generator, interval=300, blit=False)
HTML(ani.to_html5_video())