In [1]:
import numpy as np


class KMeans:
    def __init__(self, K: int, init: np.array):
        self.K = K
        self.cluster_centers_ = init

    def fit(self, X: np.array):
        while True:
            distance = np.empty((self.K, X.shape[0]), dtype=np.float64)
            for i, center in enumerate(self.cluster_centers_):
                distance[i, :] = np.linalg.norm(np.subtract(X, center), axis=1)
            
            classes = np.argmin(distance, axis=0)
            new_cluster_centers = np.empty(self.cluster_centers_.shape, dtype=np.float64)

            for k in range(self.K):
                points = X[classes==k, :]
                if points.size != 0:
                    new_cluster_centers[k, :] = np.mean(points, axis=0)
                else:
                    new_cluster_centers[k, :] = self.cluster_centers_[k, :]

            if np.max(np.linalg.norm(new_cluster_centers - self.cluster_centers_, axis=1)) < 0.001:
                self.cluster_centers_ = new_cluster_centers
                break
            
            self.cluster_centers_ = new_cluster_centers

        return self 

    def predict(self, X: np.array):
        distance = np.empty((self.K, X.shape[0]), dtype=np.float64)
        for i, center in enumerate(self.cluster_centers_):
            distance[i, :] = np.linalg.norm(np.subtract(X, center), axis=1)
        
        return np.argmin(distance, axis=0)  