In [3]:

# Import necessary libraries
import torch
from sklearn.datasets import make_blobs
from sklearn.cluster import KMeans
import numpy as np
from sklearn.metrics import silhouette_score
##sklearn KMeans is imported lately so it's class not be confused with pytorch class

class KMeans:
    def __init__(self, n_clusters: int, max_iter:int = 10):
        
        """
        Initializes the KMeans object.

        Parameters:
            n_clusters: The number of clusters to create.
            max_iter: The maximum number of iterations to run the algorithm for.
        """
        self.n_clusters = n_clusters
        self.max_iter = max_iter

    def fit(self, X):
        """
        Runs the k-means algorithm on input data X.

        Parameters:
            X: A tensor of shape (N, D) containing the input data.
            N is the number of data points 
            D is the dimensionality of the data points.
        """
        N, D = X.shape

        # Initialize centroids randomly from the data points
        centroid_indices = torch.randint(0, N, (self.n_clusters,))
        self.cluster_centers_ = X[centroid_indices]

        # Run the k-means algorithm for max_iter iterations
        for i in range(self.max_iter):
            # Compute the distance between each data point and each centroid
            distances = torch.cdist(X, self.cluster_centers_)

            # Assign each data point to the closest centroid
            self.labels_ = torch.argmin(distances, dim=1)

            # Update the centroids to be the mean of the data points assigned to them
            for j in range(self.n_clusters):
                mask = self.labels_ == j
                if mask.any():
                    self.cluster_centers_[j] = X[mask].mean(dim=0)
                    
        #thise line returns labels and centoids of the results,          
        return self.labels_, self.cluster_centers_

    def predict(self, X):
        """
        Assigns each data point in X to its closest centroid.

        Parameters:
            X: A tensor of shape (N, D) containing the input data.

        Returns:
            A tensor of shape (N,) containing the index of the closest centroid for each data point.
        """
        distances = torch.cdist(X, self.cluster_centers_)
        return torch.argmin(distances, dim=1)



# Generate sample data
X, _ = make_blobs(n_samples=1000, centers=5, n_features=10, random_state=42)

X_torch = torch.tensor(X, dtype=torch.float32)


# Set parameters
n_clusters = 3
max_iter = 3000

# Run KMeans using Torch implementation
kmeans_torch = KMeans(n_clusters=n_clusters, max_iter=max_iter)
kmeans_torch.fit(X_torch)


from sklearn.cluster import KMeans
# Run KMeans using scikit-learn implementation
kmeans_sklearn = KMeans(n_clusters=n_clusters, max_iter=max_iter)
kmeans_sklearn.fit(X)



sklearn_labels = kmeans_sklearn.labels_
pytorch_labels = kmeans_torch.labels_

sklearn_centers = kmeans_sklearn.cluster_centers_
pytorch_centers = kmeans_torch.cluster_centers_



# Check that the results are the same
print("Scikit-learn centroids:\n", sklearn_centers)
print("Torch centroids:\n", pytorch_centers)

print("Scikit-learn labels:\n", sklearn_labels)
print("Torch labels:\n", pytorch_labels) 



import numpy as np
from sklearn.metrics import silhouette_score


# Calculate the silhouette score
silhouette_sklearn = silhouette_score(X, sklearn_labels)

# Calculate the silhouette score
silhouette_torch = silhouette_score(X_torch, pytorch_labels)


# Print the comparison results
print(f"Silhouette score for Sklearn: {silhouette_sklearn}")
print(f"Silhouette score for Pytorch: {silhouette_torch}")




Scikit-learn centroids:
 [[-6.02406665  9.22941675  5.63335675 -1.90723062 -6.60980967 -6.58155471
  -6.43125394  3.91169493  0.40542788  0.09865564]
 [ 2.30849878 -6.59015916 -8.7580347   9.01674696  9.34368999  6.1359172
  -3.88906042 -8.07901218  3.58800216 -1.25549066]
 [-2.73332177 -3.59523284 -6.75295915  2.65192424 -2.83850103  4.40081092
  -4.84708312  0.36964827  1.30863565 -7.65435708]]
Torch centroids:
 tensor([[-2.7333, -3.5952, -6.7530,  2.6519, -2.8385,  4.4008, -4.8471,  0.3696,
          1.3086, -7.6544],
        [-6.0241,  9.2294,  5.6334, -1.9072, -6.6098, -6.5816, -6.4313,  3.9117,
          0.4054,  0.0987],
        [ 2.3085, -6.5902, -8.7580,  9.0167,  9.3437,  6.1359, -3.8891, -8.0790,
          3.5880, -1.2555]])
Scikit-learn labels:
 [2 0 0 1 0 2 0 1 0 2 0 0 1 0 0 0 2 1 1 0 0 2 2 1 2 2 2 2 0 2 0 0 0 0 1 0 2
 0 2 2 1 0 0 2 2 2 2 0 0 2 2 1 0 2 2 0 0 2 2 0 0 0 0 2 1 1 1 2 0 0 0 0 0 1
 0 2 1 2 2 1 1 2 0 0 2 0 2 1 2 0 2 0 0 1 0 0 0 0 0 2 0 1 0 2 0 2 1 1 1 2 2
 2 2 0 