In [None]:
from typing import List
import numpy as np
import scipy.special as sp_spec
import scipy.stats as sp_stats
import numpy.random as np_rand
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score
import seaborn as sns
import tqdm

# Bernoulli Mixture Model
See the solution pdf for derivation of CAVI updates, ELBO and for notation used throughout the notebook. See theory part of Exercise problems for algorithm psuedo code and more. Updates are written using for-loops to more clearly show relation to derivations, however, you are recommended to use matrix operations.


### Generate data

In [None]:
def generate_data_and_priors(N: int, D: int, K: int, theta_a: float, theta_b: float, pi_alpha: List[int]):
  """
  Generates datapoints and hidden variables given parameters for the priors.
  :param N: Number of datapoints
  :param D: Data dimension.
  :param K: Number of mixture components.
  :param theta_a: "a" parameter of the Beta(a,b) prior on theta.
  :param theta_b: "b" parameter of the Beta(a,b) prior on theta.
  :param pi_alpha: parameter of the Dirichlet(alpha) prior on pi.
  :return:
  """
  theta = np_rand.beta(theta_a, theta_b, size=(K,D))  # K x D matrix
  pi = np_rand.dirichlet(pi_alpha) # size K array
  x = np.zeros((N,D)) # N x D matrix
  z = np.zeros(N)     # size N array 
  for n in range(N):
    z_n = np_rand.multinomial(1, pi).argmax() # for each datapoint n, sample the component assignment/class variable
    x_n = np.zeros(D)
    for d in range(D): 
      x_n[d] = np_rand.binomial(1, theta[z_n, d]) # Sample x_nd from Bernoulli(theta_kd)
    z[n] = z_n
    x[n, :] = x_n

  return x, z, theta, pi

def generate_data(N: int, D: int, K: int, theta: np.ndarray , pi: np.ndarray):
  """
   Generates datapoints given values of the hidden variables. Can be used for more "control" of experiments.
  :param N: Number of datapoints
  :param D: Data dimension.
  :param K: Number of mixture components.
  :param theta: Values of theta.
  :param pi: Value of pi.
  :return:
  """
  x = np.zeros((N,D)) # N x D matrix
  z = np.zeros(N)     # size N array
  for n in range(N):
    z_n = np_rand.multinomial(1, pi).argmax() # for each datapoint n, sample the component assignment/class variable
    x_n = np.zeros(D)
    for d in range(D):
      x_n[d] = np_rand.binomial(1, theta[z_n, d])
    z[n] = z_n
    x[n, :] = x_n

  return x, z, theta, pi

# Test the functions by running for some simple cases and verify that data and variables look as expected.
N = 1000
D = 3
K = 2

theta = np.array([[0.1, 0.7, 0.2], [0.6, 0.2, 0.2]])
pi = np.array([0.4, 0.6])
x, z, theta, pi = generate_data(N, D, K, theta, pi)
print(f"First 3 datapoints: {x[0:3, 0:5]}")
print(f"First 3 z: {z[0:3]}")

x_mean = np.mean(x, axis=0)
print(f"Mean of x: {x_mean}")
print(f"Expected x: {pi[0] * theta[0] + pi[1] * theta[1]}")
z_one_hot = np.eye(K)[z.astype(int)]
print(f"Mean of z: {np.mean(z_one_hot, axis=0)}")
print(f"Expected z: {pi}")

#### Visualize generated data

In [None]:
import matplotlib.pyplot as plt
from itertools import product
import pandas as pd # Using pandas for easy counting and plotting

# --- This code cell visualizes the data generated in the cell above ---

# We will use the variables: x, z, theta, pi, K, D, N

# 1. Create the figure with 3 subplots
fig = plt.figure(figsize=(14, 12))
fig.suptitle('Bernoulli Mixture Model Data Visualization', fontsize=20)

# 2. Plot True Mixture Weights (pi)
ax_pi = fig.add_subplot(2, 2, 1)
colors = plt.cm.jet(np.linspace(0, 1, K))
ax_pi.bar(range(K), pi, color=colors, alpha=0.8)
ax_pi.set_title('True Mixture Weights ($\pi$)', fontsize=14)
ax_pi.set_xlabel('Cluster (k)', fontsize=12)
ax_pi.set_ylabel('Probability', fontsize=12)
ax_pi.set_xticks(range(K))
ax_pi.set_ylim(0, 1.0)
for i, v in enumerate(pi):
    ax_pi.text(i, v + 0.02, f'{v:.2f}', ha='center', fontweight='bold')

# 3. Plot True Cluster Probabilities (theta)
ax_theta = fig.add_subplot(2, 2, 2)
im = ax_theta.imshow(theta, cmap='viridis', aspect='auto', vmin=0, vmax=1)
ax_theta.set_title('True Cluster Probabilities ($\theta$)', fontsize=14)
ax_theta.set_xlabel('Feature (d)', fontsize=12)
ax_theta.set_ylabel('Cluster (k)', fontsize=12)
ax_theta.set_xticks(range(D))
ax_theta.set_yticks(range(K))
# Add text labels for the probabilities
for k in range(K):
    for d in range(D):
        color = 'white' if theta[k, d] < 0.5 else 'black'
        ax_theta.text(d, k, f'{theta[k, d]:.2f}', ha='center', va='center', color=color, fontweight='bold')
fig.colorbar(im, ax=ax_theta, label='P(x_d = 1 | z=k)', fraction=0.046, pad=0.04)


# 4. Plot the distribution of the generated data (x)
ax_data = fig.add_subplot(2, 1, 2)

# To visualize the data, we will count the occurrences of each unique binary pattern
# and group them by their true cluster assignment (z)

# Create a DataFrame for easy grouping
df = pd.DataFrame(x, columns=[f'd{i}' for i in range(D)])
df['z'] = z.astype(int)

# Create a 'pattern' column (e.g., '010') for easy counting
df['pattern'] = df.apply(lambda row: ''.join(row.iloc[:D].astype(int).astype(str)), axis=1)

# Count occurrences of each pattern, grouped by cluster
counts = df.groupby('pattern')['z'].value_counts().unstack(fill_value=0)

# Ensure all 2^D patterns are present, even if count is 0
all_patterns = [''.join(map(str, p)) for p in product([0, 1], repeat=D)]
counts = counts.reindex(all_patterns, fill_value=0)
counts.sort_index(inplace=True)

# Plot the grouped bar chart
counts.plot(kind='bar', ax=ax_data, color=colors, alpha=0.8, rot=0)

ax_data.set_title(f'Distribution of {N} Simulated Data Points (D={D})', fontsize=16)
ax_data.set_xlabel('Binary Pattern (d0, d1, d2)', fontsize=12)
ax_data.set_ylabel('Count', fontsize=12)
ax_data.legend(title='True Cluster (z)', loc='upper left')
ax_data.grid(axis='y', linestyle='--', alpha=0.7)

# Add count labels on top of bars
for c in ax_data.containers:
    ax_data.bar_label(c, label_type='edge', fontsize=9, padding=2)

plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust for main title
plt.show()

# CAVI Algorithm

## CAVI updates

### $q(\pi)$ update

In [None]:
def update_q_pi(E_Z, alpha_prior):
    alpha_star = np.sum(E_Z, axis=0) + alpha_prior  # size K array
    return alpha_star

### $q(\theta)$ update

In [None]:
def update_q_theta(x, r_q, a_prior, b_prior):
    E_Z = r_q
    N, D = x.shape
    K = r_q.shape[1]

    # With Einsum
    a_star = np.einsum('nk,nd->kd', E_Z, x) + a_prior
    b_star = np.einsum('nk,nd->kd', E_Z, 1 - x) + b_prior

    # Without Einsum
    # a_star = np.zeros((K, D))
    # b_star = np.zeros((K, D))
    # for k in range(K):
    #     for d in range(D):
    #         a_kd = np.sum(E_Z[:,k] * x[:,d]) + a_prior
    #         b_kd = np.sum(E_Z[:,k] * (1 - x[:, d])) + b_prior
    #         a_star[k, d] = a_kd
    #         b_star[k, d] = b_kd

    return a_star, b_star

### $q(Z)$ update


In [None]:
def update_q_Z(x, a_q, b_q, alpha_q):
    """
    Implements the CAVI update equation of q(Z) derived in the solution of the exercise.
    Returns: r_star: N x K matrix
    """
    E_log_theta = sp_spec.digamma(a_q) - sp_spec.digamma(a_q + b_q)         # K x D
    E_log_1_minus_theta = sp_spec.digamma(b_q) - sp_spec.digamma(a_q + b_q) # K x D
    E_log_pi = sp_spec.digamma(alpha_q) - sp_spec.digamma(np.sum(alpha_q))  # K array

    log_rho = (np.einsum('nd,kd->nk', x, E_log_theta)
               + np.einsum('nd,kd->nk', 1 - x, E_log_1_minus_theta)
               + E_log_pi)

    r_star = np.exp(log_rho - sp_spec.logsumexp(log_rho, axis=1, keepdims=True))
    return r_star

## ELBO

### ELBO closed form calculation

In [None]:
def multivariate_beta_function(a):
    eps = 0.0001  # for numerical stability to avoid -inf
    return np.exp(np.sum(sp_spec.gammaln(a + eps) - sp_spec.gammaln(np.sum(a + eps))))

def calculate_elbo(x, r_q, a_q, b_q, alpha_q):
    N, D = x.shape
    K = r_q.shape[1]
    E_log_theta = sp_spec.digamma(a_q) - sp_spec.digamma(a_q + b_q)
    E_log_1_minus_theta = sp_spec.digamma(b_q) - sp_spec.digamma(a_q + b_q)
    E_log_pi = sp_spec.digamma(alpha_q) - sp_spec.digamma(np.sum(alpha_q))
    alpha_0 = np.sum(alpha_q)
    CR_likelihood = np.sum(r_q * (np.einsum('nd,kd->nk', x, E_log_theta)
                                  + np.einsum('nd,kd->nk', 1 - x, E_log_1_minus_theta)))
    CR_Z = np.sum(r_q * E_log_pi)
    eps = 0.0001  # for numerical stability to avoid -inf
    CR_pi = -np.log(multivariate_beta_function(alpha_q) + eps) + np.sum((alpha_q - 1) * E_log_pi)
    CR_theta = np.sum((a_q-1) * E_log_theta + (b_q-1) * E_log_1_minus_theta - sp_spec.beta(a_q, b_q))
    H_Z = np.sum(r_q * np.log(r_q + eps))
    H_pi = np.log(multivariate_beta_function(alpha_q) + eps) + (alpha_0 - K) * sp_spec.digamma(alpha_0) - \
           np.sum((alpha_q - 1) * sp_spec.digamma(alpha_q))
    Beta_ab = sp_spec.beta(a_q, b_q) + eps
    H_theta = np.sum(np.log(Beta_ab) - (a_q - 1) * sp_spec.digamma(a_q) -
                     (b_q - 1) * sp_spec.digamma(b_q) + (a_q + b_q + 2) * sp_spec.digamma(a_q + b_q))
    elbo = CR_likelihood + CR_Z + CR_pi + CR_theta + H_Z + H_pi + H_theta
    return elbo

## Initialization

In [None]:
def initialize_q(x, K):
    #r_q_init = np.random.random((N,K))
    N, D = x.shape
    kmeans =  KMeans(n_clusters=K, random_state=0, n_init=30).fit(x)
    labels = kmeans.labels_
    a_q_init = np.zeros((K, D))
    b_q_init = np.zeros((K, D))
    for k in range(K):
        x_k = x[labels == k]
        a_q_init[k, :] = np.mean(x_k, axis=0) + 1
        b_q_init[k, :] = np.mean(1 - x_k, axis=0) + 1

    r_q_init = np.random.rand(N, K)
    r_q_init = r_q_init / np.sum(r_q_init, axis=1, keepdims=True)  # normalize
    a_q_init = np.random.randint(1, 100, size=(K,D))
    b_q_init = np.random.randint(1, 100, size=(K,D))
    alpha_q_init = np.ones(K)
    return r_q_init, a_q_init, b_q_init, alpha_q_init

## CAVI algorithm

In [None]:
def CAVI_algorithm(x, K, n_iter, a_prior, b_prior, alpha_prior, step_size=0.01, tol=1e-6):
    N, D = x.shape

    r_q, a_q, b_q, alpha_q = initialize_q(x, K)
    elbo_after_init = calculate_elbo(x, r_q, a_q, b_q, alpha_q)
    # Store output per iteration
    elbo = [elbo_after_init]
    r_q_out = [r_q]
    a_q_out = [a_q]
    b_q_out = [b_q]
    alpha_q_out = [alpha_q]

    pbar = tqdm.tqdm(range(n_iter))
    for i in pbar:
        # CAVI updates
        # q(Z) update
        r_star = update_q_Z(x, a_q, b_q, alpha_q)
        r_q += step_size * (r_star - r_q)
        r_q = r_q / np.sum(r_q, axis=1, keepdims=True)  # normalize

        # q(pi) update
        alpha_star = update_q_pi(r_q, alpha_prior)
        alpha_q += step_size * (alpha_star - alpha_q)

        # q(theta) update
        a_star, b_star = update_q_theta(x, r_q, a_prior, b_prior)
        a_q = a_q + step_size * (a_star - a_q)
        b_q = b_q + step_size * (b_star - b_q)

        # ELBO
        elbo.append(calculate_elbo(x, r_q, a_q, b_q, alpha_q))

        # outputs
        r_q_out.append(r_q)
        a_q_out.append(a_q)
        b_q_out.append(b_q)
        alpha_q_out.append(alpha_q)

        pbar.set_description(f"ELBO: {elbo[i]:.2f}")

        if i > 1 and np.abs(elbo[i] - elbo[i-1]) < tol:
            break

    r_q_out = np.array(r_q_out)
    a_q_out = np.array(a_q_out)
    b_q_out = np.array(b_q_out)
    alpha_q_out = np.array(alpha_q_out)
    elbo = np.array(elbo)
    out = {"r_q": r_q_out, "a_q": a_q_out, "b_q": b_q_out, "alpha_q": alpha_q_out, "elbo": elbo}
    return out

## Run optimization on simulated data

In [None]:
N = 1000
D = 100
K = 5
n_iter = 3000
a_prior = 1.
b_prior = 1.
alpha_prior = np.ones(K) * 1.0
# theta = np.array([[0.1, 0.9, 0.0, 0.0], [0.0, 0.05, 0.45, 0.5]])
# pi = np.array([0.6, 0.4])
# x, z, theta, pi = generate_data(N, D, K, theta, pi)
x, z, theta, pi = generate_data_and_priors(N, D, K, a_prior, b_prior, alpha_prior)
out = CAVI_algorithm(x, K, n_iter, a_prior, b_prior, alpha_prior, step_size=0.01, tol=1e-6)
r_q_out = out["r_q"]
a_q_out = out["a_q"]
b_q_out = out["b_q"]
alpha_q_out = out["alpha_q"]
elbo = out["elbo"]
np.printoptions(precision=2)
print(f"Print results (check for label switching).")
print(f"z :{z[0:10]}")
# As q(Z) is our variational posterior and is a Categorical, argmax corresponds to our MAP estimates
MAP_assignments = r_q_out[-1].argmax(axis=1)
print(f"r_q :{MAP_assignments[0:10]}")
print(f"ARI :{adjusted_rand_score(z, MAP_assignments)}")  # clustering metric that is independent of label switching

print(f"Expected value pi: {alpha_q_out[-1] / np.sum(alpha_q_out[-1])}")
print(f"True pi: {pi}")

In [None]:
elbo = out["elbo"]
plt.plot(elbo)
plt.title("ELBO")
plt.xlabel("Iteration")
plt.ylabel("ELBO")

### Visualize optimization

Plots the parameters of $q(\theta_{kd})$ for a set of k and a particular d for different iterations.

The corresponding Beta distribution is also plotted for the same parameters.

This is to illustrate when we are updating the parameters during the CAVI optimization, we are implicitly updating the optimal $q(\theta_{dk})$ distribution that we derived on the board.  

In [None]:
xlim_a_min = a_q_out.min()
xlim_a_max = a_q_out.max()
ylim_b_min = b_q_out.min()
ylim_b_max = b_q_out.max()
iterations_to_plot = [0, 10, -1]
n_iter_to_plot = len(iterations_to_plot)
fig, axs = plt.subplots(2, n_iter_to_plot, figsize=(10, 6))
d_to_plot = 0  # a_kd
k_to_plot = range(0, K)
for i in range(0, n_iter_to_plot):
    j = iterations_to_plot[i]
    axs[0, i].scatter(a_q_out[j, k_to_plot, d_to_plot], b_q_out[j, k_to_plot, d_to_plot])
    axs[0, i].set_title(f"q(theta) a, b: {j}")
    axs[0, i].set_xlim(int(xlim_a_min) - 100, int(xlim_a_max) + 100)
    axs[0, i].set_ylim(int(ylim_b_min) - 100, int(ylim_b_max) + 100)
    
    for k in k_to_plot:
        q_theta_kd = sp_stats.beta(a_q_out[i, k, d_to_plot], b_q_out[i, k, d_to_plot])
        pi_axis = np.linspace(0, 1, 100)
        axs[1, i].plot(pi_axis, q_theta_kd.pdf(pi_axis))

    if i == 0:
        axs[0, i].set_xlabel('a')
        axs[0, i].set_ylabel('b')
        axs[1, i].set_ylabel('N samples')

plt.show()

In [None]:
from matplotlib.animation import FuncAnimation
from IPython.display import HTML  # For display in Jupyter/Colab

# --- 1. Create Mock Data (Replace this with your data) ---
# This simulates 50 iterations of CAVI for K=3 clusters and d=1 feature.
# We assume the parameters (a, b) start at (1, 1) and converge.
n_iterations = 50
K = 3  # Number of clusters
d_features = 1  # Number of features

# Define convergence targets for our 3 clusters
a_targets = np.array([50.0, 10.0, 80.0])
b_targets = np.array([20.0, 70.0, 40.0])

# Create empty arrays
a_q_out = np.zeros((n_iterations, K, d_features))
b_q_out = np.zeros((n_iterations, K, d_features))

# Fill with data converging from (1, 1) to the targets
for k in range(K):
    # Use linspace for a smooth convergence path
    a_path = np.linspace(1.0, a_targets[k], n_iterations)
    b_path = np.linspace(1.0, b_targets[k], n_iterations)

    # Add some random noise to make it look like a real optimization
    a_noise = np.random.randn(n_iterations) * (n_iterations - np.arange(n_iterations)) / n_iterations * 0.5
    b_noise = np.random.randn(n_iterations) * (n_iterations - np.arange(n_iterations)) / n_iterations * 0.5

    a_q_out[:, k, 0] = a_path + a_noise
    b_q_out[:, k, 0] = b_path + b_noise

# Ensure parameters are > 0 (Beta params must be positive)
a_q_out[a_q_out < 0.1] = 0.1
b_q_out[b_q_out < 0.1] = 0.1
# --- End of Mock Data Section ---


# --- 2. Setup Animation Constants and Data ---
d_to_plot = 0  # The feature dimension to plot (same as your script)
K = a_q_out.shape[1] # Number of clusters
n_iterations = a_q_out.shape[0] # Total number of iterations
k_to_plot = range(0, K)

# Generate distinct colors for each cluster
colors = plt.cm.jet(np.linspace(0, 1, K))
pi_axis = np.linspace(0, 1, 100) # Axis for plotting PDFs

# --- 3. Calculate Global Axis Limits (for stable animation) ---
xlim_a_min = a_q_out[:, :, d_to_plot].min()
xlim_a_max = a_q_out[:, :, d_to_plot].max()
ylim_b_min = b_q_out[:, :, d_to_plot].min()
ylim_b_max = b_q_out[:, :, d_to_plot].max()

# Add 10% padding
a_pad = (xlim_a_max - xlim_a_min) * 0.1
b_pad = (ylim_b_max - ylim_b_min) * 0.1

xlim_a = (xlim_a_min - a_pad, xlim_a_max + a_pad)
ylim_b = (ylim_b_min - b_pad, ylim_b_max + b_pad)

# --- 4. Setup the Figure ---
# We now use a 2x1 grid, as we are animating one state over time
fig, (ax_params, ax_pdf) = plt.subplots(2, 1, figsize=(8, 10))
plt.tight_layout(pad=4.0) # Add padding for titles

# --- 5. Create Empty Plot "Artists" ---
# We create the plot elements once, and the animation will just update their data.
# This is much faster than clearing and redrawing.

# Top plot (Parameters)
param_trails = [] # The lines for the history
param_heads = []  # The scatter points for the current position
for k in k_to_plot:
    # Trail line
    line, = ax_params.plot([], [], marker='o', markersize=2, alpha=0.3, color=colors[k], label=f'Cluster {k}')
    param_trails.append(line)
    # Head point
    head, = ax_params.plot([], [], marker='o', markersize=8, color=colors[k], markeredgecolor='black')
    param_heads.append(head)

ax_params.set_xlim(xlim_a)
ax_params.set_ylim(ylim_b)
ax_params.set_xlabel('Parameter a')
ax_params.set_ylabel('Parameter b')
ax_params.set_title('q(theta) Parameter Convergence')
ax_params.legend()
ax_params.grid(True, linestyle='--', alpha=0.5)

# Bottom plot (PDFs)
pdf_lines = []
for k in k_to_plot:
    line, = ax_pdf.plot([], [], color=colors[k], lw=2, label=f'Cluster {k}')
    pdf_lines.append(line)

ax_pdf.set_xlim(0, 1)
# We'll set the Y-limit dynamically in the update function, or you can find the global max
ax_pdf.set_xlabel('$\pi$ value') # Use LaTeX for pi
ax_pdf.set_ylabel('Probability Density')
ax_pdf.set_title('q(theta) PDF')
ax_pdf.legend()
ax_pdf.grid(True, linestyle='--', alpha=0.5)

# Global iteration title
iter_title = fig.suptitle('Iteration 0', fontsize=16)

# --- 6. Define Animation Update Function ---
# This function is called for each frame (iteration)
def update(i):
    max_pdf_y = 0 # To dynamically set Y-limit for the PDF plot

    for k in k_to_plot:
        # --- Update Top Plot (Parameters) ---

        # Get history up to current iteration i
        a_history = a_q_out[0:i+1, k, d_to_plot]
        b_history = b_q_out[0:i+1, k, d_to_plot]

        # Update the trail line
        param_trails[k].set_data(a_history, b_history)

        # Update the head point (THIS IS THE CORRECTED LINE)
        param_heads[k].set_data([a_history[-1]], [b_history[-1]])

        # --- Update Bottom Plot (PDFs) ---

        # Get current parameters
        a_current = a_q_out[i, k, d_to_plot]
        b_current = b_q_out[i, k, d_to_plot]

        # Calculate new PDF
        q_theta_kd = sp_stats.beta(a_current, b_current)
        pdf_values = q_theta_kd.pdf(pi_axis)

        # Update the PDF line
        pdf_lines[k].set_data(pi_axis, pdf_values)

        # Track the max PDF height to adjust the y-axis
        if np.any(pdf_values > max_pdf_y):
            max_pdf_y = np.max(pdf_values[np.isfinite(pdf_values)]) # Ignore inf

    # Update the PDF y-axis limit
    ax_pdf.set_ylim(0, max_pdf_y * 1.1)

    # Update the main title
    iter_title.set_text(f'CAVI Iteration {i}')

    # Return all the artists that were updated
    return param_trails + param_heads + pdf_lines + [iter_title]

# --- 7. Create and Display Animation ---

# Create the animation object
# interval=100 means 100ms per frame (10 FPS)
# blit=False is often more stable, though blit=True is faster if it works
ani = FuncAnimation(fig, update, frames=n_iterations,
                    interval=100, blit=False)

# To display in Jupyter Notebook or Google Colab:
plt.close(fig) # Prevent static plot from showing
HTML(ani.to_jshtml())

# To display in a standard Python script, uncomment the following line:
# plt.show()

# To save the animation as a GIF (requires 'imagemagick' or 'pillow'):
# print("Saving animation... (this may take a moment)")
# ani.save('cavi_animation.gif', writer='pillow', fps=10)
# print("Done.")

## MNIST experiment

In [None]:
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784')
X = mnist.data
y = mnist.target
X = X.to_numpy()

selected_digits = [0, 1, 2, 3, 4]
train_filter = []
N_k = [1000, 1000, 1000, 1000, 1000]
for k,s in enumerate(selected_digits):
    train_filter.append(np.where(y == str(s))[0][0:N_k[k]])

train_filter = np.concatenate(train_filter)
X = X[train_filter, :]
y = y[train_filter]
N = len(train_filter)
X = X[0:N, :]
X = (X > 127.5).astype(int)  # binarize

In [None]:
N, D = X.shape
K = len(selected_digits)
n_iter = 500
a_prior = 1.
b_prior = 1.
alpha_prior = np.ones(K) * N/K

print(f"Priors: {alpha_prior}")
print(f"Number of iterations: {n_iter}")
print(f"Number of datapoints: {N}")
out = CAVI_algorithm(X, K, n_iter,
                     a_prior, b_prior, alpha_prior,
                     step_size=0.1)

In [None]:
elbo = out["elbo"]
plt.plot(elbo)
plt.title("ELBO")
plt.xlabel("Iteration")
plt.ylabel("ELBO")

In [None]:
r_q_out = out["r_q"]
a_q_out = out["a_q"]
b_q_out = out["b_q"]
alpha_q_out = out["alpha_q"]
elbo = out["elbo"]
np.printoptions(precision=2)

sklearn_kmeans = KMeans(n_clusters=K, random_state=0, n_init=30).fit(X)
labels = sklearn_kmeans.labels_
print(f"ARI KMeans:{adjusted_rand_score(y[0:N], labels)}")  # clustering metric that is independent of label switching

MAP_assignments = r_q_out[-1].argmax(axis=1)
print(f"ARI CAVI:{adjusted_rand_score(y[0:N], MAP_assignments)}")

y_labels, y_counts = np.unique(y[0:N], return_counts=True)
print(f"True pi: {y_counts / np.sum(y_counts)}")
print(f"Expected value pi: {alpha_q_out[-1] / np.sum(alpha_q_out[-1])}")

In [None]:
for k in range(K):
    plt.subplot(2, int(K/2) + (K % 2), k + 1)
    E_theta_k = a_q_out[-1][k] / (a_q_out[-1][k] + b_q_out[-1][k])
    plt.imshow(E_theta_k.reshape(28, 28), cmap="gray")
    plt.axis("off")

### Visualize learned cluster parameters

In [None]:
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML  # For display in Jupyter/Colab

# --- Assume a_q_out and b_q_out are populated ---
# Example dummy data (replace with your actual data):
# import numpy as np
# L_iterations = 50  # Number of iterations
# K_components = 10  # Number of components
# a_q_out = [np.random.rand(K_components, 28*28) for _ in range(L_iterations)]
# b_q_out = [np.random.rand(K_components, 28*28) for _ in range(L_iterations)]
# -------------------------------------------------

# 1. Get total iterations (L) and components (K)
L = 200 #len(a_q_out)
K = len(a_q_out[0])

# 2. Set up the figure and subplot grid
# We create a 2-row grid, calculating columns needed
nrows = 2
ncols = (K + 1) // 2  # This is a cleaner way to get (K/2) rounded up
fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 2, nrows * 2 + 0.5))
axes_flat = axes.flatten() # Flatten grid for easy indexing

# Turn off any extra, unused subplots
for k in range(K, nrows * ncols):
    axes_flat[k].axis("off")

# 3. Define the update function (called for each frame)
def update(i):
    # 'i' is the frame number, which we use as the iteration index
    for k in range(K):
        ax = axes_flat[k]
        ax.clear()  # Clear the axis for the new frame

        # Calculate E_theta_k for iteration 'i' and component 'k'
        a_k_i = a_q_out[i][k]
        b_k_i = b_q_out[i][k]
        E_theta_k = a_k_i / (a_k_i + b_k_i)

        # Plot the image
        ax.imshow(E_theta_k.reshape(28, 28), cmap="gray", vmin=0, vmax=1)
        ax.set_title(f"k = {k}", fontsize=10)
        ax.axis("off")

    # Add a title to the whole figure showing the iteration
    fig.suptitle(f"Iteration {i + 1} / {L}", fontsize=14)
    fig.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout for suptitle

# 4. Create the animation
# interval=200 means 200ms per frame
ani = FuncAnimation(fig, update, frames=L, interval=200, blit=False)

# 5. Display the animation (in Jupyter/Colab)
plt.close(fig) # Prevent the static plot from showing
#HTML(ani.to_html5_video())
HTML(ani.to_jshtml())



In [None]:
# Plot most certain vs uncertain images
fig, axs = plt.subplots(2, K, figsize=(20, 10))
np.set_printoptions(formatter={'float': lambda x: "{0:0.2f}".format(x)})
for k in range(K):
    idx_qZ_k = np.where(MAP_assignments==k)[0]
    max_entropy_idx = np.argmax(sp_stats.entropy(r_q_out[-1][idx_qZ_k], axis=1))
    print(f"Probability of most uncertain sample of cluster {k}: {r_q_out[-1][idx_qZ_k][max_entropy_idx]}")
    min_entropy_idx = np.argmin(sp_stats.entropy(r_q_out[-1][idx_qZ_k], axis=1))
    print(f"Probability of most certain sample of cluster {k}: {r_q_out[-1][idx_qZ_k][min_entropy_idx]}")
    max_entr_img = X[idx_qZ_k][max_entropy_idx]
    min_entr_img = X[idx_qZ_k][min_entropy_idx]
    axs[0, k].imshow(max_entr_img.reshape(28, 28), cmap="gray")
    axs[1, k].imshow(min_entr_img.reshape(28, 28), cmap="gray")