In [1]:
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display, Latex, Image
from sympy import Derivative, Indexed, Sum, init_printing, lambdify, symbols, latex
from celluloid import Camera

init_printing()

In [2]:
k = 7
w = 1600
h = 900
nums = 100

colors = np.random.rand(k, 3)

x = np.random.randint(0, w, size=nums)
y = np.random.randint(0, h, size=nums)
pts = np.column_stack((x,y))

print(x.shape)
print(pts.shape)

(100,)
(100, 2)


In [3]:
centroids_x = np.random.randint(0, w, size=k)
centroids_y = np.random.randint(0, h, size=k)
centroids = np.column_stack((centroids_x,centroids_y))

print(centroids.shape)

(7, 2)


In [4]:
def create_plots():
    fig,ax = plt.subplots(1,3,figsize=(16 / 9.0 * 4, 4 * 1), layout="constrained")
    fig.suptitle("K-Means Clustering Unsupervised")
    
    ax[0].set_xlabel("Epoch", fontweight="normal")
    ax[0].set_ylabel("Euclidean Distance", fontweight="normal")
    ax[0].set_title("Centroid Distance")
    
    ax[1].axis("off")
    ax[2].axis("off")
    
    ax[2] = fig.add_subplot(1, 2, 2)
    ax[2].set_xlabel("X")
    ax[2].set_ylabel("Y")
    ax[2].set_title("Centroids")

    camera = Camera(fig)
    return ax[0], ax[2], camera

In [5]:
ax0, ax1, camera = create_plots()
epochs = 10

dists = np.zeros(epochs)
dists_idx = np.arange(1, epochs+1)

for e in range(epochs):
    groups = [[] for _ in range(k)]
    acc_dist = 0
    for i in range(nums):
        min_group = 0
        min_dist = np.sqrt(w**2 + h**2)
        
        curr_pt = pts[i]
        curr_c = []
        for c in range(k):
            curr_c = centroids[c]
            
            dist = np.sqrt((curr_pt[0]-curr_c[0])**2+(curr_pt[1]-curr_c[1])**2)
            # dist = abs(curr_pt[0]-curr_c[0])+abs(curr_pt[1]-curr_c[1])
            if dist < min_dist:
                min_dist = dist
                min_group = c
        
        groups[min_group].append(curr_pt)
        acc_dist += min_dist
        
    acc_dist /= nums
    dists[e] = acc_dist
    dists_mask = dists != 0
    ax0.plot(dists_idx[:e+1], dists[:e+1], color="red")
        
    for g in range(k):
        # Draw the centroids
        curr_centroid = centroids[g]
        curr_centroid = np.array([curr_centroid], dtype=np.int32)
        ax1.scatter(curr_centroid[:,0], curr_centroid[:,1], color=colors[g])
        
        group_pts = np.array(groups[g])
        if group_pts.size != 0:
            # Draw lines between points and the centroids
            pts_in_group = group_pts.shape[0]
            for i in range(pts_in_group):
                group_pt = group_pts[i]
                ax1.plot([group_pt[0], centroids[g][0]], [group_pt[1], centroids[g][1]], color="black",alpha=0.3)
                
            # Update the location of the centroids
            new_centroid = np.mean(group_pts, axis=0)
            centroids[g] = new_centroid
            new_centroid = np.array([new_centroid], dtype=np.int32)
        


    
    ax1.scatter(pts[:,0], pts[:,1], alpha=0.1, c="blue")
    camera.snap()

plt.close()
animation = camera.animate()
animation.save("k_means.gif", writer="pillow")
Image(url="k_means.gif")