In [5]:
import numpy as np 
import torch as t
import matplotlib.pyplot as plt
from sklearn.metrics import adjusted_mutual_info_score, adjusted_rand_score, silhouette_score
from sklearn.datasets import make_blobs
from torch.distributions import MultivariateNormal as N

# Set manual seeds for reproducibility
t.manual_seed(40)
np.random.seed(40)

# Generate synthetic data once (consistent across seeds)
NUM_SAMPLES = 1000
NUM_FEATURES = 2
NUM_CLASSES = 4

# Generate blobs with centers close to each other
centers = [[0, 0], [1, 1], [1, -1], [-1, -1]]
cluster_std = [0.5, 0.7, 0.9, 0.11]  # Standard deviations

X_np, y_np = make_blobs(
    n_samples=NUM_SAMPLES,
    centers=centers,
    cluster_std=cluster_std,
    random_state=0
)
X = t.tensor(X_np, dtype=t.float32)
y = t.tensor(y_np, dtype=t.int64)

# Add random noise points
NUM_NOISE = 100
noise = t.rand(NUM_NOISE, NUM_FEATURES) * 6 - 3  # Uniformly between -3 and 3
X = t.cat([X, noise], dim=0)
y = t.cat([y, t.full((NUM_NOISE,), -1, dtype=t.int64)], dim=0)  # Label noise points as -1

# Constants
BATCH = X.shape[0]
DIM = X.shape[1]
GUESS_CLASSES = 4

# Define number of seeds
num_seeds = 5
seeds = [0, 1, 2, 3, 4]

# Initialize accumulators for metrics
metrics_gmm = {'ami': [], 'ari': [], 'silhouette': []}

for seed in seeds:
    # Set seed for reproducibility
    t.manual_seed(seed)
    np.random.seed(seed)
    
    ############################
    # Gaussian Mixture Model (GMM) with Trainable Prior
    ############################
    
    # Initialize parameters
    mu_gmm0 = t.rand(GUESS_CLASSES, DIM) * 0.1
    # print(mu_gmm0)
    s_gmm0 = t.rand(GUESS_CLASSES, DIM, DIM) * 0.1
    s_gmm0 = s_gmm0 @ s_gmm0.transpose(-2, -1) + t.einsum('ij,k->kij', t.eye(DIM), t.ones(GUESS_CLASSES))    
    # to replicate the autodiff results that run all the models one after the other 

    mu_gmm1 = t.rand(GUESS_CLASSES, DIM) * 0.1
    s_gmm1 = t.rand(GUESS_CLASSES, DIM, DIM) * 0.1
    s_gmm1 = s_gmm1 @ s_gmm1.transpose(-2, -1) + t.einsum('ij,k->kij', t.eye(DIM), t.ones(GUESS_CLASSES))

    mu_gmm = t.rand(GUESS_CLASSES, DIM) * 0.1
    s_gmm = t.rand(GUESS_CLASSES, DIM, DIM) * 0.1
    s_gmm = s_gmm @ s_gmm.transpose(-2, -1) + t.einsum('ij,k->kij', t.eye(DIM), t.ones(GUESS_CLASSES))
    
    # Initialize prior logits
    prior_logits = t.zeros(GUESS_CLASSES, requires_grad=False)
    
    for epoch in range(200):
        # E-step
        prior = t.distributions.Categorical(logits=prior_logits)
        dis = N(mu_gmm, s_gmm)
        # log_p_x = dis.log_prob(X[:, None])  # Shape: (BATCH, GUESS_CLASSES)
        # z_hard = log_p_x.argmax(-1)
        log_p_x_given_z = dis.log_prob(X[:, None])  # Shape: (BATCH, GUESS_CLASSES)
        log_pi = t.log(prior.probs)[None, :]  # Shape: (1, GUESS_CLASSES)
        log_p_xz = log_p_x_given_z + log_pi  # Shape: (BATCH, GUESS_CLASSES)
        z_hard = log_p_xz.argmax(-1)  # Assign each data point to the cluster with highest posterior probability
        
        # M-step
        for k in range(GUESS_CLASSES):
            X_k = X[z_hard == k]
            if len(X_k) > 0:
                mu_gmm[k] = X_k.mean(0)
                x_minus_mu = X_k - mu_gmm[k]
                s_gmm[k] = (x_minus_mu[:, :, None] @ x_minus_mu[:, None, :]).mean(0)
            else:
                mu_gmm[k] = t.rand(DIM) * 0.1
                s_k = t.rand(DIM, DIM) * 0.1
                s_gmm[k] = s_k @ s_k.transpose(-2, -1) + t.eye(DIM)
    
    # Evaluate Hard EM
    labels_gmm = z_hard.numpy()

    # Exclude noise points for evaluation
    mask = y.numpy() != -1
    y_eval = y.numpy()[mask]
    labels_gmm_eval = labels_gmm[mask]
    X_eval = X.numpy()[mask]
    
    # GMM Evaluation
    metrics_gmm['ami'].append(adjusted_mutual_info_score(labels_gmm_eval, y_eval))
    metrics_gmm['ari'].append(adjusted_rand_score(labels_gmm_eval, y_eval))
    metrics_gmm['silhouette'].append(silhouette_score(X_eval, labels_gmm_eval))

def average_metrics(metrics):
    return {k: np.mean(v) for k, v in metrics.items()}

# Compute standard deviation metrics
def std_metrics(metrics):
    return {k: np.std(v) for k, v in metrics.items()}

avg_gmm = average_metrics(metrics_gmm)
std_gmm = std_metrics(metrics_gmm)

# Print averaged metrics with standard deviation
print('Averaged over 5 seeds:\n')

# GMM Evaluation
print('GMM adjusted MI: {:.4f} ± {:.4f}'.format(avg_gmm['ami'], std_gmm['ami']))
print('GMM adjusted Rand Index: {:.4f} ± {:.4f}'.format(avg_gmm['ari'], std_gmm['ari']))
print('GMM Silhouette Score: {:.4f} ± {:.4f}'.format(avg_gmm['silhouette'], std_gmm['silhouette']))


Averaged over 5 seeds:

GMM adjusted MI: 0.6462 ± 0.0159
GMM adjusted Rand Index: 0.6312 ± 0.0243
GMM Silhouette Score: 0.4217 ± 0.0138


In [14]:
# learnable prior

import numpy as np 
import torch as t
import matplotlib.pyplot as plt
from sklearn.metrics import adjusted_mutual_info_score, adjusted_rand_score, silhouette_score
from sklearn.datasets import make_blobs
from torch.distributions import MultivariateNormal as N

# Set manual seeds for reproducibility
t.manual_seed(40)
np.random.seed(40)

# Generate synthetic data once (consistent across seeds)
NUM_SAMPLES = 1000
NUM_FEATURES = 2
NUM_CLASSES = 4

# Generate blobs with centers close to each other
centers = [[0, 0], [1, 1], [1, -1], [-1, -1]]
cluster_std = [0.5, 0.7, 0.9, 0.11]  # Standard deviations

X_np, y_np = make_blobs(
    n_samples=NUM_SAMPLES,
    centers=centers,
    cluster_std=cluster_std,
    random_state=0
)
X = t.tensor(X_np, dtype=t.float32)
y = t.tensor(y_np, dtype=t.int64)

# Add random noise points
NUM_NOISE = 100
noise = t.rand(NUM_NOISE, NUM_FEATURES) * 6 - 3  # Uniformly between -3 and 3
X = t.cat([X, noise], dim=0)
y = t.cat([y, t.full((NUM_NOISE,), -1, dtype=t.int64)], dim=0)  # Label noise points as -1

# Constants
BATCH = X.shape[0]
DIM = X.shape[1]
GUESS_CLASSES = 4

# Define number of seeds
num_seeds = 5
seeds = [0, 1, 2, 3, 4]

# Initialize accumulators for metrics
metrics_gmm = {'ami': [], 'ari': [], 'silhouette': []}

for seed in seeds:
    # Set seed for reproducibility
    t.manual_seed(seed)
    np.random.seed(seed)
    
    ############################
    # Gaussian Mixture Model (GMM) with Trainable Prior
    ############################
    
    # Initialize parameters
    mu_gmm0 = t.rand(GUESS_CLASSES, DIM) * 0.1
    # print(mu_gmm0)
    s_gmm0 = t.rand(GUESS_CLASSES, DIM, DIM) * 0.1
    s_gmm0 = s_gmm0 @ s_gmm0.transpose(-2, -1) + t.einsum('ij,k->kij', t.eye(DIM), t.ones(GUESS_CLASSES))    
    # to replicate the autodiff results that run all the models one after the other 

    mu_gmm1 = t.rand(GUESS_CLASSES, DIM) * 0.1
    s_gmm1 = t.rand(GUESS_CLASSES, DIM, DIM) * 0.1
    s_gmm1 = s_gmm1 @ s_gmm1.transpose(-2, -1) + t.einsum('ij,k->kij', t.eye(DIM), t.ones(GUESS_CLASSES))

    mu_gmm = t.rand(GUESS_CLASSES, DIM) * 0.1
    s_gmm = t.rand(GUESS_CLASSES, DIM, DIM) * 0.1
    s_gmm = s_gmm @ s_gmm.transpose(-2, -1) + t.einsum('ij,k->kij', t.eye(DIM), t.ones(GUESS_CLASSES))
    
    # Initialize prior logits
    prior_logits = t.zeros(GUESS_CLASSES, requires_grad=False)
    
    for epoch in range(200):
        # E-step
        prior = t.distributions.Categorical(logits=prior_logits)
        dis = N(mu_gmm, s_gmm)
        # log_p_x = dis.log_prob(X[:, None])  # Shape: (BATCH, GUESS_CLASSES)
        # z_hard = log_p_x.argmax(-1)
        log_p_x_given_z = dis.log_prob(X[:, None])  # Shape: (BATCH, GUESS_CLASSES)
        log_pi = t.log(prior.probs)[None, :]  # Shape: (1, GUESS_CLASSES)
        log_p_xz = log_p_x_given_z + log_pi  # Shape: (BATCH, GUESS_CLASSES)
        z_hard = log_p_xz.argmax(-1)  # Assign each data point to the cluster with highest posterior probability
        
        # M-step
        for k in range(GUESS_CLASSES):
            X_k = X[z_hard == k]
            if len(X_k) > 0:
                mu_gmm[k] = X_k.mean(0)
                x_minus_mu = X_k - mu_gmm[k]
                s_gmm[k] = (x_minus_mu[:, :, None] @ x_minus_mu[:, None, :]).mean(0)
            else:
                mu_gmm[k] = t.rand(DIM) * 0.1
                s_k = t.rand(DIM, DIM) * 0.1
                s_gmm[k] = s_k @ s_k.transpose(-2, -1) + t.eye(DIM)

        N_k = t.bincount(z_hard)#z_hard.sum(dim=0)
        # print(N_k)
        prior_logits = N_k.log()
        
    # Evaluate Hard EM
    labels_gmm = z_hard.numpy()

    # Exclude noise points for evaluation
    mask = y.numpy() != -1
    y_eval = y.numpy()[mask]
    labels_gmm_eval = labels_gmm[mask]
    X_eval = X.numpy()[mask]
    
    # GMM Evaluation
    metrics_gmm['ami'].append(adjusted_mutual_info_score(labels_gmm_eval, y_eval))
    metrics_gmm['ari'].append(adjusted_rand_score(labels_gmm_eval, y_eval))
    metrics_gmm['silhouette'].append(silhouette_score(X_eval, labels_gmm_eval))

def average_metrics(metrics):
    return {k: np.mean(v) for k, v in metrics.items()}

# Compute standard deviation metrics
def std_metrics(metrics):
    return {k: np.std(v) for k, v in metrics.items()}

avg_gmm = average_metrics(metrics_gmm)
std_gmm = std_metrics(metrics_gmm)

# Print averaged metrics with standard deviation
print('Averaged over 5 seeds:\n')

# GMM Evaluation
print('GMM adjusted MI: {:.4f} ± {:.4f}'.format(avg_gmm['ami'], std_gmm['ami']))
print('GMM adjusted Rand Index: {:.4f} ± {:.4f}'.format(avg_gmm['ari'], std_gmm['ari']))
print('GMM Silhouette Score: {:.4f} ± {:.4f}'.format(avg_gmm['silhouette'], std_gmm['silhouette']))


Averaged over 5 seeds:

GMM adjusted MI: 0.5374 ± 0.0877
GMM adjusted Rand Index: 0.3480 ± 0.0432
GMM Silhouette Score: 0.2074 ± 0.1277
