In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score, normalized_mutual_info_score
from sklearn.manifold import Isomap, SpectralEmbedding, TSNE, LocallyLinearEmbedding
import torch

from dataset import AbstractDataset, \
    UniformRandomDataset, \
    PrototypeDataset, \
    MNISTDataset, \
    FashionMNISTDataset

from clusterings import *

In [None]:
def measurePrototypeInformation(prototypes: np.ndarray, dataset: AbstractDataset) -> dict[str, float]:
    squared_distance: np.ndarray = np.square((prototypes[None, :, :] - prototypes[:, None, :])).sum(axis=2)
    squared_distance = squared_distance[~np.eye(prototypes.shape[0], dtype=bool)]
    return {
        # "squared_distance": squared_distance,
        "mean_squared_distance": squared_distance.mean(),
        "std_squared_distance": squared_distance.std(),
        "min_squared_distance": squared_distance.std(),
    }

def measureClusteringPerformance(predicted_labels: np.ndarray, dataset: AbstractDataset) -> dict[str, float]:
    try:
        silhouette = silhouette_score(dataset.X, predicted_labels)
    except ValueError:
        silhouette = np.nan
    nmi = normalized_mutual_info_score(dataset.y, predicted_labels)
    return {
        "silhouette": float(silhouette),
        "nmi": float(nmi),
    }

# Synthetic Prototype Dataset

In [None]:
NUM_POINTS = 1000
NUM_FEATURES = 10
NUM_DATA_PROTOTYPES = 5
INSTANCE_NOISE = 0.4
dataset = PrototypeDataset(NUM_DATA_PROTOTYPES,INSTANCE_NOISE,NUM_POINTS,NUM_FEATURES)

NUM_CLUSTERING_PROTOTYPES = 20

def visualize_prototypes(prototypes: np.ndarray, dataset: AbstractDataset, embedding_method: str = "TSNE"):
    if NUM_FEATURES < 2:
        return
    if NUM_FEATURES == 2:
        X = dataset.X
    if NUM_FEATURES > 2:
        match embedding_method:
            case "TSNE":
                embedding = TSNE()
            case "LocallyLinear":
                embedding = LocallyLinearEmbedding()
            case "Isomap":
                embedding = Isomap()
            case "SpectralEmbedding":
                embedding = SpectralEmbedding()
            case _:
                print(f"{embedding_method} is not a valid embedding method")
                return
        embedded_data = embedding.fit_transform(np.vstack([prototypes, dataset.X]))
        prototypes, X = embedded_data[:prototypes.shape[0]], embedded_data[prototypes.shape[0]:]

    for prototype_index in range(NUM_DATA_PROTOTYPES):
        prototype_mask = dataset.y==prototype_index
        plt.scatter(X[prototype_mask, 0], X[prototype_mask, 1], label=f"{prototype_index}")
    plt.scatter(prototypes[:, 0], prototypes[:, 1], s=10, c="k")
    plt.legend(title="Prototype Index")

# EMBEDDING_METHODS = ["TSNE","LocallyLinear","Isomap","SpectralEmbedding",]
EMBEDDING_METHODS = []

### K-Means

In [None]:
clustering = KMeans(NUM_CLUSTERING_PROTOTYPES)
clustering.fit(dataset.X)
predicted_labels = clustering.predict(dataset.X)

print(measurePrototypeInformation(clustering.cluster_centers_, dataset))
print(measureClusteringPerformance(predicted_labels, dataset))
for method in ["TSNE","LocallyLinear","Isomap","SpectralEmbedding",]:
    visualize_prototypes(clustering.cluster_centers_, embedding_method=method)
    plt.title(f"Embedding: {method}")
    plt.show()

### Winner Takes All

In [None]:
clustering = WinnerTakesAll(NUM_CLUSTERING_PROTOTYPES, NUM_FEATURES)
clustering.fit(dataset.X)
predicted_labels = clustering.predict(dataset.X)

print(measurePrototypeInformation(clustering.prototypes, dataset))
print(measureClusteringPerformance(predicted_labels, dataset))
for method in ["TSNE","LocallyLinear","Isomap","SpectralEmbedding",]:
    visualize_prototypes(clustering.prototypes, embedding_method=method)
    plt.title(f"Embedding: {method}")
    plt.show()

### FSCL

In [None]:
clustering = FSCL(NUM_CLUSTERING_PROTOTYPES, NUM_FEATURES)
clustering.fit(dataset.X)
predicted_labels = clustering.predict(dataset.X)

print(measurePrototypeInformation(clustering.prototypes, dataset))
print(measureClusteringPerformance(predicted_labels, dataset))
for method in ["TSNE","LocallyLinear","Isomap","SpectralEmbedding",]:
    visualize_prototypes(clustering.prototypes, embedding_method=method)
    plt.title(f"Embedding: {method}")
    plt.show()

### RPCL

In [None]:
clustering = RPCL(NUM_CLUSTERING_PROTOTYPES, NUM_FEATURES)
clustering.fit(dataset.X,best_matching_unit_learning_rate=1e-3,rival_matching_unit_learning_rate=2.75e-4)
predicted_labels = clustering.predict(dataset.X)

print(measurePrototypeInformation(clustering.prototypes, dataset))
print(measureClusteringPerformance(predicted_labels, dataset))
for method in ["TSNE","LocallyLinear","Isomap","SpectralEmbedding",]:
    visualize_prototypes(clustering.prototypes, embedding_method=method)
    plt.title(f"Embedding: {method}")
    plt.show()

# for prototype_index in range(NUM_DATA_PROTOTYPES):
#     prototype_mask = dataset.y==prototype_index
#     plt.scatter(dataset.X[prototype_mask, 0], dataset.X[prototype_mask, 1], label=f"{prototype_index}")
# plt.scatter(clustering.prototypes[:, 0], clustering.prototypes[:, 1], s=10, c="k")
# plt.legend(title="Prototype Index")
# plt.xlim(-3,3)
# plt.ylim(-3,3)
# plt.show()

### Base ClAM

In [None]:
# This specifically will NOT work well with NUM_FEATURES=2 due to the masking of data in ClAM.

clustering = ClAMClustering(NUM_CLUSTERING_PROTOTYPES, NUM_FEATURES, beta=1, time_constant=1e0)
torchX = torch.tensor(dataset.X)
clusteringPerformanceHistoryCallback = ClusteringPerformanceHistoryCallback(clustering, torchX, dataset.y)
clustering.add_training_callback(clusteringPerformanceHistoryCallback)

loss = clustering.fit(torchX, num_epochs=1000, batch_size=64, mask_bernoulli_parameter=0.8)
clusteringPerformanceHistory = pd.DataFrame(clusteringPerformanceHistoryCallback.clustering_performance_history)
clusteringPerformanceHistory["loss"] = loss
display(clusteringPerformanceHistory)
for col in clusteringPerformanceHistory:
    y = clusteringPerformanceHistory[col]
    X = np.arange(len(y))
    plt.plot(X, y)
    plt.xlabel("Epoch")
    plt.ylabel(f"{col}")
    plt.show()

predicted_labels = clustering.predict(torchX).detach().cpu().numpy()
print(measurePrototypeInformation(clustering.prototypes.detach().cpu().numpy(), dataset))
print(measureClusteringPerformance(predicted_labels, dataset))
for method in ["TSNE","LocallyLinear","Isomap","SpectralEmbedding",]:
    visualize_prototypes(clustering.prototypes.detach().cpu().numpy(), embedding_method=method)
    plt.title(f"Embedding: {method}")
    plt.show()

### Regularized ClAM

In [None]:
clustering = RegularizedClAM(
    NUM_CLUSTERING_PROTOTYPES, 
    NUM_FEATURES, 
    regularization_lambda=1e-6,
    regularization_exponent=1,
    beta=1, 
    time_constant=1e0
)
torchX = torch.tensor(dataset.X)
clusteringPerformanceHistoryCallback = ClusteringPerformanceHistoryCallback(clustering, torchX, dataset.y)
clustering.add_training_callback(clusteringPerformanceHistoryCallback)

loss = clustering.fit(torchX, num_epochs=2500, batch_size=64, mask_bernoulli_parameter=0.8)
clusteringPerformanceHistory = pd.DataFrame(clusteringPerformanceHistoryCallback.clustering_performance_history)
clusteringPerformanceHistory["loss"] = loss
display(clusteringPerformanceHistory)
for col in clusteringPerformanceHistory:
    y = clusteringPerformanceHistory[col]
    X = np.arange(len(y))
    plt.plot(X, y)
    plt.xlabel("Epoch")
    plt.ylabel(f"{col}")
    plt.show()

predicted_labels = clustering.predict(torchX).detach().cpu().numpy()
print(measurePrototypeInformation(clustering.prototypes.detach().cpu().numpy(), dataset))
print(measureClusteringPerformance(predicted_labels, dataset))
for method in ["TSNE","LocallyLinear","Isomap","SpectralEmbedding",]:
    visualize_prototypes(clustering.prototypes.detach().cpu().numpy(), embedding_method=method)
    plt.title(f"Embedding: {method}")
    plt.show()

# MNIST Dataset

In [None]:
NUM_POINTS = 1000
dataset = MNISTDataset(NUM_POINTS)
# dataset = FashionMNISTDataset(NUM_POINTS)
# NUM_DATA_PROTOTYPES = 10
NUM_FEATURES = dataset.num_features

NUM_CLUSTERING_PROTOTYPES = 20

def visualize_prototypes(prototypes: np.ndarray, ncols: int = 5, imsize: float=3):
    prototypes = prototypes.reshape(-1, 28, 28)
    nrows = int(np.ceil(prototypes.shape[0]/ncols))
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*imsize, nrows*imsize))
    for i, ax in enumerate(axes.ravel()):
        ax.set_xticks([])
        ax.set_yticks([])

        if i >= prototypes.shape[0]:
            ax.set_frame_on(False)
            continue

        prototype = prototypes[i, :]
        ax.imshow(prototype, cmap="grey")
    plt.tight_layout()
    plt.show()

### K-Means

In [None]:
clustering = KMeans(NUM_CLUSTERING_PROTOTYPES)
clustering.fit(dataset.X)
predicted_labels = clustering.predict(dataset.X)

print(measureClusteringPerformance(predicted_labels, dataset))
visualize_prototypes(clustering.cluster_centers_)

### Winner Takes All

In [None]:
clustering = WinnerTakesAll(NUM_CLUSTERING_PROTOTYPES, NUM_FEATURES)
clustering.fit(dataset.X, num_epochs=30)
predicted_labels = clustering.predict(dataset.X)

print(measureClusteringPerformance(predicted_labels, dataset))
visualize_prototypes(clustering.prototypes)

### FSCL

In [None]:
clustering = FSCL(NUM_CLUSTERING_PROTOTYPES, NUM_FEATURES)
clustering.fit(dataset.X, num_epochs=30)
predicted_labels = clustering.predict(dataset.X)

print(measureClusteringPerformance(predicted_labels, dataset))
visualize_prototypes(clustering.prototypes)

### RPCL

In [None]:
clustering = RPCL(NUM_CLUSTERING_PROTOTYPES, NUM_FEATURES)
clustering.fit(dataset.X, num_epochs=50,best_matching_unit_learning_rate=1e-3,rival_matching_unit_learning_rate=3.5e-4)
predicted_labels = clustering.predict(dataset.X)

print(measureClusteringPerformance(predicted_labels, dataset))
visualize_prototypes(clustering.prototypes)

### Base ClAM

In [None]:
clustering = ClAMClustering(NUM_CLUSTERING_PROTOTYPES, NUM_FEATURES, beta=10, time_constant=1e0)
torchX = torch.tensor(dataset.X)
clustering.fit(torchX, num_epochs=1000, mask_bernoulli_parameter=0.8, batch_size=128)
predicted_labels = clustering.predict(torchX).detach().cpu().numpy()

print(measureClusteringPerformance(predicted_labels, dataset))
visualize_prototypes(clustering.prototypes.detach().cpu().numpy())

### Regularized ClAM

In [None]:
clustering = RegularizedClAM(NUM_CLUSTERING_PROTOTYPES, NUM_FEATURES, regularization_lambda=0, beta=10, time_constant=1e0)
torchX = torch.tensor(dataset.X)
clustering.fit(torchX, num_epochs=1000, mask_bernoulli_parameter=0.8, batch_size=128)
predicted_labels = clustering.predict(torchX).detach().cpu().numpy()

print(measureClusteringPerformance(predicted_labels, dataset))
visualize_prototypes(clustering.prototypes.detach().cpu().numpy())