In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.cluster import KMeans, AgglomerativeClustering, OPTICS
from sklearn.metrics import silhouette_score, normalized_mutual_info_score
from sklearn.manifold import Isomap, SpectralEmbedding, TSNE, LocallyLinearEmbedding
import torch

from dataset import AbstractDataset, \
    UniformRandomDataset, \
    PrototypeDataset, \
    MNISTDataset, \
    FashionMNISTDataset

from clusterings import *

In [None]:
class ClusteringPerformance():
    
    def __init__(self, dataset: AbstractDataset):
        self.dataset = dataset
        self.data = {
            "clustering_name": [],
            "prototype_mean_squared_distance": [],
            "prototype_std_squared_distance": [],
            "prototype_mean_min_squared_distance": [],
            "prototype_std_min_squared_distance": [],
            "clustering_silhouette": [],
            "clustering_nmi": [],
        }

    def measure_clustering(self, name: str, prototypes: np.ndarray, predicted_labels: np.ndarray):
        self.data["clustering_name"].append(name)

        squared_distance: np.ndarray = np.square((prototypes[None, :, :] - prototypes[:, None, :])).sum(axis=2)
        squared_distance = squared_distance[~np.eye(prototypes.shape[0], dtype=bool)].reshape(prototypes.shape[0], prototypes.shape[0]-1)
        self.data["prototype_mean_squared_distance"].append(np.mean(squared_distance))
        self.data["prototype_std_squared_distance"].append(np.std(squared_distance))
        self.data["prototype_mean_min_squared_distance"].append(np.mean(squared_distance.min(axis=1)))
        self.data["prototype_std_min_squared_distance"].append(np.std(squared_distance.min(axis=1)))
        
        try:
            silhouette = silhouette_score(self.dataset.X, predicted_labels)
        except ValueError:
            silhouette = np.nan
        nmi = normalized_mutual_info_score(self.dataset.y, predicted_labels)
        self.data["clustering_silhouette"].append(float(silhouette))
        self.data["clustering_nmi"].append(float(nmi))

    def get_dataframe(self) -> pd.DataFrame:
        return pd.DataFrame(self.data)

# Synthetic Prototype Dataset

In [None]:
NUM_POINTS = 1000
NUM_FEATURES = 10
NUM_DATA_PROTOTYPES = 5
INSTANCE_NOISE = 0.4
dataset = PrototypeDataset(NUM_DATA_PROTOTYPES,INSTANCE_NOISE,NUM_POINTS,NUM_FEATURES)

NUM_CLUSTERING_PROTOTYPES = 20

def visualize_prototypes(prototypes: np.ndarray, dataset: AbstractDataset, embedding_method: str = "TSNE"):
    if NUM_FEATURES < 2:
        return
    if NUM_FEATURES == 2:
        X = dataset.X
    if NUM_FEATURES > 2:
        match embedding_method:
            case "TSNE":
                embedding = TSNE()
            case "LocallyLinear":
                embedding = LocallyLinearEmbedding()
            case "Isomap":
                embedding = Isomap()
            case "SpectralEmbedding":
                embedding = SpectralEmbedding()
            case _:
                print(f"{embedding_method} is not a valid embedding method")
                return
        embedded_data = embedding.fit_transform(np.vstack([prototypes, dataset.X]))
        prototypes, X = embedded_data[:prototypes.shape[0]], embedded_data[prototypes.shape[0]:]

    for prototype_index in range(NUM_DATA_PROTOTYPES):
        prototype_mask = dataset.y==prototype_index
        plt.scatter(X[prototype_mask, 0], X[prototype_mask, 1], label=f"{prototype_index}")
    plt.scatter(prototypes[:, 0], prototypes[:, 1], s=10, c="k")
    plt.legend(title="Prototype Index")

# EMBEDDING_METHODS = ["TSNE","LocallyLinear","Isomap","SpectralEmbedding",]
EMBEDDING_METHODS = ["TSNE"]
# EMBEDDING_METHODS = []

clustering_performance = ClusteringPerformance(dataset)

### K-Means

In [None]:
clustering = KMeans(NUM_CLUSTERING_PROTOTYPES)
clustering.fit(dataset.X)
predicted_labels = clustering.predict(dataset.X)

clustering_performance.measure_clustering("KMeans", clustering.cluster_centers_, predicted_labels)
for method in EMBEDDING_METHODS:
    plt.figure(figsize=(6,6))
    visualize_prototypes(clustering.cluster_centers_, dataset, embedding_method=method)
    plt.title(f"Embedding: {method}")
    plt.tight_layout()
    plt.show()

### Winner Takes All

In [None]:
clustering = WinnerTakesAll(NUM_CLUSTERING_PROTOTYPES, NUM_FEATURES)
clustering.fit(dataset.X)
predicted_labels = clustering.predict(dataset.X)

clustering_performance.measure_clustering("WinnerTakesAll", clustering.prototypes, predicted_labels)
for method in EMBEDDING_METHODS:
    plt.figure(figsize=(6,6))
    visualize_prototypes(clustering.prototypes, dataset, embedding_method=method)
    plt.title(f"Embedding: {method}")
    plt.tight_layout()
    plt.show()

### FSCL

In [None]:
clustering = FSCL(NUM_CLUSTERING_PROTOTYPES, NUM_FEATURES)
clustering.fit(dataset.X)
predicted_labels = clustering.predict(dataset.X)

clustering_performance.measure_clustering("FSCL", clustering.prototypes, predicted_labels)
for method in EMBEDDING_METHODS:
    plt.figure(figsize=(6,6))
    visualize_prototypes(clustering.prototypes, dataset, embedding_method=method)
    plt.title(f"Embedding: {method}")
    plt.tight_layout()
    plt.show()

### RPCL

In [None]:
clustering = RPCL(NUM_CLUSTERING_PROTOTYPES, NUM_FEATURES)
clustering.fit(dataset.X,best_matching_unit_learning_rate=1e-3,rival_matching_unit_learning_rate=2.75e-4)
predicted_labels = clustering.predict(dataset.X)

clustering_performance.measure_clustering("RPCL", clustering.prototypes, predicted_labels)
for method in EMBEDDING_METHODS:
    plt.figure(figsize=(6,6))
    visualize_prototypes(clustering.prototypes, dataset, embedding_method=method)
    plt.title(f"Embedding: {method}")
    plt.tight_layout()
    plt.show()

### Base ClAM

In [None]:
# This specifically will NOT work well with NUM_FEATURES=2 due to the masking of data in ClAM.

clustering = ClAMClustering(NUM_CLUSTERING_PROTOTYPES, NUM_FEATURES, beta=1, time_constant=1e0)
torchX = torch.tensor(dataset.X)
clusteringPerformanceHistoryCallback = ClusteringPerformanceHistoryCallback(clustering, torchX, dataset.y)
clustering.add_training_callback(clusteringPerformanceHistoryCallback)
prototypeSeparationHistoryCallback = PrototypeSeparationHistoryCallback(clustering)
clustering.add_training_callback(prototypeSeparationHistoryCallback)

loss = clustering.fit(torchX, num_epochs=1000, batch_size=64, mask_bernoulli_parameter=0.8)
# clusteringPerformanceHistory = pd.DataFrame(clusteringPerformanceHistoryCallback.clustering_performance_history)
# clusteringPerformanceHistory["loss"] = loss
# display(clusteringPerformanceHistory)
# for col in clusteringPerformanceHistory:
#     y = clusteringPerformanceHistory[col]
#     X = np.arange(len(y))
#     plt.plot(X, y)
#     plt.xlabel("Epoch")
#     plt.ylabel(f"{col}")
#     plt.show()

# prototypeSeparationHistory = pd.DataFrame(prototypeSeparationHistoryCallback.prototype_separation_history)
# display(prototypeSeparationHistory)
# for col in prototypeSeparationHistory:
#     y = prototypeSeparationHistory[col]
#     X = np.arange(len(y))
#     plt.plot(X, y)
#     plt.xlabel("Epoch")
#     plt.ylabel(f"{col}")
#     plt.show()


predicted_labels = clustering.predict(torchX).detach().cpu().numpy()
clustering_performance.measure_clustering("ClAM", clustering.prototypes.detach().cpu().numpy(), predicted_labels)
for method in EMBEDDING_METHODS:
    plt.figure(figsize=(6,6))
    visualize_prototypes(clustering.prototypes.detach().cpu().numpy(), dataset, embedding_method=method)
    plt.title(f"Embedding: {method}")
    plt.tight_layout()
    plt.show()

### Regularized ClAM

In [None]:
clustering = RegularizedClAM(
    NUM_CLUSTERING_PROTOTYPES, 
    NUM_FEATURES, 
    regularization_lambda=1e-4,
    regularization_exponent=2,
    beta=1, 
    time_constant=1e0
)
torchX = torch.tensor(dataset.X)
clusteringPerformanceHistoryCallback = ClusteringPerformanceHistoryCallback(clustering, torchX, dataset.y)
clustering.add_training_callback(clusteringPerformanceHistoryCallback)
prototypeSeparationHistoryCallback = PrototypeSeparationHistoryCallback(clustering)
clustering.add_training_callback(prototypeSeparationHistoryCallback)


loss = clustering.fit(torchX, num_epochs=1000, batch_size=64, mask_bernoulli_parameter=0.8)
# clusteringPerformanceHistory = pd.DataFrame(clusteringPerformanceHistoryCallback.clustering_performance_history)
# clusteringPerformanceHistory["loss"] = loss
# display(clusteringPerformanceHistory)
# for col in clusteringPerformanceHistory:
#     y = clusteringPerformanceHistory[col]
#     X = np.arange(len(y))
#     plt.plot(X, y)
#     plt.xlabel("Epoch")
#     plt.ylabel(f"{col}")
#     plt.show()
    
# prototypeSeparationHistory = pd.DataFrame(prototypeSeparationHistoryCallback.prototype_separation_history)
# display(prototypeSeparationHistory)
# for col in prototypeSeparationHistory:
#     y = prototypeSeparationHistory[col]
#     X = np.arange(len(y))
#     plt.plot(X, y)
#     plt.xlabel("Epoch")
#     plt.ylabel(f"{col}")
#     plt.show()


predicted_labels = clustering.predict(torchX).detach().cpu().numpy()
clustering_performance.measure_clustering("L2RegularizedClAM", clustering.prototypes.detach().cpu().numpy(), predicted_labels)
for method in EMBEDDING_METHODS:
    plt.figure(figsize=(6,6))
    visualize_prototypes(clustering.prototypes.detach().cpu().numpy(), dataset, embedding_method=method)
    plt.title(f"Embedding: {method}")
    plt.tight_layout()
    plt.show()

### Clustering Performance Data Analysis

In [None]:
df = clustering_performance.get_dataframe()

plt.bar(df["clustering_name"], df["prototype_mean_squared_distance"], yerr=df["prototype_std_squared_distance"])
plt.xticks(rotation=25)
plt.ylabel("Mean Squared Distance between Prototypes")
plt.tight_layout()
plt.show()

plt.bar(df["clustering_name"], df["prototype_mean_min_squared_distance"], yerr=df["prototype_std_min_squared_distance"])
plt.xticks(rotation=25)
plt.ylabel("Mean Squared Minimum Distance between Prototypes")
plt.tight_layout()
plt.show()

plt.bar(df["clustering_name"], df["clustering_silhouette"], label="Silhouette Score")
plt.xticks(rotation=25)
plt.ylabel("Silhouette Score")
plt.tight_layout()
plt.show()

plt.bar(df["clustering_name"], df["clustering_nmi"], label="Normalized Mutual Information")
plt.xticks(rotation=25)
plt.ylabel("Normalized Mutual Information")
plt.tight_layout()
plt.show()

# MNIST Dataset

In [None]:
NUM_POINTS = 1000
dataset = MNISTDataset(NUM_POINTS)
# dataset = FashionMNISTDataset(NUM_POINTS)
# NUM_DATA_PROTOTYPES = 10
NUM_FEATURES = dataset.num_features

NUM_CLUSTERING_PROTOTYPES = 20

def visualize_MNIST_prototypes(prototypes: np.ndarray, ncols: int = 5, imsize: float=3):
    prototypes = prototypes.reshape(-1, 28, 28)
    nrows = int(np.ceil(prototypes.shape[0]/ncols))
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*imsize, nrows*imsize))
    for i, ax in enumerate(axes.ravel()):
        ax.set_xticks([])
        ax.set_yticks([])

        if i >= prototypes.shape[0]:
            ax.set_frame_on(False)
            continue

        prototype = prototypes[i, :]
        ax.imshow(prototype, cmap="grey")
    plt.tight_layout()
    plt.show()

# EMBEDDING_METHODS = ["TSNE","LocallyLinear","Isomap","SpectralEmbedding",]
EMBEDDING_METHODS = ["TSNE"]
# EMBEDDING_METHODS = []

clustering_performance = ClusteringPerformance(dataset)

### K-Means

In [None]:
clustering = KMeans(NUM_CLUSTERING_PROTOTYPES)
clustering.fit(dataset.X)
predicted_labels = clustering.predict(dataset.X)

clustering_performance.measure_clustering("KMeans", clustering.cluster_centers_, predicted_labels)
for method in EMBEDDING_METHODS:
    visualize_prototypes(clustering.cluster_centers_, dataset, embedding_method=method)
    plt.title(f"Embedding: {method}")
    plt.show()

### Winner Takes All

In [None]:
clustering = WinnerTakesAll(NUM_CLUSTERING_PROTOTYPES, NUM_FEATURES)
clustering.fit(dataset.X, num_epochs=30)
predicted_labels = clustering.predict(dataset.X)

clustering_performance.measure_clustering("WinnerTakesAll", clustering.prototypes, predicted_labels)
visualize_MNIST_prototypes(clustering.prototypes)
for method in EMBEDDING_METHODS:
    visualize_prototypes(clustering.prototypes, dataset, embedding_method=method)
    plt.title(f"Embedding: {method}")
    plt.show()

### FSCL

In [None]:
clustering = FSCL(NUM_CLUSTERING_PROTOTYPES, NUM_FEATURES)
clustering.fit(dataset.X, num_epochs=30)
predicted_labels = clustering.predict(dataset.X)

clustering_performance.measure_clustering("FSCL", clustering.prototypes, predicted_labels)
visualize_MNIST_prototypes(clustering.prototypes)
for method in EMBEDDING_METHODS:
    visualize_prototypes(clustering.prototypes, dataset, embedding_method=method)
    plt.title(f"Embedding: {method}")
    plt.show()

### RPCL

In [None]:
clustering = RPCL(NUM_CLUSTERING_PROTOTYPES, NUM_FEATURES)
clustering.fit(dataset.X, num_epochs=50,best_matching_unit_learning_rate=1e-3,rival_matching_unit_learning_rate=3.5e-4)
predicted_labels = clustering.predict(dataset.X)

clustering_performance.measure_clustering("RPCL", clustering.prototypes, predicted_labels)
visualize_MNIST_prototypes(clustering.prototypes)
for method in EMBEDDING_METHODS:
    visualize_prototypes(clustering.prototypes, dataset, embedding_method=method)
    plt.title(f"Embedding: {method}")
    plt.show()

### Base ClAM

In [None]:
clustering = ClAMClustering(NUM_CLUSTERING_PROTOTYPES, NUM_FEATURES, beta=0.1, time_constant=1e0)
torchX = torch.tensor(dataset.X)
clusteringPerformanceHistoryCallback = ClusteringPerformanceHistoryCallback(clustering, torchX, dataset.y)
clustering.add_training_callback(clusteringPerformanceHistoryCallback)
prototypeSeparationHistoryCallback = PrototypeSeparationHistoryCallback(clustering)
clustering.add_training_callback(prototypeSeparationHistoryCallback)

loss = clustering.fit(torchX, num_epochs=1000, batch_size=64, mask_bernoulli_parameter=0.8)
# clusteringPerformanceHistory = pd.DataFrame(clusteringPerformanceHistoryCallback.clustering_performance_history)
# clusteringPerformanceHistory["loss"] = loss
# display(clusteringPerformanceHistory)
# for col in clusteringPerformanceHistory:
#     y = clusteringPerformanceHistory[col]
#     X = np.arange(len(y))
#     plt.plot(X, y)
#     plt.xlabel("Epoch")
#     plt.ylabel(f"{col}")
#     plt.show()
    
prototypeSeparationHistory = pd.DataFrame(prototypeSeparationHistoryCallback.prototype_separation_history)
display(prototypeSeparationHistory)
for col in prototypeSeparationHistory:
    y = prototypeSeparationHistory[col]
    X = np.arange(len(y))
    plt.plot(X, y)
    plt.xlabel("Epoch")
    plt.ylabel(f"{col}")
    plt.show()

clustering_performance.measure_clustering("ClAM", clustering.prototypes.detach().cpu().numpy(), predicted_labels)
visualize_MNIST_prototypes(clustering.prototypes.detach().cpu().numpy())
for method in EMBEDDING_METHODS:
    visualize_prototypes(clustering.prototypes.detach().cpu().numpy(), dataset, embedding_method=method)
    plt.title(f"Embedding: {method}")
    plt.show()

### Regularized ClAM

In [None]:
clustering = RegularizedClAM(
    NUM_CLUSTERING_PROTOTYPES, 
    NUM_FEATURES, 
    regularization_exponent=2,
    regularization_lambda=1e-5,
    beta=10, 
    time_constant=1e0
)
torchX = torch.tensor(dataset.X)
clusteringPerformanceHistoryCallback = ClusteringPerformanceHistoryCallback(clustering, torchX, dataset.y)
clustering.add_training_callback(clusteringPerformanceHistoryCallback)
prototypeSeparationHistoryCallback = PrototypeSeparationHistoryCallback(clustering)
clustering.add_training_callback(prototypeSeparationHistoryCallback)

loss = clustering.fit(torchX, num_epochs=1000, batch_size=64, mask_bernoulli_parameter=0.8)
# clusteringPerformanceHistory = pd.DataFrame(clusteringPerformanceHistoryCallback.clustering_performance_history)
# clusteringPerformanceHistory["loss"] = loss
# display(clusteringPerformanceHistory)
# for col in clusteringPerformanceHistory:
#     y = clusteringPerformanceHistory[col]
#     X = np.arange(len(y))
#     plt.plot(X, y)
#     plt.xlabel("Epoch")
#     plt.ylabel(f"{col}")
#     plt.show()
    
prototypeSeparationHistory = pd.DataFrame(prototypeSeparationHistoryCallback.prototype_separation_history)
display(prototypeSeparationHistory)
for col in prototypeSeparationHistory:
    y = prototypeSeparationHistory[col]
    X = np.arange(len(y))
    plt.plot(X, y)
    plt.xlabel("Epoch")
    plt.ylabel(f"{col}")
    plt.show()

clustering_performance.measure_clustering("L2RegularizedClAM", clustering.prototypes.detach().cpu().numpy(), predicted_labels)
visualize_MNIST_prototypes(clustering.prototypes.detach().cpu().numpy())
for method in EMBEDDING_METHODS:
    visualize_prototypes(clustering.prototypes.detach().cpu().numpy(), dataset, embedding_method=method)
    plt.title(f"Embedding: {method}")
    plt.show()