In [None]:
import torch
from sklearn.cluster import KMeans


In [None]:
def tensor_generator(shape, distribution='normal'):
    while True:
        if distribution == 'normal':
            yield torch.randn(shape)
        elif distribution == 'uniform':
            yield torch.rand(shape)
        else:
            raise ValueError(f'Unknown distribution: {distribution}')

In [None]:
def noise_generator(centroid, std=0.1):
    while True:
        yield centroid + torch.randn_like(centroid) * std

In [None]:
N_CENTROIDS = 10
N = 9000
STD = 0.5
SHAPE = (512)
DIST = 'uniform'

In [None]:
centroid_generator = tensor_generator(SHAPE, DIST)
centroids = [next(centroid_generator) for _ in range(N_CENTROIDS)]

In [None]:
noise_generators = [noise_generator(centroid, STD) for centroid in centroids]

# generate N noisy tensors for each centroid
noisy_tensors = torch.stack([next(noise_generators[i % N_CENTROIDS]) for i in range(N)])

In [None]:
def k_means_clustering(data, n_clusters):
    # Convert PyTorch tensor to numpy array
    data_np = data.cpu().numpy()

    # Perform K-means clustering
    kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(data_np)

    # Convert labels back to PyTorch tensor
    labels = torch.from_numpy(kmeans.labels_)

    return labels

In [None]:
n_clusters = 5
labels = k_means_clustering(noisy_tensors, n_clusters)

In [None]:
# plot the data in 2d
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

pca = PCA(n_components=2)
noisy_tensors_2d = pca.fit_transform(noisy_tensors)

plt.scatter(noisy_tensors_2d[:, 0], noisy_tensors_2d[:, 1], c=labels, cmap='viridis')

# plot the centroids
centroids_2d = pca.transform(torch.stack(centroids))
plt.scatter(centroids_2d[:, 0], centroids_2d[:, 1], c='red', s=100, alpha=0.5)


plt.show()

In [None]:
# plot in 3d
from mpl_toolkits.mplot3d import Axes3D

pca = PCA(n_components=3)
noisy_tensors_3d = pca.fit_transform(noisy_tensors)

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(noisy_tensors_3d[:, 0], noisy_tensors_3d[:, 1], noisy_tensors_3d[:, 2], c=labels, cmap='viridis')

# plot the centroids
centroids_3d = pca.transform(torch.stack(centroids))
ax.scatter(centroids_3d[:, 0], centroids_3d[:, 1], centroids_3d[:, 2], c='red', s=100, alpha=0.5)

plt.show()