In [None]:
import numpy as np
import matplotlib.pyplot as plt
from jcm import ClusterModel, run_gibbs
from jcm.utils import estimate_cluster

from sklearn.metrics import adjusted_rand_score
from sklearn.neighbors import KNeighborsClassifier
from scipy.stats import multivariate_normal, dirichlet, linregress
from sklearn.decomposition import PCA
from sklearn.mixture import GaussianMixture
from scipy.spatial.distance import jensenshannon

PLOTS_DIR = "./plots/"

np.random.seed(42)

### Define helper functions and basic test cases

In [None]:
def generate_words(N=100, V=50, doc_length=20, eta=0.1, z_true=None):
    if z_true is None or len(z_true) != N:
        raise ValueError("z_true must be provided and match N")
    
    K = 2  # fixed to 2 clusters
    word_probs = dirichlet.rvs([eta] * V, size=K)  # two topic distributions

    W = np.zeros((N, V), dtype=int)
    for i in range(N):
        k = z_true[i]
        W[i] = np.random.multinomial(doc_length, word_probs[k])
    
    return W, word_probs

def generate_embeddings(N=100, separation=2.0, sigma=1.0, z_true=None):
    if z_true is None or len(z_true) != N:
        raise ValueError("z_true must be provided and match N")
    
    # Two means along the y = x line
    direction = np.array([1.0, 1.0]) / np.sqrt(2)
    mu1 = -0.5 * separation * direction
    mu2 =  0.5 * separation * direction
    means = [mu1, mu2]
    
    cov = sigma**2 * np.eye(2)

    Y = np.zeros((N, 2))
    for i in range(N):
        k = z_true[i]
        Y[i] = multivariate_normal.rvs(mean=means[k], cov=cov)
    
    return Y, means


def plot_2d_embeddings(Y, z, title="Embeddings (2D)", ax=None):
    if ax is None:
        fig, ax = plt.subplots()
    scatter = ax.scatter(Y[:, 0], Y[:, 1], c=z, cmap="tab10", edgecolors="k", alpha=0.7)
    ax.set_title(title)
    ax.grid(True)
    return scatter

In [None]:
z = np.random.randint(0, 2, size=100)  # Random cluster assignments
W, word_probs = generate_words(N=100, V=10, doc_length=10, eta=0.1, z_true=z)

In [None]:
# Simulate setup
N = 200
z_true = np.array([0]*100 + [1]*100)
sep = 1.5
d = 2
W, word_probs = generate_words(N=N, V=50, doc_length=10, eta=0.1, z_true=z_true)
X, means = generate_embeddings(N=N, separation=sep, sigma=0.5, z_true=z_true)

# Plotting
fig, ax = plt.subplots()
plot_2d_embeddings(X, z_true, ax=ax, title=f"Embeddings with true labels, separation={sep}")
plt.show()

Use GMM to set initial cluster assignments based on embeddings.

In [None]:
K = 2
gmm = GaussianMixture(n_components=K, covariance_type='full', random_state=42)
gmm.fit(X)
labels = gmm.predict(X)
fig, ax = plt.subplots()
plot_embeddings(X, labels, ax, title="GMM Labels")
print(f"ARI: {adjusted_rand_score(z_true, labels)}")

In [None]:
v = np.zeros(d)
for k in range(K):
    cluster_k = X[labels == k]
    if len(cluster_k) == 0:
        raise ValueError(f"Cluster {k} has no points assigned.")
    v += np.var(cluster_k, axis=0) / K

S_0 = np.diag(v)  # Shared across all clusters
m_0 = np.zeros(d)  # Mean vector for the prior
kappa_0 = 0.05

# Initialise model using GMM labels as predictions
model = ClusterModel(X, W, K, labels, m_0, S_0, kappa_0=kappa_0)

M = 500
burnin = 200

samples, likelihoods = run_gibbs(model, M, burnin, track_likelihood=True, show_progress=True)
z_pred = estimate_cluster(samples, K)

In [None]:
fig, ax = plt.subplots()
plot_embeddings(X, z_pred, ax, title="M4R model Labels")
print(f"ARI: {adjusted_rand_score(z_true, z_pred)}")

### Varying separation

In [None]:
# Simulation parameters
N = 200
K = 2
V = 50
doc_length = 10
eta = 0.5
sigma = 0.5
z_true = np.array([0]*100 + [1]*100)

# Separation values (decreasing)
separations = [3.0, 2.0, 1.5, 1.0, 0.5]

# Set up the plot grid
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

# Generate data and run model for each separation
for idx, sep in enumerate(separations):
    # Generate data
    W, word_probs = generate_words(N=N, V=V, doc_length=doc_length, eta=eta, z_true=z_true)
    X, means = generate_embeddings(N=N, separation=sep, sigma=sigma, z_true=z_true)
    
    # Initialize GMM for initial cluster assignments
    gmm = GaussianMixture(n_components=K, covariance_type='full', random_state=42)
    gmm.fit(X)
    initial_labels = gmm.predict(X)
    
    m_0 = np.zeros(X.shape[1])  # Mean vector for the prior
    v = np.zeros(d)
    for k in range(K):
        cluster_k = X[labels == k]
        if len(cluster_k) == 0:
            raise ValueError(f"Cluster {k} has no points assigned.")
        v += np.var(cluster_k, axis=0) / K

    S_0 = np.diag(v)  # Shared across all clusters

    # Initialise model using GMM labels as predictions
    model = ClusterModel(X, W, K, labels, m_0, S_0, kappa_0=kappa_0)
    samples, _ = run_gibbs(model, M=500, burnin=200)
    z_pred = estimate_cluster(samples, K)

    # Compute ARI
    ari = adjusted_rand_score(z_true, z_pred)
    
    # Plot results
    ax = axes[idx]
    plot_2d_embeddings(X, model.z, title=f"Separation = {sep}, ARI = {ari:.3f}", ax=ax)

# Remove empty subplot (if any)
if len(separations) < len(axes):
    for i in range(len(separations), len(axes)):
        fig.delaxes(axes[i])

# Adjust layout
plt.tight_layout()
plt.savefig(PLOTS_DIR + 'separation_grid.png')

## Plotting separation against ARI, with lines for GMM model and M4R model

In [None]:
# Simulation parameters
N = 200
K = 2
V = 50
doc_length = 20
sigma = 0.5

# Separation values (decreasing)
separations = np.arange(0, 3, 0.1)
M = 20
eta_values = [0.5, 1.0, 5.0, 10.0]

ari_m4r = np.zeros((len(eta_values), len(separations), M))
ari_gmm = np.zeros((len(eta_values), len(separations), M))

# Generate data and run model for each separation
for i, eta in enumerate(eta_values):
    # Fix cluster labels and word distributions (weakly informative)
    z_true = np.array([0]*100 + [1]*100)
    W, word_probs = generate_words(N=N, V=V, doc_length=doc_length, eta=eta, z_true=z_true)
    for j, sep in enumerate(separations):
        for k in range(M): 
            # Generate data
            X, means = generate_embeddings(N=N, separation=sep, sigma=sigma, z_true=z_true)
            
            # Initialize GMM for initial cluster assignments
            gmm = GaussianMixture(n_components=K, covariance_type='full', random_state=42)
            gmm.fit(X)
            initial_labels = gmm.predict(X)

            ari_gmm[i, j, k] = adjusted_rand_score(z_true, initial_labels)
            
            # Set NIW priors
            m_0 = np.zeros(X.shape[1])  # Mean vector for the prior
            
            v = np.zeros(d)
            for c in range(K):
                cluster_k = X[initial_labels == c]
                if len(cluster_k) == 0:
                    raise ValueError(f"Cluster {c} has no points assigned.")
                v += np.var(cluster_k, axis=0) / K

            S_0 = np.diag(v)  # Shared across all clusters
            m_0 = np.zeros(d)  # Mean vector for the prior
            kappa_0 = 0.05

            # Initialise model using GMM labels as predictions
            model = ClusterModel(X, W, K, initial_labels, m_0, S_0, kappa_0=kappa_0)

            n_iter = 10
            burnin = 2

            samples, _ = run_gibbs(model, n_iter, burnin)
            z_pred = estimate_cluster(samples, K)

            # Compute ARI
            ari_m4r[i, j, k] = adjusted_rand_score(z_true, z_pred)



In [None]:
# Plotting ARI results
fig, ax = plt.subplots(figsize=(10, 6))

# Plot a separate M4R curve for each eta value
for i, eta_val in enumerate(eta_values):
    mean_ari = np.mean(ari_m4r[i], axis=1)
    std_ari = np.std(ari_m4r[i], axis=1)
    ax.plot(separations, mean_ari, label=f'M4R $\eta = $ {eta_val:.1f}')
    ax.fill_between(separations, mean_ari - std_ari, mean_ari + std_ari, alpha=0.2)

# Plot GMM (shared across all eta values)
mean_ari_gmm = np.mean(ari_gmm, axis=(0, 2))
std_ari_gmm = np.std(ari_gmm, axis=(0, 2))
ax.plot(separations, mean_ari_gmm, label='GMM', color='black')
ax.fill_between(separations, mean_ari_gmm - std_ari_gmm, mean_ari_gmm + std_ari_gmm, color='black', alpha=0.2)

ax.set_xlabel('Separation')
ax.set_ylabel('ARI')
ax.legend()
plt.tight_layout()
plt.savefig(PLOTS_DIR + 'ari_vs_separation_eta.png')
plt.show()

### Demonstrate likelihood weighting by $\alpha$


In [None]:
# Simulation parameters
N = 200
K = 2
V = 50
doc_length = 20
sigma = 0.5

# Separation values (finer grid for 0 to 0.5)
separations = np.concat([np.arange(0, 0.5, 0.025), np.arange(0.5, 3, 0.25)])
M = 20
alpha_values = [0.05, 0.1, 0.15, 0.2, 0.5, 1.0]

ari_m4r = np.zeros((len(alpha_values), len(separations), M))
ari_gmm = np.zeros((len(alpha_values), len(separations), M))

# Fix cluster labels and word distributions (weakly informative)
z_true = np.array([0]*100 + [1]*100)
eta=0.5
W, word_probs = generate_words(N=N, V=V, doc_length=doc_length, eta=eta, z_true=z_true)

# Generate data and run model for each weighting
for i, alpha in enumerate(alpha_values):
    for j, sep in enumerate(separations):
        for k in range(M): 
            # Generate data
            X, means = generate_embeddings(N=N, separation=sep, sigma=sigma, z_true=z_true)

            # Initialize GMM for initial cluster assignments
            gmm = GaussianMixture(n_components=K, covariance_type='full', random_state=42)
            gmm.fit(X)
            initial_labels = gmm.predict(X)

            ari_gmm[i, j, k] = adjusted_rand_score(z_true, initial_labels)
            
            # Set NIW priors
            m_0 = np.zeros(X.shape[1])  # Mean vector for the prior
            
            v = np.zeros(d)
            for c in range(K):
                cluster_k = X[initial_labels == c]
                if len(cluster_k) == 0:
                    raise ValueError(f"Cluster {c} has no points assigned.")
                v += np.var(cluster_k, axis=0) / K

            S_0 = np.diag(v)  # Shared across all clusters
            m_0 = np.zeros(d)  # Mean vector for the prior
            kappa_0 = 0.05

            # Initialise model using GMM labels as predictions
            model = ClusterModel(X, W, K, initial_labels, m_0, S_0, kappa_0=kappa_0, weight_W = alpha)

            n_iter = 10
            burnin = 2

            samples, _ = run_gibbs(model, n_iter, burnin)
            z_pred = estimate_cluster(samples, K)

            # Compute ARI
            ari_m4r[i, j, k] = adjusted_rand_score(z_true, z_pred)



In [None]:
# Plotting ARI results
fig, ax = plt.subplots(figsize=(10, 6))


# Plot GMM (shared across all eta values)
mean_ari_gmm = np.mean(ari_gmm[0], axis=1)
std_ari_gmm = np.std(ari_gmm[0], axis=1)
ax.plot(separations, mean_ari_gmm, label=r'GMM ($\alpha_\mathbf{{W}} = 0$)', color='black')
lower_bound = np.clip(mean_ari_gmm - std_ari_gmm, 0, 1)
upper_bound = np.clip(mean_ari_gmm + std_ari_gmm, 0, 1)
ax.fill_between(separations, lower_bound, upper_bound, color='black', alpha=0.2)

# Plot a separate M4R curve for each eta value
for i, alpha in enumerate(alpha_values):
    mean_ari = np.mean(ari_m4r[i], axis=1)
    std_ari = np.std(ari_m4r[i], axis=1)
    ax.plot(separations, mean_ari, label=rf'$\alpha_\mathbf{{W}} = $ {alpha:.2f}')
    lower_bound = np.clip(mean_ari - std_ari, 0, 1)
    upper_bound = np.clip(mean_ari + std_ari, 0, 1)
    ax.fill_between(separations, lower_bound, upper_bound, alpha=0.2)

ax.set_xlabel('Separation')
ax.set_ylabel('Adjusted Rand Index (ARI)')


ax.legend()
plt.tight_layout()
plt.savefig(PLOTS_DIR + 'ari_vs_separation_alpha.png')

plt.show()

finer grid between 0.5 0
more separations between 0 and 0.1

plot on Cora dataset, plotting ari against alpha


### Varying Dirichlet prior hyperparameter.

1. $\eta$ - (controls $\phi$) Dirichlet prior for word distributions spiky vs uniform word distribution within clusters
2. $\gamma$ - (controls $\psi$) Dirichlet prior for cluster assigments, smaller gamma favours uneven cluster assignments, larger gamma encourages more balanced cluster sizes.

Vary the document length, show how when we increase it becomes dominat over the likelihood, could compare the likelihood components $l_w$ vs $l_e$., weighted mix over $l_w$ and $l_e$.

Weight the posterior z prob by document lenght/embedding dimension i.e. rescale.


In [None]:
# Experiment settings
N, d, K, V = 200, 2, 2, 100
cluster_sep, doc_length = 1, 100

M = 50 # Number of simulations per eta
eta_values = np.arange(0.1, 10.5, 0.5)
ari_values = np.zeros((len(eta_values), M))
js_divergence = np.zeros((len(eta_values), M))
# Sample cluster labels
z_true = np.array([0]*100 + [1]*100)

# Loop over data-generating eta values
for i, eta in enumerate(tqdm(eta_values, desc="Eta grid")):
    for j in range(M):
        # Generate synthetic data
        X, means = generate_embeddings(N=N, separation=cluster_sep, sigma=sigma, z_true=z_true)
        W, word_probs = generate_words(N=N, V=V, doc_length=doc_length, eta=eta, z_true=z_true)
        js_divergence[i, j] = jensenshannon(word_probs[0], word_probs[1]) # use this as proxy for eta on later plot

        gmm = GaussianMixture(n_components=K, covariance_type='full', random_state=42)
        gmm.fit(X)
        initial_labels = gmm.predict(X)

        # Set NIW priors
        m_0 = np.zeros(X.shape[1])  # Mean vector for the prior
        
        v = np.zeros(d)
        for k in range(K):
            cluster_k = X[initial_labels == k]
            if len(cluster_k) == 0:
                raise ValueError(f"Cluster {k} has no points assigned.")
            v += np.var(cluster_k, axis=0) / K

        S_0 = np.diag(v)  # Shared across all clusters
        m_0 = np.zeros(d)  # Mean vector for the prior
        kappa_0 = 0.05

        # Initialise model using GMM labels as predictions
        model = ClusterModel(X, W, K, initial_labels, m_0, S_0, kappa_0=kappa_0)

        n_iter = 100
        burnin = 50

        samples, _ = run_gibbs(model, n_iter, burnin)
        z_pred = estimate_cluster(samples, K)

        # Compute ARI between true and inferred clusters
        ari_values[i, j] = adjusted_rand_score(z_true, z_pred)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Calculate median ARI and standard deviation
ari_medians = np.median(ari_values, axis=1)

# Plotting
ax1.plot(eta_values, ari_medians, 'o', color='orange')
ax1.set_xlabel(r'$\eta$')
ax1.set_ylabel('Median ARI')
ax1.grid(True)

# Calculate mean ARI and standard deviation
ari_medians = np.mean(ari_values, axis=1)

# Fit a linear regression to (eta_values, ari_medians)
slope, intercept, r_value, p_value, std_err = linregress(eta_values, ari_medians)

# Add line of best fit only below ARI = 1
fit_line = slope * eta_values + intercept

ax2.plot(eta_values, ari_medians, 'o', label='Mean ARI', color='orange')

ax2.set_xlabel(r'$\eta$')
ax2.set_ylabel('Mean ARI')
ax2.legend()

plt.savefig(PLOTS_DIR + 'ari_vs_eta.png')

fig.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(12, 6))
plt.boxplot(ari_values.T, positions=eta_values, widths=0.3, patch_artist=True,
            boxprops=dict(facecolor='orange', alpha=0.6), 
            medianprops=dict(color='black'), 
            flierprops=dict(marker='o', markerfacecolor='red', markersize=4, linestyle='none'))

plt.xlabel(r'$\eta$')
plt.ylabel('ARI')
plt.grid(True)
plt.tight_layout()
plt.savefig(PLOTS_DIR + 'ari_boxplot_vs_eta.png')
plt.show()

## Relationship between pairwise JS divergence of word distributions and ARI


In [None]:
# Flatten for plotting
flat_js = js_divergence.flatten()
flat_ari = ari_values.flatten()

plt.figure(figsize=(9, 6))
plt.scatter(flat_js, flat_ari, alpha=0.4, s=20, c='royalblue', edgecolors='none')

# Optional smoothing or trendline if desired
# import seaborn as sns
# sns.regplot(x=flat_js, y=flat_ari, lowess=True, scatter=False, color='black', line_kws={'linewidth': 1.5})

plt.xlabel(r'JS($\phi_1 \mid \phi_2)$', fontsize=12)
plt.ylabel('ARI', fontsize=12)
plt.grid(True, linestyle='--', alpha=0.4)
plt.tight_layout()
plt.savefig(PLOTS_DIR + 'ari_vs_js_scatter.png')
plt.show()

Using random initialisation, even for larger JS divergences, the model can sometimes be unable to recover the clusters.

In [None]:
# Experiment settings
N, d, K, V = 200, 2, 2, 100
cluster_sep, doc_length = 0.5, 100
sigma = 0.5

M = 50 # Number of simulations per eta
eta_values = np.arange(0.1, 10.5, 0.5)
ari_values = np.zeros((len(eta_values), M))
js_divergence = np.zeros((len(eta_values), M))
# Sample cluster labels
z_true = np.array([0]*100 + [1]*100)

# Loop over data-generating eta values
for i, eta in enumerate(eta_values):
    for j in range(M):
        # Generate synthetic data
        Y, means = generate_embeddings(N=N, separation=cluster_sep, sigma=sigma, z_true=z_true)
        W, word_probs = generate_words(N=N, V=V, doc_length=doc_length, eta=eta, z_true=z_true)
        js_divergence[i, j] = jensenshannon(word_probs[0], word_probs[1]) # use this as proxy for eta on later plot

        # Fixed NIW prior
        m_0 = np.mean(Y, axis=0)
        kappa_0 = 5
        nu_0 = d + 2
        S_0 = np.cov(Y, rowvar=False) * d

        # Initialise k means for initial cluster assignments
        kmeans = KNeighborsClassifier(n_neighbors=K)
        kmeans.fit(Y, z_true)
        labels_kmeans = kmeans.predict(Y)

        # Fit your model (using the true eta used for data generation)
        model = ClusterModel(Y, W, K, gamma=1.0, eta=eta, m_0=m_0, kappa_0=kappa_0, nu_0=nu_0, S_0=S_0, initial_z=labels_kmeans)
        sampler = CollapsedGibbsSampler(model, z_true)
        sampler.run(5000, 1e-4, 2000)

        # Compute ARI between true and inferred clusters
        ari_values[i, j] = adjusted_rand_score(z_true, model.z)

### Testing effectiveness of model with imbalanced clusters


In [None]:
# Constants
N = 200
K = 2
d, V = 2, 100
doc_length = 20
separation = 1.0
sigma = 0.5
eta = 1.0
gamma = 1.0
M = 30  # simulations per setting

# Cluster imbalance settings (as proportions for cluster 0)
cluster_0_props = np.linspace(0.5, 0.95, 10)  # from 50:50 to 95:5
ari_scores = np.zeros((len(cluster_0_props), M))

# Loop over imbalance settings
for i, p0 in enumerate(tqdm(cluster_0_props)):
    n0 = int(N * p0)
    n1 = N - n0
    z_true = np.array([0]*n0 + [1]*n1)

    for j in range(M):
        # Shuffle to avoid ordering effects
        np.random.shuffle(z_true)

        # Generate data
        Y, means = generate_embeddings(N=N, separation=separation, sigma=sigma, z_true=z_true)
        W, word_probs = generate_words(N=N, V=V, doc_length=doc_length, eta=eta, z_true=z_true)

        # Model priors
        m_0 = np.mean(Y, axis=0)
        kappa_0 = 5
        nu_0 = d + 2
        S_0 = np.cov(Y, rowvar=False) * d

        # Initial clustering (optional: GMM or KMeans)
        gmm = GaussianMixture(n_components=K, random_state=42)
        gmm.fit(Y)
        labels_init = gmm.predict(Y)

       # Set NIW priors
        m_0 = np.zeros(X.shape[1])  # Mean vector for the prior
        
        v = np.zeros(d)
        for k in range(K):
            cluster_k = X[labels_init == k]
            if len(cluster_k) == 0:
                raise ValueError(f"Cluster {k} has no points assigned.")
            v += np.var(cluster_k, axis=0) / K

        S_0 = np.diag(v)  # Shared across all clusters
        m_0 = np.zeros(d)  # Mean vector for the prior
        kappa_0 = 0.05

        # Initialise model using GMM labels as predictions
        model = ClusterModel(X, W, K, labels_init, m_0, S_0, kappa_0=kappa_0)

        n_iter = 100
        burnin = 50

        samples, _ = run_gibbs(model, n_iter, burnin)
        z_pred = estimate_cluster(samples, K)

        # Evaluate ARI
        ari_scores[i, j] = adjusted_rand_score(z_true, z_pred)

In [None]:
plt.figure(figsize=(8, 6))
mean_ari = np.mean(ari_scores, axis=1)
std_ari = np.std(ari_scores, axis=1)

lower_bound = np.clip(mean_ari - std_ari, 0, 1)
upper_bound = np.clip(mean_ari + std_ari, 0, 1)

plt.plot(cluster_0_props, mean_ari, marker='o', label='Mean ARI')
plt.fill_between(cluster_0_props, lower_bound, upper_bound, alpha=0.2)

plt.xlabel('Proportion of Nodes in Cluster 0')
plt.ylabel('Adjusted Rand Index (ARI)')
plt.ylim(0, 1.05)  # Slight buffer above 1 to prevent cutting off error bars visually
plt.grid(True)
plt.tight_layout()
plt.savefig(PLOTS_DIR + 'ari_vs_cluster_imbalance.png')
plt.show()