# Clustering Algorithms

In [None]:
import math
import torch
import matplotlib.pyplot as plt
from torch.distributions import MultivariateNormal

In [None]:
torch.set_printoptions(precision=4, sci_mode=False)

## 1. Randomly-Generated Data

In [None]:
n_clusters = 10
n_samples = 250

In [None]:
random_centers = torch.rand((n_clusters, 2)) * 100 - 50

In [None]:
plt.scatter(random_centers[...,0], random_centers[...,1], marker="x");

In [None]:
def generate_data(clusters: torch.Tensor, n: int):
    from itertools import chain
    colors = torch.as_tensor(list(chain(*[[i]*n for i in range(clusters.shape[0])])))
    return (
        torch.cat([
            MultivariateNormal(center, torch.diag(torch.tensor([5., 5.]))).sample((n,))
            for center in clusters
        ]),
        colors
    )

In [None]:
data, clusters = generate_data(random_centers, n_samples)

In [None]:
def plot_data(X, clusters, centers):
    assert len(X) == len(clusters)
    _, ax = plt.subplots(1, 1)
    ax.scatter(X[...,0], X[...,1], c=clusters, cmap="tab10", marker=".", s=2)
    ax.scatter(centers[..., 0], centers[..., 1], c="black", marker="X", s=20)
    ax.scatter(centers[..., 0], centers[..., 1], c="white", marker="x", s=10)    

In [None]:
plot_data(data, clusters, random_centers)

In [None]:
#         _, ax = plt.subplots(1, 1)
#         ax.scatter(X[...,0], X[...,1], c=colors, cmap="tab10", marker=".", s=2)
#         ax.scatter(centers[..., 0], centers[..., 1], c="black", marker="X", s=20)
#         ax.scatter(centers[..., 0], centers[..., 1], c="white", marker="x", s=10)

In [None]:
# from dataclasses import dataclass, field
# @dataclass
# class RandomClusters:
#     n_clusters: int
#     n_per_cluster: int
#     a: float = 100
#     b: float = -50
    
#     _centers: torch.Tensor = field(init=False)
#     _points: torch.Tensor = field(init=False)
    
#     @property
#     def centers(self): return self._centers

#     @property
#     def points(self): return self._points

#     def __repr__(self): return (
#         f"RandomCluster(centers={self._centers.shape}, points={self._points.shape})\n"
#         f"Centers:\n{self.centers}"
#     )
    
#     def __post_init__(self):
#         uniform = torch.rand((self.n_clusters, 2))
#         centers = self.a*uniform + self.b
#         self._centers = centers
#         self._points = generate_data(centers, self.n_per_cluster)
        
#     def plot_centers(self):
#         plt.scatter(self.centers[...,0], self.centers[...,1], marker="x")
        
#     def plot_data(self):
#         X, centers, colors = self._points, self._centers, self.true_clusters()
#         _, ax = plt.subplots(1, 1)
#         ax.scatter(X[...,0], X[...,1], c=colors, cmap="tab10", marker=".", s=2)
#         ax.scatter(centers[..., 0], centers[..., 1], c="black", marker="X", s=20)
#         ax.scatter(centers[..., 0], centers[..., 1], c="white", marker="x", s=10)
        
#     def true_clusters(self):
#         from itertools import chain
#         colors = list(chain(*[[i]*self.n_per_cluster for i in range(self.n_clusters)]))
#         return colors

## 2. Mean Shift Clustering

In [None]:
def gaussian(data: torch.Tensor, bw: float): 
    return torch.exp(-0.5*(data/bw)**2) / (bw*math.sqrt(2*math.pi))

In [None]:
[plt.plot(gaussian(torch.arange(20), bw), label=f"bw={bw}") for bw in (1., 2.5, 5.)]
plt.legend()
plt.show()

### 2.1. One Sample

In [None]:
X = data
x = X[0]
x, x.shape, X.shape

In [None]:
weight_to_x = gaussian((x - X).pow(2).sum(1).sqrt(), 2.5)

In [None]:
weight_to_x.shape, X.shape

In [None]:
delta = (weight_to_x[:,None] * X).sum(0)/weight_to_x.sum()
delta

### 2.2. All Samples

In [None]:
X = data.clone()

In [None]:
W = (X[None] - X[:,None]).pow(2).sum(2).sqrt()

In [None]:
plt.imshow(W);

In [None]:
W.shape, X.shape

In [None]:
(W @ X)/W.sum(1)[:,None]

In [None]:
X = data.clone()
for _ in range(15):
    W = gaussian((X[None] - X[:, None]).pow(2).sum(2).sqrt(), 2.5)
    X = (W @ X)/W.sum(1, keepdim=True)

In [None]:
plot_data(X, clusters, random_centers + 1)

In [None]:
ms_centers, ms_clusters = X.unique(dim=0, return_inverse=True)

In [None]:
plot_data(data, ms_clusters, ms_centers)