# Probabilistic Machine Learning: Lecture 23 - EM Algorithm

#### Introduction

Welcome to Lecture 23 of Probabilistic Machine Learning! This lecture introduces the **Expectation-Maximization (EM) algorithm**, a powerful iterative method for finding maximum likelihood (ML) or maximum a posteriori (MAP) estimates of parameters in probabilistic models, especially when the model involves **latent variables**. We will delve into why EM works by understanding its connection to maximizing a lower bound on the model evidence.

This notebook will provide detailed explanations and practical code illustrations using **JAX** for numerical computations and **Plotly** for interactive visualizations, focusing on the classic application of EM to **Gaussian Mixture Models (GMMs)**.

#### 1. The EM Algorithm: Maximizing Model Evidence

As we discussed in Lecture 22 (and recap from Slide 2), the general recipe for hyperparameter inference involves maximizing the marginal (log-) likelihood, also known as the **model evidence**:

$$\log p(y | \theta) = \log \left( \int p(y, z | \theta) dz \right)$$

where $y$ are the observed data, $\theta$ are the model parameters, and $z$ are the latent variables. This integral is often intractable. While Laplace approximations offer one way to approximate this, in some cases, the **EM algorithm** provides an iterative solution.

The core idea of EM (Slide 3) is to iteratively improve the parameter estimates by performing two steps:

* **E-step (Expectation)**: Compute the expected complete-data log-likelihood, $q(\theta, \theta_t)$, given the current parameter estimates $\theta_t$:
    $$q(\theta, \theta_t) = \int p(z | y, \theta_t) \log p(y, z | \theta) dz$$

* **M-step (Maximization)**: Set the new parameter estimates $\theta_{t+1}$ to maximize this expected complete-data log-likelihood:
    $$\theta_{t+1} = \arg \max_{\theta} q(\theta, \theta_{t+1})$$

The algorithm iterates these two steps until convergence of either the log-likelihood or the parameters $\theta$. Let's start by setting up our imports and a utility function for plotting.

In [None]:
import jax.numpy as jnp
import jax
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from jax.scipy.stats import multivariate_normal as mvn

# Set JAX to use 64-bit floats for numerical stability
jax.config.update("jax_enable_x64", True)

def plot_gmm_plotly(X, means, covariances, responsibilities=None, title="Gaussian Mixture Model", colors=['red', 'blue', 'green', 'purple', 'orange']):
    """Plots 2D data, GMM components, and optionally responsibilities."""
    fig = go.Figure()

    # Plot data points, optionally colored by responsibility
    if responsibilities is not None:
        # For responsibilities, we can use a scatter plot with customdata for hover text
        # and color based on the component with highest responsibility
        dominant_component = jnp.argmax(responsibilities, axis=1)
        for k in range(means.shape[0]):
            mask = dominant_component == k
            fig.add_trace(go.Scatter(
                x=X[mask, 0],
                y=X[mask, 1],
                mode='markers',
                marker=dict(color=colors[k % len(colors)], size=5, opacity=0.7),
                name=f'Data (Component {k+1})',
                showlegend=True
            ))
    else:
        fig.add_trace(go.Scatter(
            x=X[:, 0],
            y=X[:, 1],
            mode='markers',
            marker=dict(color='gray', size=5, opacity=0.7),
            name='Data Points',
            showlegend=True
        ))

    # Plot Gaussian components (mean and covariance ellipses)
    for k in range(means.shape[0]):
        mean = means[k]
        cov = covariances[k]
        
        # Draw ellipse representing 2-sigma contour
        vals, vecs = jnp.linalg.eigh(cov)
        order = vals.argsort()[::-1]
        vals = vals[order]
        vecs = vecs[:, order]
        
        theta = jnp.degrees(jnp.arctan2(*vecs[:, 0][::-1]))
        width, height = 2 * jnp.sqrt(5.991 * vals) # 5.991 for 95% confidence for 2D Chi-squared with 2 DOF

        fig.add_shape(
            type='circle',
            xref='x',
            yref='y',
            x0=mean[0] - width / 2,
            y0=mean[1] - height / 2,
            x1=mean[0] + width / 2,
            y1=mean[1] + height / 2,
            line=dict(color=colors[k % len(colors)], width=2),
            opacity=0.8,
            layer='below',
            name=f'Component {k+1}'
        )
        # Add mean point
        fig.add_trace(go.Scatter(
            x=[mean[0]],
            y=[mean[1]],
            mode='markers',
            marker=dict(symbol='x', size=10, color=colors[k % len(colors)], line=dict(width=2, color='black')),
            name=f'Mean {k+1}',
            showlegend=False
        ))

    fig.update_layout(title_text=title, title_x=0.5,
                      xaxis_title='Feature 1',
                      yaxis_title='Feature 2',
                      autosize=False, width=700, height=600)
    fig.update_yaxes(scaleanchor="x", scaleratio=1) # Keep aspect ratio square
    fig.show()

def generate_gmm_data(num_samples=300, num_components=3, random_seed=42):
    """Generates synthetic data from a Gaussian Mixture Model."""
    np.random.seed(random_seed)
    
    # True parameters for 3 components
    true_weights = np.array([0.3, 0.4, 0.3])
    true_means = np.array([
        [0, 0],
        [3, 3],
        [0, 4]
    ])
    true_covariances = np.array([
        [[0.5, 0.2], [0.2, 0.5]],
        [[0.8, -0.1], [-0.1, 0.8]],
        [[0.6, 0.3], [0.3, 0.6]]
    ])

    X = []
    for _ in range(num_samples):
        # Choose a component based on weights
        k = np.random.choice(num_components, p=true_weights)
        # Sample from the chosen Gaussian
        sample = np.random.multivariate_normal(true_means[k], true_covariances[k])
        X.append(sample)
    
    return jnp.array(X), true_weights, true_means, true_covariances


#### 2. Why EM Works: Maximizing a Lower Bound

The EM algorithm works by iteratively maximizing a lower bound on the model evidence, known as the **Evidence Lower Bound (ELBO)** (Slides 4-5).

For any arbitrary distribution $q(z)$ over the latent variables $z$, the log evidence can be decomposed as:

$$\log p(x | \theta) = \mathcal{L}(q, \theta) + D_{KL}(q || p(z | x, \theta))$$

where:
* $\mathcal{L}(q, \theta) = \int q(z) \log \left( \frac{p(x, z | \theta)}{q(z)} \right) dz$ is the **ELBO**.
* $D_{KL}(q || p(z | x, \theta)) = -\int q(z) \log \left( \frac{p(z | x, \theta)}{q(z)} \right) dz$ is the **Kullback-Leibler (KL) divergence** between $q(z)$ and the true posterior $p(z | x, \theta)$.

Since KL divergence is always non-negative ($D_{KL}(q || p) \ge 0$), the ELBO $\mathcal{L}(q, \theta)$ is a lower bound on the log evidence: $\log p(x | \theta) \ge \mathcal{L}(q, \theta)$.

The EM algorithm's steps can be understood in terms of this decomposition (Slides 7-9):

* **E-step**: With $\theta_{old}$ fixed, we set $q(z)$ to be the true posterior $p(z | x, \theta_{old})$. This makes the KL divergence term $D_{KL}(q || p(z | x, \theta_{old})) = 0$, thus making the ELBO equal to the log evidence at $\theta_{old}$: $\mathcal{L}(q, \theta_{old}) = \log p(x | \theta_{old})$. This step tightens the lower bound at the current parameter values.

* **M-step**: With $q(z)$ fixed (from the E-step), we maximize the ELBO $\mathcal{L}(q, \theta)$ with respect to $\theta$ to obtain $\theta_{new}$. Since $q(z)$ is fixed, maximizing $\mathcal{L}(q, \theta)$ is equivalent to maximizing $\int q(z) \log p(x, z | \theta) dz$, which is the expected complete-data log-likelihood. This step increases the ELBO, and since the KL divergence is non-negative, it also guarantees that the log evidence $\log p(x | \theta)$ either increases or stays the same.

This iterative process ensures that the log evidence is non-decreasing with each step, eventually converging to a local maximum.

#### 3. Example: EM for Gaussian Mixture Models (GMMs)

A classic application of the EM algorithm is to learn the parameters of a **Gaussian Mixture Model (GMM)** (Slide 13). A GMM assumes that the data $x$ is generated from a mixture of $K$ Gaussian distributions:

$$p(x | \pi, \mu, \Sigma) = \sum_{k=1}^K \pi_k \mathcal{N}(x | \mu_k, \Sigma_k)$$

where $\pi_k$ are the mixing coefficients (weights), $\mu_k$ are the means, and $\Sigma_k$ are the covariances for each component $k$. The sum over components makes direct maximization of the likelihood difficult due to the logarithm of a sum.

We introduce a latent variable $z_n$ for each data point $x_n$, where $z_{nk}=1$ if $x_n$ belongs to component $k$, and $z_{nk}=0$ otherwise. The "complete data" likelihood then factorizes (Slide 15):

$$\log p(x, z | \pi, \mu, \Sigma) = \sum_{n=1}^N \sum_{k=1}^K z_{nk} (\log \pi_k + \log \mathcal{N}(x_n | \mu_k, \Sigma_k))$$

This form is much easier to optimize because the logarithm is inside the sum over components.

### EM Algorithm for GMMs

Let $\theta = (\pi, \mu, \Sigma)$ denote all the GMM parameters.

**E-step (Compute Responsibilities)** (Slide 17):
The E-step requires computing the posterior probability of the latent variable $z_n$ given the observed data $x_n$ and current parameters $\theta_t$. This is the "responsibility" of component $k$ for data point $x_n$, denoted $r_{nk}$:

$$r_{nk} = p(z_{nk}=1 | x_n, \theta_t) = \frac{\pi_k \mathcal{N}(x_n | \mu_k, \Sigma_k)}{\sum_{j=1}^K \pi_j \mathcal{N}(x_n | \mu_j, \Sigma_j)}$$

**M-step (Update Parameters)** (Slides 18-20):
The M-step maximizes the expected complete-data log-likelihood with respect to $\pi, \mu, \Sigma$. The expectations $\mathbb{E}_{q(z)}[z_{nk}]$ are simply $r_{nk}$.

* **Update Means ($\mu_k$)**:
    $$\mu_k^{new} = \frac{\sum_{n=1}^N r_{nk} x_n}{\sum_{n=1}^N r_{nk}}$$

* **Update Covariances ($\Sigma_k$)**:
    $$\Sigma_k^{new} = \frac{\sum_{n=1}^N r_{nk} (x_n - \mu_k^{new})(x_n - \mu_k^{new})^T}{\sum_{n=1}^N r_{nk}}$$

* **Update Weights ($\pi_k$)**:
    $$\pi_k^{new} = \frac{\sum_{n=1}^N r_{nk}}{N}$$

Let's implement the EM algorithm for GMMs and visualize its iterative process.

In [None]:
# --- EM Algorithm for Gaussian Mixture Models ---

def initialize_gmm_params(X, num_components, key):
    """Randomly initializes GMM parameters (means, covariances, weights)."""
    num_samples, data_dim = X.shape

    key, subkey_means, subkey_covs, subkey_weights = jax.random.split(key, 4)

    # Initialize means by randomly picking data points
    random_indices = jax.random.choice(subkey_means, num_samples, (num_components,), replace=False)
    means = X[random_indices]

    # Initialize covariances to identity matrices (or small diagonal)
    covariances = jnp.array([jnp.eye(data_dim) * 0.1 for _ in range(num_components)])

    # Initialize weights uniformly
    weights = jnp.ones(num_components) / num_components

    return means, covariances, weights

@jax.jit
def e_step(X, means, covariances, weights):
    """Computes responsibilities (rnk) for the GMM."""
    num_samples, _ = X.shape
    num_components = means.shape[0]

    # Compute log-likelihood for each data point under each component
    log_likelihoods = jnp.zeros((num_samples, num_components))
    for k in range(num_components):
        # Add a small regularization to covariance to prevent singularity
        reg_cov = covariances[k] + jnp.eye(covariances[k].shape[0]) * 1e-6
        log_likelihoods = log_likelihoods.at[:, k].set(mvn.logpdf(X, means[k], reg_cov))

    # Compute log(p(x_n, z_n=k)) = log(pi_k) + log(N(x_n | mu_k, Sigma_k))
    log_joint = log_likelihoods + jnp.log(weights)

    # Compute log(sum_k p(x_n, z_n=k)) for normalization
    log_sum_exp = jax.nn.logsumexp(log_joint, axis=1, keepdims=True)

    # Compute responsibilities (rnk) = p(z_n=k | x_n) = exp(log_joint - log_sum_exp)
    responsibilities = jnp.exp(log_joint - log_sum_exp)

    return responsibilities

@jax.jit
def m_step(X, responsibilities):
    """Updates GMM parameters (means, covariances, weights)."""
    num_samples, data_dim = X.shape
    num_components = responsibilities.shape[1]

    # Update weights
    Nk = jnp.sum(responsibilities, axis=0) # Sum of responsibilities for each component
    new_weights = Nk / num_samples

    # Update means
    new_means = jnp.dot(responsibilities.T, X) / Nk[:, jnp.newaxis]

    # Update covariances
    new_covariances = jnp.zeros((num_components, data_dim, data_dim))
    for k in range(num_components):
        diff = X - new_means[k]
        # Weighted outer product sum
        weighted_outer_product = jnp.dot((responsibilities[:, k] * diff.T), diff)
        new_covariances = new_covariances.at[k].set(weighted_outer_product / Nk[k])
        # Add a small diagonal term for numerical stability, especially if a component gets very few points
        new_covariances = new_covariances.at[k].set(new_covariances[k] + jnp.eye(data_dim) * 1e-6)

    return new_means, new_covariances, new_weights

def compute_log_likelihood(X, means, covariances, weights):
    """Computes the total log-likelihood of the data under the GMM."""
    num_samples, _ = X.shape
    num_components = means.shape[0]

    log_likelihoods_per_component = jnp.zeros((num_samples, num_components))
    for k in range(num_components):
        reg_cov = covariances[k] + jnp.eye(covariances[k].shape[0]) * 1e-6
        log_likelihoods_per_component = log_likelihoods_per_component.at[:, k].set(mvn.logpdf(X, means[k], reg_cov))

    # log(sum_k pi_k * N(x_n | mu_k, Sigma_k))
    log_weighted_likelihoods = log_likelihoods_per_component + jnp.log(weights)
    total_log_likelihood = jnp.sum(jax.nn.logsumexp(log_weighted_likelihoods, axis=1))
    return total_log_likelihood

def run_gmm_em(X, num_components, max_iter=100, tol=1e-4, key=None):
    """
    Runs the EM algorithm for Gaussian Mixture Models.
    Returns final parameters, log-likelihood history, and responsibilities.
    """
    if key is None:
        key = jax.random.PRNGKey(0)

    means, covariances, weights = initialize_gmm_params(X, num_components, key)
    log_likelihood_history = []

    print("Starting EM for GMM...")
    for i in range(max_iter):
        # E-step
        responsibilities = e_step(X, means, covariances, weights)

        # M-step
        new_means, new_covariances, new_weights = m_step(X, responsibilities)

        # Compute log-likelihood for convergence check
        current_log_likelihood = compute_log_likelihood(X, new_means, new_covariances, new_weights)
        log_likelihood_history.append(current_log_likelihood)

        # Check for convergence
        if i > 0 and jnp.abs(current_log_likelihood - log_likelihood_history[-2]) < tol:
            print(f"EM converged in {i+1} iterations. Log-likelihood: {current_log_likelihood:.4f}")
            break

        means, covariances, weights = new_means, new_covariances, new_weights
    else:
        print(f"EM did not converge after {max_iter} iterations. Final log-likelihood: {current_log_likelihood:.4f}")

    return means, covariances, weights, log_likelihood_history, responsibilities


### Example: Fitting a GMM to Data

Let's generate some synthetic data from a known GMM and then use our EM implementation to recover the parameters. We'll visualize the progress.

In [None]:
# --- Main Execution: GMM EM Example ---

# 1. Generate synthetic GMM data
num_samples = 500
num_components = 3
X_gmm, true_weights, true_means, true_covariances = generate_gmm_data(num_samples, num_components)

print("True GMM Parameters:")
print("  Weights:", true_weights)
print("  Means:\n", true_means)
print("  Covariances:\n", true_covariances)

# 2. Run the EM algorithm
em_key = jax.random.PRNGKey(10)
estimated_means, estimated_covariances, estimated_weights, ll_history, final_responsibilities = \
    run_gmm_em(X_gmm, num_components, max_iter=200, tol=1e-5, key=em_key)

print("\nEstimated GMM Parameters after EM:")
print("  Weights:", estimated_weights)
print("  Means:\n", estimated_means)
print("  Covariances:\n", estimated_covariances)

# 3. Plot the final GMM fit
plot_gmm_plotly(
    X_gmm,
    estimated_means,
    estimated_covariances,
    responsibilities=final_responsibilities,
    title='GMM Fit after EM Algorithm'
)

# 4. Plot the log-likelihood history
fig_ll = go.Figure()
fig_ll.add_trace(go.Scatter(
    x=jnp.arange(len(ll_history)),
    y=jnp.array(ll_history),
    mode='lines',
    name='Log-Likelihood'
))
fig_ll.update_layout(title_text='Log-Likelihood during EM Iterations', title_x=0.5,
                     xaxis_title='Iteration',
                     yaxis_title='Log-Likelihood')
fig_ll.show()


The plots should show the data points clustered according to the estimated Gaussian components, with ellipses representing their covariances. The log-likelihood plot should demonstrate a non-decreasing trend, indicating the EM algorithm's convergence towards a local maximum of the evidence.

#### 4. The EM Algorithm for MAP Estimation

It is straightforward to extend the EM algorithm to find the **Maximum A Posteriori (MAP) estimate** instead of the maximum likelihood estimate (Slide 11). This involves simply adding a log-prior term for the parameters $\theta$ to the objective function in the M-step:

$$\theta_{new} = \arg \max_{\theta} \left( \int q(z) \log \frac{p(x, z | \theta)}{q(z)} dz + \log p(\theta) \right)$$

This effectively maximizes $\mathcal{L}(q, \theta) + \log p(\theta)$, which is a lower bound on $\log p(\theta | x)$. The E-step remains the same, as it only depends on the likelihood $p(z|x, \theta_{old})$. The M-step then optimizes the combination of the expected complete-data log-likelihood and the log-prior over parameters.

#### 5. Why is EM Useful? (Analytic Updates)

One of the main reasons EM is so useful is that for many models, particularly those where the complete-data likelihood $p(x, z | \theta)$ belongs to an exponential family, the M-step updates can be found **analytically** (Slide 12). This is because the expectation of the sufficient statistics is often all that's needed.

For instance, if $p(x, z | \theta) = \exp(\phi(x, z)^T \theta - \log Z(\theta))$, then the ELBO is $\mathbb{E}_{q(z)}[\phi(x, z)]^T \theta - \log Z(\theta)$. Maximizing this with respect to $\theta$ often leads to closed-form solutions, as demonstrated with the GMM updates.

#### 6. Summary

In summary (Slide 26):

* The **EM algorithm** is an iterative procedure to find maximum likelihood (or MAP) estimates for models involving latent variables.
* It works by alternating between an **E-step** (computing the posterior over latent variables given current parameters) and an **M-step** (maximizing the expected complete-data log-likelihood with respect to parameters).
* EM is guaranteed to converge to a local maximum of the model evidence (marginal likelihood) because each step increases a lower bound (ELBO) on the evidence.
* For many models, especially those whose complete-data likelihood belongs to an exponential family, the M-step updates are **analytic**, making EM a computationally attractive solution.

The EM algorithm is a cornerstone of many probabilistic models, including mixture models, Hidden Markov Models, and factor analysis, providing a robust framework for learning in the presence of unobserved variables.

#### Exercises

**Exercise 1: Impact of Initialization**
The EM algorithm is sensitive to initialization. In the `run_gmm_em` function, change the `em_key` to different random seeds (e.g., 10, 20, 30). Observe how the final estimated parameters and the log-likelihood history change. Plot the final GMM fits for a few different initializations. Discuss why initialization matters for EM.

**Exercise 2: Varying Number of Components**
Change `num_components` in the `main` execution block (e.g., to 2 or 4). How does the fit change? What happens if `num_components` is too low or too high compared to the true number of components in the data? (You might need to adjust `max_iter` or `tol` for convergence).

**Exercise 3: Implementing Log-Likelihood for Convergence**
The provided `compute_log_likelihood` function calculates the log-likelihood. Explain, in detail, why this specific formula correctly represents the log-likelihood of the observed data $X$ under the GMM, given the parameters $\pi, \mu, \Sigma$. Relate it back to the definition of the GMM likelihood.

**Exercise 4 (Advanced): EM for a Simple HMM (Conceptual)**
Briefly research how the EM algorithm (specifically, the Baum-Welch algorithm) is applied to Hidden Markov Models (HMMs). Conceptually describe what the E-step and M-step would involve for an HMM with discrete states and Gaussian observations. How do these steps relate to the general EM framework?