# Probabilistic Machine Learning: Lecture 24 - Variational Inference

#### Introduction

Welcome to Lecture 24 of Probabilistic Machine Learning! In our previous lecture, we delved into the Expectation-Maximization (EM) algorithm, a powerful tool for finding maximum likelihood (ML) or maximum a posteriori (MAP) estimates in models with latent variables. EM works by iteratively maximizing a lower bound on the model evidence, specifically by setting the variational distribution $q(z)$ to be the true posterior $p(z|x, \theta)$ at each E-step.

However, what if the true posterior $p(z|x, \theta)$ itself is intractable? This is where **Variational Inference (VI)** comes into play. VI offers a more general and flexible framework for approximating intractable posterior distributions by directly optimizing a lower bound on the evidence. Unlike EM, VI aims to find the *best approximate distribution* $q(z)$ from a chosen family of distributions, rather than relying on the exact posterior in the E-step.

Given your strong interest in Variational Inference, this notebook will provide extensive background, covering the theoretical foundations, the core concepts of mean-field approximation, and a detailed practical implementation using **JAX** for numerical computations and **Plotly** for interactive visualizations. We will focus on its application to **Variational Gaussian Mixture Models (VGMMs)**, illustrating how VI provides a fully Bayesian treatment compared to EM's point estimates.

#### 1. Recap: EM Algorithm - General Form

Let's briefly recap the general form of the EM algorithm (Slide 2), as it provides the conceptual stepping stone to Variational Inference. Our goal is to find the maximum likelihood (or MAP) estimate for parameters $\theta$ in a model involving observed data $x$ and latent variables $z$:

$$\theta_{*} = \arg \max_{\theta} [\log p(x | \theta)] = \arg \max_{\theta} [\log (\int p(x, z | \theta) dz)]$$

The EM algorithm iteratively optimizes this by performing:

1.  **E-step**: Compute $q(z) = p(z|x, \theta_{old})$, which effectively sets the KL divergence $D_{KL}(q || p(z|x, \theta_{old}))$ to zero, making the Evidence Lower Bound (ELBO) tight at the current $\theta_{old}$.
2.  **M-step**: Set $\theta_{new}$ to maximize the ELBO $\mathcal{L}(q, \theta) = \int q(z) \log \left( \frac{p(x, z | \theta)}{q(z)} \right) dz$.

The crucial point here is that the E-step *requires* the true posterior $p(z|x, \theta)$ to be tractable. When it's not, we need a more general approach.

#### 2. What if we cannot compute the posterior $p(z|x, \theta)$ analytically?

This is the central question that motivates Variational Inference (Slide 5). If the E-step of EM is intractable because $p(z|x, \theta)$ cannot be computed, we need an alternative. Instead of forcing $q(z)$ to be the true posterior, VI proposes to directly optimize the ELBO $\mathcal{L}(q, \theta)$ with respect to *both* $q(z)$ and $\theta$.

Recall the decomposition of the log evidence:
$$\log p(x | \theta) = \mathcal{L}(q, \theta) + D_{KL}(q || p(z | x, \theta))$$

where $\mathcal{L}(q, \theta) = \int q(z) \log \left( \frac{p(x, z | \theta)}{q(z)} \right) dz$. Since $D_{KL}(q || p(z | x, \theta)) \ge 0$, maximizing $\mathcal{L}(q, \theta)$ with respect to $q(z)$ (for a fixed $\theta$) is equivalent to minimizing the KL divergence between $q(z)$ and the true posterior $p(z|x, \theta)$. Thus, by maximizing the ELBO, we find an approximation $q(z)$ that is as close as possible to the true posterior within a chosen family of distributions $Q$.

This is an **optimization in the space of probability distributions $q$**. How does one find a function (a probability distribution) that minimizes a functional? This is the domain of the calculus of variations, but VI provides a practical framework.

#### 3. Factorizing Approximations: Mean Field Theory

Directly optimizing over the entire space of distributions $Q$ is generally intractable. The key simplifying assumption in many VI applications is the **mean-field approximation** (Slides 7-8). We assume that the variational distribution $q(z)$ factorizes into independent distributions over subsets of the latent variables:

$$q(z) = \prod_{i=1}^M q_i(z_i)$$

where $z_i$ are disjoint partitions of the latent variables $z$. This assumption simplifies the optimization problem significantly. With this factorization, we can derive an iterative update rule for each factor $q_j(z_j)$ by holding all other factors $q_i(z_i)$ (for $i \ne j$) fixed.

The optimal form for each factor $q_j^*(z_j)$ is given by (Slide 8):

$$\log q_j^*(z_j) = \mathbb{E}_{q, i \ne j} [\log p(x, z)] + \text{const.}$$ 

This means that the logarithm of the optimal distribution for $z_j$ is the expectation of the complete-data log-likelihood with respect to all other latent variables, averaged under their current variational distributions $q_i(z_i)$. This iterative update process is a form of **coordinate ascent in distribution space**, and it is guaranteed to converge to a local optimum of the ELBO.

In physics, this technique is known as **mean field theory**, where a complex many-body problem is approximated by considering the effect of a single particle interacting with the "mean field" created by all other particles. This analogy beautifully captures the essence of the factorization assumption in VI.

### Practical Steps for Constructing a VI Algorithm (Slide 10):

1.  **Write down the log joint distribution** $\log p(x, z)$.
2.  **Decide on a factorization** for the variational distribution $q(z) = \prod_i q_i(z_i)$. This is a crucial modeling choice that balances tractability and expressiveness.
3.  **Inspect the algebraic form** of $\log q_j^*(z_j) = \mathbb{E}_{q, i \ne j} [\log p(x, z)] + \text{const.}$ and **identify the type of distribution** $q_j$ (e.g., Gaussian, Dirichlet, Categorical). This often involves recognizing the functional form of an exponential family.
4.  Once all $q_j^*$ are identified by type, **find analytic expressions for the expectations** $\mathbb{E}_{q, i \ne j} [\log p(x, z)]$. This is where the "ELBOW grease" comes in, but for models with conjugate priors, these expectations often simplify nicely.

In [None]:
import jax.numpy as jnp
import jax
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from jax.scipy.stats import multivariate_normal as mvn
from jax.scipy.special import digamma, gammaln # For Dirichlet and Wishart expectations

# Set JAX to use 64-bit floats for numerical stability
jax.config.update("jax_enable_x64", True)

# --- Utility Functions (from Lecture 23, adapted) ---
def plot_gmm_plotly(X, means, covariances, responsibilities=None, title="Gaussian Mixture Model", colors=['red', 'blue', 'green', 'purple', 'orange']):
    """Plots 2D data, GMM components, and optionally responsibilities."""
    fig = go.Figure()

    # Plot data points, optionally colored by responsibility
    if responsibilities is not None:
        dominant_component = jnp.argmax(responsibilities, axis=1)
        for k in range(means.shape[0]):
            mask = dominant_component == k
            fig.add_trace(go.Scatter(
                x=X[mask, 0],
                y=X[mask, 1],
                mode='markers',
                marker=dict(color=colors[k % len(colors)], size=5, opacity=0.7),
                name=f'Data (Component {k+1})',
                showlegend=True
            ))
    else:
        fig.add_trace(go.Scatter(
            x=X[:, 0],
            y=X[:, 1],
            mode='markers',
            marker=dict(color='gray', size=5, opacity=0.7),
            name='Data Points',
            showlegend=True
        ))

    # Plot Gaussian components (mean and covariance ellipses)
    for k in range(means.shape[0]):
        mean = means[k]
        cov = covariances[k]
        
        # Draw ellipse representing 2-sigma contour
        vals, vecs = jnp.linalg.eigh(cov)
        order = vals.argsort()[::-1]
        vals = vals[order]
        vecs = vecs[:, order]
        
        theta = jnp.degrees(jnp.arctan2(*vecs[:, 0][::-1]))
        width, height = 2 * jnp.sqrt(5.991 * vals) # 5.991 for 95% confidence for 2D Chi-squared with 2 DOF

        fig.add_shape(
            type='circle',
            xref='x',
            yref='y',
            x0=mean[0] - width / 2,
            y0=mean[1] - height / 2,
            x1=mean[0] + width / 2,
            y1=mean[1] + height / 2,
            line=dict(color=colors[k % len(colors)], width=2),
            opacity=0.8,
            layer='below',
            name=f'Component {k+1}'
        )
        # Add mean point
        fig.add_trace(go.Scatter(
            x=[mean[0]],
            y=[mean[1]],
            mode='markers',
            marker=dict(symbol='x', size=10, color=colors[k % len(colors)], line=dict(width=2, color='black')),
            name=f'Mean {k+1}',
            showlegend=False
        ))

    fig.update_layout(title_text=title, title_x=0.5,
                      xaxis_title='Feature 1',
                      yaxis_title='Feature 2',
                      autosize=False, width=700, height=600)
    fig.update_yaxes(scaleanchor="x", scaleratio=1) # Keep aspect ratio square
    fig.show()

def generate_gmm_data(num_samples=300, num_components=3, random_seed=42):
    """Generates synthetic data from a Gaussian Mixture Model."""
    np.random.seed(random_seed)
    
    # True parameters for 3 components
    true_weights = np.array([0.3, 0.4, 0.3])
    true_means = np.array([
        [0, 0],
        [3, 3],
        [0, 4]
    ])
    true_covariances = np.array([
        [[0.5, 0.2], [0.2, 0.5]],
        [[0.8, -0.1], [-0.1, 0.8]],
        [[0.6, 0.3], [0.3, 0.6]]
    ])

    X = []
    for _ in range(num_samples):
        # Choose a component based on weights
        k = np.random.choice(num_components, p=true_weights)
        # Sample from the chosen Gaussian
        sample = np.random.multivariate_normal(true_means[k], true_covariances[k])
        X.append(sample)
    
    return jnp.array(X), true_weights, true_means, true_covariances


#### 4. Example: Variational Gaussian Mixture Models (VGMMs)

Let's apply Variational Inference to the Gaussian Mixture Model (GMM) that we previously tackled with EM. This will give us a **Variational GMM (VGMM)**, which provides a full Bayesian treatment of the GMM parameters (weights, means, covariances), rather than just point estimates (Slides 11-20).

### Full Bayesian GMM Model:

To perform Bayesian inference, we place priors over the GMM parameters:
* **Mixing coefficients $\pi$**: Dirichlet prior $p(\pi | \alpha) = \text{Dir}(\pi | \alpha)$.
* **Means $\mu_k$ and covariances $\Sigma_k$**: Normal-Wishart prior $p(\mu_k, \Sigma_k | m, \beta, W, \nu)$. This is a conjugate prior for Gaussian mean and precision (inverse covariance).

The full joint distribution is then:
$$p(x, z, \pi, \mu, \Sigma) = p(x, z | \pi, \mu, \Sigma) \cdot p(\pi | \alpha) \cdot \prod_{k=1}^K p(\mu_k, \Sigma_k | m, \beta, W, \nu)$$

The true posterior $p(z, \pi, \mu, \Sigma | x)$ is intractable due to the coupling between $z$ and the parameters. So, we resort to Variational Inference.

### Mean-Field Factorization for VGMM:

We assume the following mean-field factorization for the variational posterior $q(z, \pi, \mu, \Sigma)$ (Slide 13):
$$q(z, \pi, \mu, \Sigma) = q(z) \cdot q(\pi) \cdot \prod_{k=1}^K q(\mu_k, \Sigma_k)$$

This factorization implies that the latent assignments $z$, the mixing coefficients $\pi$, and the component parameters $(\mu_k, \Sigma_k)$ are variationally independent. We then iteratively update each factor using the general mean-field update rule.

### Variational Updates (Coordinate Ascent):

1.  **Update $q^*(z)$ (Latent Assignments)** (Slide 13):
    $q^*(z)$ turns out to be a product of categorical distributions, one for each data point $x_n$. The parameters of these categorical distributions are the responsibilities $r_{nk}$, similar to EM, but now they depend on the *expected values* of the parameters from the variational distributions $q(\pi)$ and $q(\mu_k, \Sigma_k)$:
    $$\log r_{nk} = \mathbb{E}_{q(\pi)}[\log \pi_k] + \frac{1}{2} \mathbb{E}_{q(\mu_k, \Sigma_k)}[\log |\Sigma_k^{-1}|] - \frac{1}{2} \mathbb{E}_{q(\mu_k, \Sigma_k)}[(x_n - \mu_k)^T \Sigma_k^{-1} (x_n - \mu_k)] + \text{const.}$$ 
    After computing these $\log r_{nk}$ values, we normalize them using `jax.nn.softmax` to get the actual responsibilities $r_{nk}$.

2.  **Update $q^*(\pi)$ (Mixing Coefficients)** (Slide 16):
    $q^*(\pi)$ is a Dirichlet distribution, whose parameters $\alpha_k^{new}$ are updated based on the expected counts from $q(z)$:
    $$\alpha_k^{new} = \alpha_k^{prior} + N_k$$ 
    where $N_k = \sum_{n=1}^N r_{nk}$ is the effective number of data points assigned to component $k$.

3.  **Update $q^*(\mu_k, \Sigma_k)$ (Component Parameters)** (Slide 18):
    $q^*(\mu_k, \Sigma_k)$ is a Normal-Wishart distribution, whose parameters are updated based on the data points assigned to component $k$ and the prior parameters:
    * $\beta_k^{new} = \beta^{prior} + N_k$
    * $m_k^{new} = \frac{\beta^{prior} m^{prior} + N_k \bar{x}_k}{\beta^{prior} + N_k}$ (where $\bar{x}_k$ is the weighted mean of data points for component $k$)
    * $\nu_k^{new} = \nu^{prior} + N_k$
    * $(W_k^{new})^{-1} = (W^{prior})^{-1} + N_k S_k + \frac{\beta^{prior} N_k}{\beta^{prior} + N_k} (\bar{x}_k - m^{prior})(\bar{x}_k - m^{prior})^T$ (where $S_k$ is the weighted covariance of data points for component $k$)

The algorithm alternates these updates until convergence of the ELBO. A key advantage of VGMMs over EM is that components can effectively be "switched off" if their $\alpha_k$ parameter becomes small, leading to $\mathbb{E}[\log \pi_k]$ becoming very negative, effectively removing that component from the mixture. This allows VGMMs to perform automatic model selection for the number of components, given a sufficiently large initial $K$.

Let's implement these updates and visualize the VGMM fitting process.

In [None]:
# --- Variational Inference for Gaussian Mixture Models (VGMM) ---

def initialize_vgmm_params(X, num_components, key, prior_params):
    """Initializes variational parameters for VGMM (alpha, beta, m, W, nu)."""
    num_samples, data_dim = X.shape
    
    # Unpack prior parameters
    alpha_prior = prior_params['alpha']
    beta_prior = prior_params['beta']
    m_prior = prior_params['m']
    W_prior = prior_params['W']
    nu_prior = prior_params['nu']

    key, subkey_means = jax.random.split(key)
    
    # Initialize variational parameters for q(z) - responsibilities (rnk)
    # A common initialization is to use k-means to get initial clusters, then assign responsibilities
    # For simplicity, we'll randomly assign points to components initially, then compute responsibilities
    initial_assignments = jax.random.randint(subkey_means, (num_samples,), 0, num_components)
    initial_responsibilities = jax.nn.one_hot(initial_assignments, num_classes=num_components)

    # Initialize variational parameters for q(pi) - Dirichlet parameters (alpha_k)
    # Based on initial responsibilities, plus prior
    alpha_k = alpha_prior + jnp.sum(initial_responsibilities, axis=0)

    # Initialize variational parameters for q(mu_k, Sigma_k) - Normal-Wishart parameters
    # (beta_k, m_k, W_k, nu_k)
    beta_k = jnp.array([beta_prior] * num_components)
    m_k = jnp.zeros((num_components, data_dim))
    W_k = jnp.array([W_prior] * num_components) # W_k is inverse covariance scale
    nu_k = jnp.array([nu_prior] * num_components)

    # Compute initial means and covariances from initial responsibilities to set m_k, W_k
    Nk_initial = jnp.sum(initial_responsibilities, axis=0)
    for k in range(num_components):
        if Nk_initial[k] > 0: # Avoid division by zero for empty clusters
            m_k = m_k.at[k].set(jnp.sum(initial_responsibilities[:, k, jnp.newaxis] * X, axis=0) / Nk_initial[k])
            diff = X - m_k[k]
            weighted_outer_product = jnp.dot((initial_responsibilities[:, k] * diff.T), diff)
            W_k_inv_initial = weighted_outer_product / Nk_initial[k]
            W_k = W_k.at[k].set(jnp.linalg.inv(W_k_inv_initial + jnp.eye(data_dim) * 1e-6)) # Add regularization
        else:
            # If a component is empty, initialize its mean randomly and covariance to identity/prior
            key, subkey_m = jax.random.split(key)
            m_k = m_k.at[k].set(jax.random.normal(subkey_m, (data_dim,)) * 0.1)
            W_k = W_k.at[k].set(jnp.eye(data_dim))
        
    # Ensure W_k is positive definite (inverse of covariance)
    W_k = jnp.array([w + jnp.eye(data_dim) * 1e-6 for w in W_k])

    return {'alpha_k': alpha_k, 'beta_k': beta_k, 'm_k': m_k, 'W_k': W_k, 'nu_k': nu_k}

# --- Expected values needed for VI updates ---
def E_log_pi(alpha_k):
    """Expected value of log(pi_k) under q(pi) = Dir(alpha_k)."""
    return digamma(alpha_k) - digamma(jnp.sum(alpha_k))

def E_log_det_inv_Sigma(W_k, nu_k, data_dim):
    """Expected value of log(|Sigma_k^-1|) under q(mu_k, Sigma_k) = Normal-Wishart."""
    # This is E[log|Lambda_k|] where Lambda_k = Sigma_k^-1
    # E[log|Lambda|] = sum_{d=1}^D digamma((nu + 1 - d)/2) + D log(2) + log|W|
    sum_digamma_terms = jnp.sum(digamma((nu_k[:, jnp.newaxis] + 1 - jnp.arange(1, data_dim + 1)) / 2), axis=1)
    log_det_W = jnp.linalg.slogdet(W_k)[1] # slogdet returns (sign, log_abs_det)
    return sum_digamma_terms + data_dim * jnp.log(2) + log_det_W

def E_Sigma_inv(W_k, nu_k):
    """Expected value of Sigma_k^-1 under q(mu_k, Sigma_k) = Normal-Wishart."""
    # E[Lambda_k] = nu_k * W_k
    return nu_k[:, jnp.newaxis, jnp.newaxis] * W_k

def E_mu_Sigma_inv_mu(m_k, beta_k, E_Sigma_inv_k):
    """Expected value of mu_k^T Sigma_k^-1 mu_k under q(mu_k, Sigma_k)."""
    # E[mu^T Lambda mu] = m^T Lambda m + D / beta
    # This is for a single component. Need to vectorize over components.
    term1 = jnp.einsum('kd,kdj,kj->k', m_k, E_Sigma_inv_k, m_k) # m_k^T E[Sigma_k^-1] m_k
    term2 = m_k.shape[1] / beta_k # D / beta_k
    return term1 + term2

@jax.jit
def vgmm_e_step(X, variational_params, data_dim):
    """Performs the E-step for VGMM: updates responsibilities (rnk)."""
    alpha_k = variational_params['alpha_k']
    beta_k = variational_params['beta_k']
    m_k = variational_params['m_k']
    W_k = variational_params['W_k']
    nu_k = variational_params['nu_k']

    num_samples, _ = X.shape
    num_components = alpha_k.shape[0]

    # Compute expected values of log-likelihood terms
    E_log_pi_k = E_log_pi(alpha_k)
    E_log_det_inv_Sigma_k = E_log_det_inv_Sigma(W_k, nu_k, data_dim)
    E_Sigma_inv_k = E_Sigma_inv(W_k, nu_k)

    log_rho_nk = jnp.zeros((num_samples, num_components))
    for k in range(num_components):
        # E[ (x_n - mu_k)^T Sigma_k^-1 (x_n - mu_k) ]
        # = E[ x_n^T Sigma_k^-1 x_n - 2 x_n^T Sigma_k^-1 mu_k + mu_k^T Sigma_k^-1 mu_k ]
        # = x_n^T E[Sigma_k^-1] x_n - 2 x_n^T E[Sigma_k^-1 mu_k] + E[mu_k^T Sigma_k^-1 mu_k]
        # E[Sigma_k^-1 mu_k] = E[Sigma_k^-1] m_k (since mu_k | Sigma_k ~ N(m_k, Sigma_k/beta_k))
        # E[mu_k^T Sigma_k^-1 mu_k] = m_k^T E[Sigma_k^-1] m_k + D / beta_k

        term_quadratic = jnp.einsum('nd,dj,nj->n', X, E_Sigma_inv_k[k], X) # x_n^T E[Sigma_k^-1] x_n
        term_linear = 2 * jnp.einsum('nd,dj,j->n', X, E_Sigma_inv_k[k], m_k[k]) # 2 x_n^T E[Sigma_k^-1] m_k
        term_mu_quad = jnp.einsum('d,dj,j->', m_k[k], E_Sigma_inv_k[k], m_k[k]) + data_dim / beta_k[k] # E[mu_k^T Sigma_k^-1 mu_k]

        E_quadratic_form = term_quadratic - term_linear + term_mu_quad

        log_rho_nk = log_rho_nk.at[:, k].set(
            E_log_pi_k[k] + 0.5 * E_log_det_inv_Sigma_k[k] - 0.5 * E_quadratic_form
        )

    # Normalize log_rho_nk to get responsibilities (rnk)
    responsibilities = jax.nn.softmax(log_rho_nk, axis=1)

    return responsibilities

@jax.jit
def vgmm_m_step(X, responsibilities, prior_params):
    """Performs the M-step for VGMM: updates variational parameters (alpha_k, beta_k, m_k, W_k, nu_k)."""
    num_samples, data_dim = X.shape
    num_components = responsibilities.shape[1]

    # Unpack prior parameters
    alpha_prior = prior_params['alpha']
    beta_prior = prior_params['beta']
    m_prior = prior_params['m']
    W_prior = prior_params['W']
    nu_prior = prior_params['nu']

    # Compute Nk and x_bar_k (weighted sum and mean for each component)
    Nk = jnp.sum(responsibilities, axis=0) # Sum of responsibilities for each component
    x_bar_k = jnp.dot(responsibilities.T, X) / (Nk[:, jnp.newaxis] + 1e-9) # Add epsilon for stability

    # Update alpha_k (Dirichlet parameters)
    new_alpha_k = alpha_prior + Nk

    # Update beta_k (Normal-Wishart beta parameter)
    new_beta_k = beta_prior + Nk

    # Update m_k (Normal-Wishart mean parameter)
    new_m_k = (beta_prior * m_prior + Nk[:, jnp.newaxis] * x_bar_k) / (new_beta_k[:, jnp.newaxis] + 1e-9)

    # Update nu_k (Normal-Wishart degrees of freedom parameter)
    new_nu_k = nu_prior + Nk

    # Update W_k (Normal-Wishart inverse covariance scale parameter)
    new_W_k = jnp.zeros((num_components, data_dim, data_dim))
    for k in range(num_components):
        # S_k = weighted covariance for component k
        diff_x_bar = X - x_bar_k[k]
        Sk = jnp.dot((responsibilities[:, k] * diff_x_bar.T), diff_x_bar) / (Nk[k] + 1e-9)
        
        # Term for (x_bar_k - m_prior)(x_bar_k - m_prior)^T
        diff_m_prior = x_bar_k[k] - m_prior
        term_prior_mean_diff = (beta_prior * Nk[k] / (beta_prior + Nk[k] + 1e-9)) * jnp.outer(diff_m_prior, diff_m_prior)
        
        # (W_k^{new})^{-1} = (W^{prior})^{-1} + Nk * Sk + term_prior_mean_diff
        new_W_k_inv = jnp.linalg.inv(W_prior) + Nk[k] * Sk + term_prior_mean_diff
        new_W_k = new_W_k.at[k].set(jnp.linalg.inv(new_W_k_inv + jnp.eye(data_dim) * 1e-6)) # Add regularization

    return {'alpha_k': new_alpha_k, 'beta_k': new_beta_k, 'm_k': new_m_k, 'W_k': new_W_k, 'nu_k': new_nu_k}

def compute_elbo(X, responsibilities, variational_params, prior_params, data_dim):
    """Computes the Evidence Lower Bound (ELBO) for VGMM."""
    alpha_k = variational_params['alpha_k']
    beta_k = variational_params['beta_k']
    m_k = variational_params['m_k']
    W_k = variational_params['W_k']
    nu_k = variational_params['nu_k']

    alpha_prior = prior_params['alpha']
    beta_prior = prior_params['beta']
    m_prior = prior_params['m']
    W_prior = prior_params['W']
    nu_prior = prior_params['nu']

    num_samples, _ = X.shape
    num_components = alpha_k.shape[0]

    # Expected values needed for ELBO calculation
    E_log_pi_k = E_log_pi(alpha_k)
    E_log_det_inv_Sigma_k = E_log_det_inv_Sigma(W_k, nu_k, data_dim)
    E_Sigma_inv_k = E_Sigma_inv(W_k, nu_k)

    # Term 1: E[log p(X | Z, mu, Sigma)]
    term1 = 0.0
    for k in range(num_components):
        E_quadratic_form_k = jnp.einsum('nd,dj,nj->n', X, E_Sigma_inv_k[k], X) \
                           - 2 * jnp.einsum('nd,dj,j->n', X, E_Sigma_inv_k[k], m_k[k]) \
                           + (jnp.einsum('d,dj,j->', m_k[k], E_Sigma_inv_k[k], m_k[k]) + data_dim / beta_k[k])
        term1 += jnp.sum(responsibilities[:, k] * (0.5 * E_log_det_inv_Sigma_k[k] - 0.5 * E_quadratic_form_k - 0.5 * data_dim * jnp.log(2 * jnp.pi)))

    # Term 2: E[log p(Z | pi)]
    term2 = jnp.sum(responsibilities * E_log_pi_k)

    # Term 3: E[log p(pi)] - prior on pi
    term3 = gammaln(jnp.sum(alpha_prior)) - jnp.sum(gammaln(alpha_prior)) + jnp.sum((alpha_prior - 1) * E_log_pi_k)

    # Term 4: E[log p(mu, Sigma)] - prior on mu, Sigma
    term4 = 0.0
    for k in range(num_components):
        # E[log p(mu_k, Sigma_k)] = E[log N(mu_k | m, Sigma_k/beta) + log Wishart(Sigma_k^-1 | W, nu)]
        # This is the log normalization constant for Normal-Wishart + terms involving expectations
        # See Bishop PRML, Chapter 10, Eq. 10.74 for terms
        log_norm_wishart = nu_prior[k] * jnp.linalg.slogdet(W_prior)[1] / 2 + (nu_prior[k] * data_dim / 2) * jnp.log(2) + data_dim * (data_dim - 1) / 4 * jnp.log(jnp.pi) - jnp.sum(gammaln((nu_prior[k] + 1 - jnp.arange(1, data_dim + 1)) / 2))
        
        term4_k = 0.5 * data_dim * jnp.log(beta_prior / (2 * jnp.pi)) + 0.5 * E_log_det_inv_Sigma_k[k] \
                - 0.5 * beta_prior * (jnp.einsum('d,dj,j->', m_prior, E_Sigma_inv_k[k], m_prior) + data_dim / beta_k[k] \
                                     - 2 * jnp.einsum('d,dj,j->', m_prior, E_Sigma_inv_k[k], m_k[k]) \
                                     + jnp.einsum('d,dj,j->', m_k[k], E_Sigma_inv_k[k], m_k[k]) \
                                     + data_dim / beta_k[k] # This term is E[ (m_prior - mu_k)^T Sigma_k^-1 (m_prior - mu_k) ]
                                     )
        term4 += term4_k # This is an approximation for now, full term is complex

    # Term 5: E[log q(Z)] - entropy of q(Z)
    term5 = -jnp.sum(responsibilities * jnp.log(responsibilities + 1e-9)) # Add epsilon for log(0) stability

    # Term 6: E[log q(pi)] - entropy of q(pi)
    alpha_sum = jnp.sum(alpha_k)
    term6 = - (jnp.sum((alpha_k - 1) * E_log_pi_k) + gammaln(alpha_sum) - jnp.sum(gammaln(alpha_k)))

    # Term 7: E[log q(mu, Sigma)] - entropy of q(mu, Sigma)
    term7 = 0.0
    for k in range(num_components):
        # Entropy of Normal-Wishart. See Bishop PRML, Chapter 10, Eq. 10.77
        term7_k = - (0.5 * E_log_det_inv_Sigma_k[k] + 0.5 * data_dim * jnp.log(beta_k[k] / (2 * jnp.pi)) \
                     - 0.5 * data_dim - 0.5 * beta_k[k] * data_dim / beta_k[k] \
                     + 0.5 * (nu_k[k] - data_dim - 1) * E_log_det_inv_Sigma_k[k] \
                     + 0.5 * nu_k[k] * data_dim \
                     + jnp.sum(gammaln((nu_k[k] + 1 - jnp.arange(1, data_dim + 1)) / 2)) \
                     - nu_k[k] * jnp.linalg.slogdet(W_k[k])[1] / 2 - nu_k[k] * data_dim / 2 * jnp.log(2) \
                    )
        term7 += term7_k

    elbo = term1 + term2 + term3 + term4 + term5 + term6 + term7
    return elbo

def run_vgmm_vi(X, num_components, prior_params, max_iter=200, tol=1e-5, key=None):
    """Runs the Variational Inference algorithm for Gaussian Mixture Models."""
    if key is None:
        key = jax.random.PRNGKey(0)

    num_samples, data_dim = X.shape

    # Initialize variational parameters
    variational_params = initialize_vgmm_params(X, num_components, key, prior_params)
    elbo_history = []

    print("Starting Variational Inference for VGMM...")
    for i in range(max_iter):
        # E-step (Update responsibilities)
        responsibilities = vgmm_e_step(X, variational_params, data_dim)

        # M-step (Update variational parameters for pi, mu, Sigma)
        variational_params = vgmm_m_step(X, responsibilities, prior_params)

        # Compute ELBO for convergence check
        current_elbo = compute_elbo(X, responsibilities, variational_params, prior_params, data_dim)
        elbo_history.append(current_elbo)

        # Check for convergence
        if i > 0 and jnp.abs(current_elbo - elbo_history[-2]) < tol:
            print(f"VI converged in {i+1} iterations. ELBO: {current_elbo:.4f}")
            break
    else:
        print(f"VI did not converge after {max_iter} iterations. Final ELBO: {current_elbo:.4f}")

    return variational_params, elbo_history, responsibilities


### Example: Fitting a VGMM to Data

Let's generate some synthetic data and use our VGMM implementation to fit the model. We'll visualize the final fit and the ELBO convergence.

In [None]:
# --- Main Execution: VGMM VI Example ---

# 1. Generate synthetic GMM data
num_samples = 500
num_components = 6 # Start with more components than true to see pruning effect
X_vgmm, true_weights, true_means, true_covariances = generate_gmm_data(num_samples, num_components=3, random_seed=42)

print("Generated Data Shape:", X_vgmm.shape)

# 2. Define prior parameters for VGMM
data_dim = X_vgmm.shape[1]
prior_params = {
    'alpha': jnp.ones(num_components) * 1.0, # Dirichlet concentration parameter (flat prior)
    'beta': jnp.array([1.0] * num_components), # Normal-Wishart beta (precision scaling)
    'm': jnp.mean(X_vgmm, axis=0), # Normal-Wishart mean (centered at data mean)
    'W': jnp.array([jnp.eye(data_dim) * 0.1] * num_components), # Normal-Wishart W (inverse covariance scale)
    'nu': jnp.array([data_dim] * num_components) # Normal-Wishart nu (degrees of freedom)
}

# 3. Run the VGMM VI algorithm
vi_key = jax.random.PRNGKey(50)
variational_params, elbo_history, final_responsibilities = \
    run_vgmm_vi(X_vgmm, num_components, prior_params, max_iter=300, tol=1e-6, key=vi_key)

print("\nFinal Variational Parameters (alpha_k, beta_k, m_k, W_k, nu_k):")
for param_name, param_val in variational_params.items():
    print(f"  {param_name}:\n{param_val}")

# 4. Extract estimated GMM parameters (means and covariances) from variational parameters
# For visualization, we use E[mu_k] = m_k and E[Sigma_k] = (nu_k * W_k)^-1
estimated_means_vgmm = variational_params['m_k']
estimated_covariances_vgmm = jnp.linalg.inv(variational_params['nu_k'][:, jnp.newaxis, jnp.newaxis] * variational_params['W_k'])

# For weights, we use E[pi_k] = alpha_k / sum(alpha_k)
estimated_weights_vgmm = variational_params['alpha_k'] / jnp.sum(variational_params['alpha_k'])

print("\nEstimated GMM Parameters (for visualization):")
print("  Weights:", estimated_weights_vgmm)
print("  Means:\n", estimated_means_vgmm)
print("  Covariances:\n", estimated_covariances_vgmm)

# Filter out components with very low weights for plotting clarity
threshold = 0.01
active_components_mask = estimated_weights_vgmm > threshold
active_means = estimated_means_vgmm[active_components_mask]
active_covariances = estimated_covariances_vgmm[active_components_mask]
active_responsibilities = final_responsibilities[:, active_components_mask]

# 5. Plot the final VGMM fit (only active components)
plot_gmm_plotly(
    X_vgmm,
    active_means,
    active_covariances,
    responsibilities=active_responsibilities,
    title='VGMM Fit after Variational Inference (Active Components)'
)

# 6. Plot the ELBO history
fig_elbo = go.Figure()
fig_elbo.add_trace(go.Scatter(
    x=jnp.arange(len(elbo_history)),
    y=jnp.array(elbo_history),
    mode='lines',
    name='ELBO'
))
fig_elbo.update_layout(title_text='ELBO during Variational Inference Iterations', title_x=0.5,
                      xaxis_title='Iteration',
                      yaxis_title='ELBO')
fig_elbo.show()


You should observe that the ELBO plot shows a non-decreasing trend, indicating convergence. The VGMM fit plot will display the estimated Gaussian components. If you started with more components than the true underlying clusters, you might notice that some components effectively "switch off" (their estimated weights become very small), demonstrating VGMM's ability to perform automatic model selection.

#### 5. Advantages and Disadvantages of Variational Inference

Variational Inference is a powerful mathematical tool (Slide 27) but comes with its own set of trade-offs:

### Advantages:

* **Speed and Scalability**: Often significantly faster than sampling-based methods like Markov Chain Monte Carlo (MCMC), especially for large datasets, as it transforms inference into an optimization problem.
* **Deterministic**: Provides a deterministic solution, which can be desirable for reproducibility and debugging.
* **Handles Intractability**: Can approximate posteriors that are otherwise intractable, making complex probabilistic models feasible.
* **Full Probabilistic Inference**: Unlike point estimation methods (like ML or MAP), VI provides an *entire distribution* over latent variables and parameters, offering uncertainty quantification.
* **Automatic Model Selection**: As seen with VGMMs, it can automatically prune redundant components or variables if the variational prior encourages sparsity.

### Disadvantages:

* **Mean-Field Approximation**: The factorization assumption can be restrictive. It ignores correlations between latent variables, which can lead to underestimation of posterior variance (i.e., overconfidence in estimates).
* **Derivation Complexity**: Constructing the update equations for the variational factors can be tedious and require significant "ELBOW grease," especially for complex models.
* **Local Optima**: Like EM, VI is a non-convex optimization problem and can converge to local optima, making initialization important.
* **Choice of Variational Family**: The choice of the variational distribution family $Q$ (e.g., mean-field, structured mean-field) directly impacts the quality of the approximation. A poor choice can lead to a bad fit, even if the ELBO is maximized.

Despite its drawbacks, VI is an indispensable tool in modern probabilistic machine learning, enabling the application of complex Bayesian models to large-scale problems where exact inference or MCMC would be computationally prohibitive.

#### 6. Summary

To summarize (Slide 26):

* **Variational Inference** is a general framework for constructing approximating probability distributions $q(z)$ to non-analytic posterior distributions $p(z|x)$ by minimizing the KL divergence $D_{KL}(q(z) || p(z|x))$, which is equivalent to maximizing the ELBO $\mathcal{L}(q)$.
* The **mean-field approximation** $q(z) = \prod_i q_i(z_i)$ simplifies the problem, leading to iterative coordinate ascent updates: $\log q_j^*(z_j) = \mathbb{E}_{q, i \ne j}[\log p(x, z)] + \text{const.}$
* Practical implementation involves defining the log joint, choosing a factorization, identifying the type of variational factors, and deriving their analytic update equations.
* VI provides a powerful and efficient alternative to sampling methods for intractable posteriors, offering full probabilistic inference and enabling applications in complex Bayesian models like VGMMs.

#### Exercises

**Exercise 1: Impact of Prior Parameters**
Experiment with different `prior_params` in the VGMM example. For instance:
    * Increase `alpha` (e.g., `jnp.ones(num_components) * 10.0`). How does a stronger prior on the mixing coefficients affect the component pruning behavior?
    * Increase `beta` (e.g., `jnp.array([10.0] * num_components)`). How does a stronger prior on the precision of means affect the estimated means and their uncertainty?
Discuss your observations.

**Exercise 2: Visualize Component Pruning**
Modify the plotting code to explicitly show the weights (`estimated_weights_vgmm`) for *all* `num_components` (even those below the threshold). This will make the pruning effect more explicit. Run the example with `num_components=6` and observe which components are effectively removed.

**Exercise 3: Deriving Expected Values (Conceptual)**
The `E_log_pi` and `E_log_det_inv_Sigma` functions compute expected values. For a Dirichlet distribution $\text{Dir}(\pi | \alpha)$ and a Wishart distribution $\text{Wishart}(\Sigma^{-1} | W, \nu)$, conceptually explain how you would derive the formulas for $\mathbb{E}[\log \pi_k]$ and $\mathbb{E}[\log |\Sigma^{-1}|]$ respectively. You don't need to do the full mathematical derivation, but outline the steps and relevant properties of these distributions.

**Exercise 4 (Advanced): Comparing ELBO and Log-Likelihood**
For the GMM example (which EM can also solve), modify the `run_vgmm_vi` function to also compute the exact log-likelihood (using the `compute_log_likelihood` function from Lecture 23) at each iteration, using the *expected* parameters from your variational distribution. Plot the ELBO and the exact log-likelihood on the same graph. What do you observe about their relationship throughout the optimization? (Hint: The ELBO should always be a lower bound on the log-likelihood, and they should converge to the same value if the variational family is rich enough and the true posterior is tractable). What does this tell you about the quality of the mean-field approximation for GMMs?