In [1]:
import numpy as np
from numpy.linalg import norm

In [2]:
class KMeans:
    '''KMeans Clustering Algorithm'''
    
    def __init__(self, n_clusters, max_iter = 100, random_state = 42):
        self.n_clusters = n_clusters
        self.max_iter = max_iter
        self.random_state = random_state
        
    def initialize_centroids(self, x):
        np.random.RandomState(self.random_state)
        random_idx = np.random.permutation(X.shape[0])
        centroids = X[random_idx[:self.n_clusters]]
        
        return centroids
    
    def compute_centroids(self, x, labels):
        centroids = np.zeros((self.n_clusters, x.shape[1]))
        for k in range(self.n_clusters):
            centroids[k, :] = np.mean(x[labels == k, :], axis = 0)
        
        return centroids
    
    def compute_distance(self, x, centroids):
        distance = np.zeros((x.shape[0], self.n_clusters))
        for k in range(self.n_clusters):
            row_norm = norm(x - centroids[k, :], axis = 1)
            distance[:, k] = np.square(row_norm)
        
        return distance
    
    def find_closest_cluster(self, distance):
        
        return np.argmin(distance, axis = 1)
    
    def compute_sse(self, x, labels, centroids):
        distance = np.zeros(x.shape[0])
        for k in range(self.n_clusters):
            distance[labels == k] = norm(x[labels == k] - centroid[k], axis = 1)
        
        return np.sm(np.square(distance))
    
    def fit(self, x):
        self.centroids = self.initialize_centroids(x)
        for i in range(self.max_iter):
            old_centroids = self.centroids
            distance = self.compute_distance(x, old_centroids)
            self.labels = self.find_closest_cluster(distance)
            self.centroids = self.compute_centroids(x, self.labels)
            if np.all(old_centroids == self.centroids):
                break
        self.error = self.compute_sse(x, self.labels, self.centroids)
        
    def predict(self, x):
        distance = self.compute_distance(x, old_centroids)
        return self.find_closest_cluster(distance)