In [25]:
# fixed prior softmax update

import numpy as np 
import torch as t
import matplotlib.pyplot as plt
from sklearn.metrics import adjusted_mutual_info_score, adjusted_rand_score, silhouette_score
from sklearn.datasets import make_blobs
from torch.distributions import MultivariateNormal as N

# Set manual seeds for reproducibility
t.manual_seed(40)
np.random.seed(40)

# Generate synthetic data once (consistent across seeds)
NUM_SAMPLES = 1000
NUM_FEATURES = 2
NUM_CLASSES = 4

# Generate blobs with centers close to each other
centers = [[0, 0], [1, 1], [1, -1], [-1, -1]]
cluster_std = [0.5, 0.7, 0.9, 0.11]  # Standard deviations

X_np, y_np = make_blobs(
    n_samples=NUM_SAMPLES,
    centers=centers,
    cluster_std=cluster_std,
    random_state=0
)
X = t.tensor(X_np, dtype=t.float32)
y = t.tensor(y_np, dtype=t.int64)

# Add random noise points
NUM_NOISE = 100
noise = t.rand(NUM_NOISE, NUM_FEATURES) * 6 - 3  # Uniformly between -3 and 3
X = t.cat([X, noise], dim=0)
y = t.cat([y, t.full((NUM_NOISE,), -1, dtype=t.int64)], dim=0)  # Label noise points as -1

# Constants
BATCH = X.shape[0]
DIM = X.shape[1]
GUESS_CLASSES = 4

# Define number of seeds
num_seeds = 5
seeds = [0, 1, 2, 3, 4]

# Initialize accumulators for metrics
metrics_gmm = {'ami': [], 'ari': [], 'silhouette': []}

for seed in seeds:
    # Set seed for reproducibility
    t.manual_seed(seed)
    np.random.seed(seed)
    
    #################################
    # Gaussian Mixture Model (GMM) #
    ################################
    
    # Initialize parameters
    mu_gmm = t.rand(GUESS_CLASSES, DIM) * 0.1
    s_gmm = t.rand(GUESS_CLASSES, DIM, DIM) * 0.1
    s_gmm = s_gmm @ s_gmm.transpose(-2, -1) + t.einsum('ij,k->kij', t.eye(DIM), t.ones(GUESS_CLASSES))
    
    for epoch in range(200):
        # E-step
        prior = t.distributions.Categorical(logits=t.zeros(GUESS_CLASSES))
        dis = N(mu_gmm, s_gmm)
        
        log_p_x_given_z = dis.log_prob(X[:, None])  # Shape (BATCH, GUESS_CLASSES)
        log_p_z = prior.probs.log()[None, :]  # Shape (1, GUESS_CLASSES)
        log_p_xz = log_p_x_given_z + log_p_z  # Shape (BATCH, GUESS_CLASSES)
        
        # Compute q_gmm (posterior probabilities)
        q_gmm = t.softmax(log_p_xz, dim=1)  # Shape (BATCH, GUESS_CLASSES)
        
        # M-step
        mu_gmm = (q_gmm[:, :, None] * X[:, None, :]).sum(0) / q_gmm.sum(0)[:, None]
        x_minus_mu = X[:, None, :] - mu_gmm[None, :, :]
        s_gmm = ((x_minus_mu[:, :, :, None] @ x_minus_mu[:, :, None, :]) * q_gmm[:, :, None, None]).sum(0) / q_gmm.sum(0)[:, None, None]
    
    # Evaluate GMM
    labels_gmm = q_gmm.argmax(1).numpy()

    # Exclude noise points for evaluation
    mask = y.numpy() != -1
    y_eval = y.numpy()[mask]
    labels_gmm_eval = labels_gmm[mask]
    X_eval = X.numpy()[mask]
    
    # GMM Evaluation
    metrics_gmm['ami'].append(adjusted_mutual_info_score(labels_gmm_eval, y_eval))
    metrics_gmm['ari'].append(adjusted_rand_score(labels_gmm_eval, y_eval))
    metrics_gmm['silhouette'].append(silhouette_score(X_eval, labels_gmm_eval))

def average_metrics(metrics):
    return {k: np.mean(v) for k, v in metrics.items()}

# Compute standard deviation metrics
def std_metrics(metrics):
    return {k: np.std(v) for k, v in metrics.items()}

avg_gmm = average_metrics(metrics_gmm)
std_gmm = std_metrics(metrics_gmm)

# Print averaged metrics with standard deviation
print('Averaged over 5 seeds:\n')

# GMM Evaluation
print('GMM adjusted MI: {:.4f} ± {:.4f}'.format(avg_gmm['ami'], std_gmm['ami']))
print('GMM adjusted Rand Index: {:.4f} ± {:.4f}'.format(avg_gmm['ari'], std_gmm['ari']))
print('GMM Silhouette Score: {:.4f} ± {:.4f}'.format(avg_gmm['silhouette'], std_gmm['silhouette']))


Averaged over 5 seeds:

GMM adjusted MI: 0.5439 ± 0.0040
GMM adjusted Rand Index: 0.4854 ± 0.0043
GMM Silhouette Score: 0.3301 ± 0.0073


In [19]:
# learnable prior, totally classical update

import numpy as np 
import torch as t
import matplotlib.pyplot as plt
from sklearn.metrics import adjusted_mutual_info_score, adjusted_rand_score, silhouette_score
from sklearn.datasets import make_blobs
from torch.distributions import MultivariateNormal as N

# Set manual seeds for reproducibility
t.manual_seed(40)
np.random.seed(40)

# Generate synthetic data once (consistent across seeds)
NUM_SAMPLES = 1000
NUM_FEATURES = 2
NUM_CLASSES = 4

# Generate blobs with centers close to each other
centers = [[0, 0], [1, 1], [1, -1], [-1, -1]]
cluster_std = [0.5, 0.7, 0.9, 0.11]  # Standard deviations

X_np, y_np = make_blobs(
    n_samples=NUM_SAMPLES,
    centers=centers,
    cluster_std=cluster_std,
    random_state=0
)
X = t.tensor(X_np, dtype=t.float32)
y = t.tensor(y_np, dtype=t.int64)

# Add random noise points
NUM_NOISE = 100
noise = t.rand(NUM_NOISE, NUM_FEATURES) * 6 - 3  # Uniformly between -3 and 3
X = t.cat([X, noise], dim=0)
y = t.cat([y, t.full((NUM_NOISE,), -1, dtype=t.int64)], dim=0)  # Label noise points as -1

# Constants
BATCH = X.shape[0]
DIM = X.shape[1]
GUESS_CLASSES = 4

# Define number of seeds
num_seeds = 5
seeds = [0, 1, 2, 3, 4]

# Initialize accumulators for metrics
metrics_gmm = {'ami': [], 'ari': [], 'silhouette': []}

for seed in seeds:
    # Set seed for reproducibility
    t.manual_seed(seed)
    np.random.seed(seed)
    
    ############################
    # Gaussian Mixture Model (GMM) with Standard EM Updates
    ############################
    
    # Initialize parameters
    mu_gmm = t.rand(GUESS_CLASSES, DIM) * 0.1
    s_gmm = t.rand(GUESS_CLASSES, DIM, DIM) * 0.1
    s_gmm = s_gmm @ s_gmm.transpose(-2, -1) + t.einsum('ij,k->kij', t.eye(DIM), t.ones(GUESS_CLASSES))
    pi_k = t.ones(GUESS_CLASSES) / GUESS_CLASSES  # Initialize mixing coefficients
    
    for epoch in range(200):
        # E-step
        dis = N(mu_gmm, s_gmm)
        log_p_x_given_z = dis.log_prob(X[:, None])  # Shape (BATCH, GUESS_CLASSES)
        p_x_given_z = log_p_x_given_z.exp()
        p_z = pi_k[None, :]  # Shape (1, GUESS_CLASSES)
        p_xz = p_x_given_z * p_z  # Shape (BATCH, GUESS_CLASSES)
        gamma_nk = p_xz / p_xz.sum(dim=1, keepdim=True)  # Shape (BATCH, GUESS_CLASSES)
        
        # M-step
        N_k = gamma_nk.sum(dim=0)  # Shape (GUESS_CLASSES,)
        pi_k = N_k / BATCH  # Update mixing coefficients
        
        mu_gmm = (gamma_nk[:, :, None] * X[:, None, :]).sum(0) / N_k[:, None]
        x_minus_mu = X[:, None, :] - mu_gmm[None, :, :]
        s_gmm = ((gamma_nk[:, :, None, None] * (x_minus_mu[:, :, :, None] * x_minus_mu[:, :, None, :])).sum(0)) / N_k[:, None, None]
        
        # Add a small value to the diagonal to prevent singularity
        s_gmm += t.eye(DIM)[None, :, :] * 1e-6
    
    # Evaluate GMM
    labels_gmm = gamma_nk.argmax(1).numpy()

    # Exclude noise points for evaluation
    mask = y.numpy() != -1
    y_eval = y.numpy()[mask]
    labels_gmm_eval = labels_gmm[mask]
    X_eval = X.numpy()[mask]
    
    # GMM Evaluation
    metrics_gmm['ami'].append(adjusted_mutual_info_score(labels_gmm_eval, y_eval))
    metrics_gmm['ari'].append(adjusted_rand_score(labels_gmm_eval, y_eval))
    metrics_gmm['silhouette'].append(silhouette_score(X_eval, labels_gmm_eval))

def average_metrics(metrics):
    return {k: np.mean(v) for k, v in metrics.items()}

# Compute standard deviation metrics
def std_metrics(metrics):
    return {k: np.std(v) for k, v in metrics.items()}

avg_gmm = average_metrics(metrics_gmm)
std_gmm = std_metrics(metrics_gmm)

# Print averaged metrics with standard deviation
print('Averaged over 5 seeds:\n')

# GMM Evaluation
print('GMM adjusted MI: {:.4f} ± {:.4f}'.format(avg_gmm['ami'], std_gmm['ami']))
print('GMM adjusted Rand Index: {:.4f} ± {:.4f}'.format(avg_gmm['ari'], std_gmm['ari']))
print('GMM Silhouette Score: {:.4f} ± {:.4f}'.format(avg_gmm['silhouette'], std_gmm['silhouette']))


Averaged over 5 seeds:

GMM adjusted MI: 0.6056 ± 0.0123
GMM adjusted Rand Index: 0.5305 ± 0.0049
GMM Silhouette Score: 0.3454 ± 0.0319


In [20]:
# learnable prior, softmax update

import numpy as np 
import torch as t
import matplotlib.pyplot as plt
from sklearn.metrics import adjusted_mutual_info_score, adjusted_rand_score, silhouette_score
from sklearn.datasets import make_blobs
from torch.distributions import MultivariateNormal as N

# Set manual seeds for reproducibility
t.manual_seed(40)
np.random.seed(40)

# Generate synthetic data once (consistent across seeds)
NUM_SAMPLES = 1000
NUM_FEATURES = 2
NUM_CLASSES = 4

# Generate blobs with centers close to each other
centers = [[0, 0], [1, 1], [1, -1], [-1, -1]]
cluster_std = [0.5, 0.7, 0.9, 0.11]  # Standard deviations

X_np, y_np = make_blobs(
    n_samples=NUM_SAMPLES,
    centers=centers,
    cluster_std=cluster_std,
    random_state=0
)
X = t.tensor(X_np, dtype=t.float32)
y = t.tensor(y_np, dtype=t.int64)

# Add random noise points
NUM_NOISE = 100
noise = t.rand(NUM_NOISE, NUM_FEATURES) * 6 - 3  # Uniformly between -3 and 3
X = t.cat([X, noise], dim=0)
y = t.cat([y, t.full((NUM_NOISE,), -1, dtype=t.int64)], dim=0)  # Label noise points as -1

# Constants
BATCH = X.shape[0]
DIM = X.shape[1]
GUESS_CLASSES = 4

# Define number of seeds
num_seeds = 5
seeds = [0, 1, 2, 3, 4]

# Initialize accumulators for metrics
metrics_gmm = {'ami': [], 'ari': [], 'silhouette': []}

for seed in seeds:
    # Set seed for reproducibility
    t.manual_seed(seed)
    np.random.seed(seed)
    
    ############################
    # Gaussian Mixture Model (GMM) with Trainable Prior
    ############################
    
    # Initialize parameters
    mu_gmm = t.rand(GUESS_CLASSES, DIM) * 0.1
    s_gmm = t.rand(GUESS_CLASSES, DIM, DIM) * 0.1
    s_gmm = s_gmm @ s_gmm.transpose(-2, -1) + t.einsum('ij,k->kij', t.eye(DIM), t.ones(GUESS_CLASSES))
    
    # Initialize prior logits
    prior_logits = t.zeros(GUESS_CLASSES, requires_grad=False)
    
    for epoch in range(200):
        # E-step
        dis = N(mu_gmm, s_gmm)
        log_p_x_given_z = dis.log_prob(X[:, None])  # Shape (BATCH, GUESS_CLASSES)
        log_p_z = t.log_softmax(prior_logits, dim=0)[None, :]  # Shape (1, GUESS_CLASSES)
        log_p_xz = log_p_x_given_z + log_p_z  # Shape (BATCH, GUESS_CLASSES)
        
        # Compute q_gmm (posterior probabilities)
        # log_q_gmm = log_p_xz - t.logsumexp(log_p_xz, dim=1, keepdim=True)
        # q_gmm = log_q_gmm.exp()  # Shape (BATCH, GUESS_CLASSES)

        q_gmm = t.softmax(log_p_xz, dim=1)  # Shape (BATCH, GUESS_CLASSES)
        
        # M-step
        N_k = q_gmm.sum(dim=0)  # Effective number of data points assigned to each cluster
        
        # Update prior logits directly based on N_k
        prior_logits = N_k.log()
        
        # Update means
        mu_gmm = (q_gmm[:, :, None] * X[:, None, :]).sum(0) / N_k[:, None]
        
        # Update covariances
        x_minus_mu = X[:, None, :] - mu_gmm[None, :, :]
        s_gmm = ((q_gmm[:, :, None, None] * (x_minus_mu[:, :, :, None] * x_minus_mu[:, :, None, :])).sum(0)) / N_k[:, None, None]
        
        # Avoid singular covariance matrices
        s_gmm += t.eye(DIM)[None, :, :] * 1e-6
    
    # Evaluate GMM
    labels_gmm = q_gmm.argmax(1).numpy()

    # Exclude noise points for evaluation
    mask = y.numpy() != -1
    y_eval = y.numpy()[mask]
    labels_gmm_eval = labels_gmm[mask]
    X_eval = X.numpy()[mask]
    
    # GMM Evaluation
    metrics_gmm['ami'].append(adjusted_mutual_info_score(labels_gmm_eval, y_eval))
    metrics_gmm['ari'].append(adjusted_rand_score(labels_gmm_eval, y_eval))
    metrics_gmm['silhouette'].append(silhouette_score(X_eval, labels_gmm_eval))

def average_metrics(metrics):
    return {k: np.mean(v) for k, v in metrics.items()}

# Compute standard deviation metrics
def std_metrics(metrics):
    return {k: np.std(v) for k, v in metrics.items()}

avg_gmm = average_metrics(metrics_gmm)
std_gmm = std_metrics(metrics_gmm)

# Print averaged metrics with standard deviation
print('Averaged over 5 seeds:\n')

# GMM Evaluation
print('GMM adjusted MI: {:.4f} ± {:.4f}'.format(avg_gmm['ami'], std_gmm['ami']))
print('GMM adjusted Rand Index: {:.4f} ± {:.4f}'.format(avg_gmm['ari'], std_gmm['ari']))
print('GMM Silhouette Score: {:.4f} ± {:.4f}'.format(avg_gmm['silhouette'], std_gmm['silhouette']))


Averaged over 5 seeds:

GMM adjusted MI: 0.6056 ± 0.0123
GMM adjusted Rand Index: 0.5305 ± 0.0049
GMM Silhouette Score: 0.3454 ± 0.0319
