# Interactive Dirichlet Process Mixture Model Explorer

This notebook demonstrates **Bayesian inference for Dirichlet Process Mixture Models (DPMM)** using JAX.

**What you'll learn:**
- How DPMMs automatically discover the number of clusters in data
- The stick-breaking construction for infinite mixture models
- How the concentration parameter Œ± controls cluster formation
- Posterior inference using importance resampling

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/josephausterweil/probintro/blob/amplify/notebooks/dpmm_interactive.ipynb)

## Setup

First, let's install the required packages if running on Google Colab:

**Note**: After running the installation cell below, you may need to restart the runtime (Runtime ‚Üí Restart runtime) before proceeding.

In [None]:
# Check if running on Colab
try:
    import google.colab
    IN_COLAB = True
except:
    IN_COLAB = False

if IN_COLAB:
    print("Running on Google Colab - installing dependencies...")
    # Install compatible versions: numpy 2.0.x works with most Colab packages
    # Upgrade JAX to satisfy flax and orbax requirements
    !pip install -q --upgrade "jax>=0.6.0" "jaxlib>=0.6.0" "numpy>=2.0,<2.1" scipy ipywidgets
    print("‚úì Dependencies installed")
    print("‚ö†Ô∏è  Please restart runtime (Runtime ‚Üí Restart runtime) before continuing")
else:
    print("Running locally")

Import required libraries:

In [None]:
import jax
import jax.numpy as jnp
import jax.random as random
from jax.scipy.stats import norm, beta, dirichlet
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
import ipywidgets as widgets
from IPython.display import display, clear_output

# Enable widgets in Colab
try:
    import google.colab
    from google.colab import output
    output.enable_custom_widget_manager()
except:
    pass

# Set random seed for reproducibility
key = random.PRNGKey(42)

print("‚úì Imports successful")
print(f"JAX version: {jax.__version__}")

## The Dirichlet Process Mixture Model

A DPMM is an infinite mixture model that can automatically discover the number of clusters in data.

### Model Structure

For each cluster $k = 1, 2, \ldots, K_{\text{max}}$:
1. Draw cluster center: $\mu_k \sim \text{Normal}(\mu_0, \sigma_0^2)$
2. Draw stick-breaking weight: $\beta_k \sim \text{Beta}(1, \alpha)$
3. Compute mixture probability: $\pi_k = \beta_k \prod_{j<k}(1-\beta_j)$

Then normalize to get $\theta \sim \text{Dirichlet}(\pi_1, \ldots, \pi_K)$

For each observation $i = 1, \ldots, N$:
1. Draw cluster assignment: $z_i \sim \text{Categorical}(\theta)$
2. Draw observation: $x_i \sim \text{Normal}(\mu_{z_i}, \sigma_x^2)$

### Key Parameters

- **Œ± (concentration)**: Controls how spread out the mixture is
  - Small Œ± ‚Üí Few active clusters
  - Large Œ± ‚Üí Many clusters with similar weights
- **$K_{\text{max}}$**: Truncation level (upper bound on clusters)
- **$\mu_0, \sigma_0$**: Prior on cluster centers
- **$\sigma_x$**: Observation noise

## Implementation

Let's implement the DPMM using pure JAX (without GenJAX for now, to keep it simple and debuggable):

In [None]:
def stick_breaking_weights(key, alpha, K):
    """Generate stick-breaking weights for DPMM.
    
    Args:
        key: JAX random key
        alpha: Concentration parameter
        K: Number of components
    
    Returns:
        pis: Mixture probabilities (sum to ~1)
    """
    betas = random.beta(key, 1.0, alpha, shape=(K,))
    
    # Stick-breaking: œÄ_k = Œ≤_k * ‚àè_{j<k}(1-Œ≤_j)
    pis = jnp.zeros(K)
    remaining_stick = 1.0
    
    for k in range(K):
        pis = pis.at[k].set(betas[k] * remaining_stick)
        remaining_stick *= (1.0 - betas[k])
    
    # Ensure no zero probabilities (for numerical stability)
    pis = jnp.maximum(pis, 1e-6)
    pis = pis / jnp.sum(pis)  # Renormalize
    
    return pis

def sample_dpmm_prior(key, alpha, K, mu0, sig0, sigx, N):
    """Sample from DPMM prior.
    
    Returns:
        mus: Cluster centers
        thetas: Mixture probabilities
        zs: Cluster assignments
        xs: Generated data
    """
    keys = random.split(key, 4)
    
    # Sample cluster centers
    mus = random.normal(keys[0], shape=(K,)) * sig0 + mu0
    
    # Generate stick-breaking weights
    pis = stick_breaking_weights(keys[1], alpha, K)
    
    # Sample from Dirichlet to get final mixture weights
    # (This adds additional variability on top of stick-breaking)
    thetas = random.dirichlet(keys[2], pis * 100)  # Scale up for concentration
    
    # Sample cluster assignments
    zs = random.categorical(keys[3], jnp.log(thetas), shape=(N,))
    
    # Sample observations
    key_xs = random.split(keys[3], N)
    xs = jnp.array([random.normal(key_xs[i]) * sigx + mus[zs[i]] for i in range(N)])
    
    return mus, thetas, zs, xs

def log_likelihood(xs, mus, thetas, sigx, K):
    """Compute log likelihood of data given parameters."""
    N = len(xs)
    ll = 0.0
    
    for i in range(N):
        # p(x_i | mus, thetas) = ‚àë_k Œ∏_k * N(x_i | Œº_k, œÉ_x¬≤)
        component_lls = jnp.array([norm.logpdf(xs[i], mus[k], sigx) for k in range(K)])
        ll += jax.scipy.special.logsumexp(jnp.log(thetas) + component_lls)
    
    return ll

def importance_sampling(key, obs_xs, alpha, K, mu0, sig0, sigx, num_samples):
    """Perform importance sampling for posterior inference.
    
    Args:
        obs_xs: Observed data
        alpha, K, mu0, sig0, sigx: Model parameters
        num_samples: Number of importance samples
    
    Returns:
        mus_samples: Posterior samples of cluster centers
        thetas_samples: Posterior samples of mixture weights
        weights: Importance weights (normalized)
    """
    N = len(obs_xs)
    keys = random.split(key, num_samples)
    
    # Sample from prior
    all_mus = []
    all_thetas = []
    log_weights = []
    
    for i in range(num_samples):
        mus, thetas, _, _ = sample_dpmm_prior(keys[i], alpha, K, mu0, sig0, sigx, N)
        all_mus.append(mus)
        all_thetas.append(thetas)
        
        # Compute likelihood as importance weight
        ll = log_likelihood(obs_xs, mus, thetas, sigx, K)
        log_weights.append(ll)
    
    # Normalize weights
    log_weights = jnp.array(log_weights)
    log_weights = log_weights - jax.scipy.special.logsumexp(log_weights)
    weights = jnp.exp(log_weights)
    
    return jnp.array(all_mus), jnp.array(all_thetas), weights

def importance_resampling(key, mus_samples, thetas_samples, weights, num_resamples):
    """Resample from importance samples to get posterior samples."""
    indices = random.choice(key, len(weights), shape=(num_resamples,), p=weights)
    return mus_samples[indices], thetas_samples[indices]

print("‚úì DPMM functions defined")

## Interactive Exploration

Now let's create an interactive function to explore how DPMMs work! The cell below defines `run_dpmm_inference()` which you can call with different parameters.

In [None]:
# Default observed data (3 clear clusters)
default_data = np.array([-10.4, -10., -9.4, -10.1, -9.9, 0., 9.5, 9.9, 10., 10.1, 10.5])

def run_dpmm_inference(alpha=2.0, K_max=10, num_samples=500, data_str=None):
    """Run DPMM inference and visualize results.
    
    Parameters:
    -----------
    alpha : float
        Concentration parameter (0.1 to 10.0)
    K_max : int
        Maximum number of clusters (3 to 20)
    num_samples : int
        Number of importance samples (100 to 2000)
    data_str : str, optional
        Comma-separated data points. If None, uses default 3-cluster data.
    """
    # Parse data
    if data_str is None or data_str == "":
        obs_xs = default_data
    else:
        try:
            obs_xs = jnp.array([float(x.strip()) for x in data_str.split(',')])
        except:
            print("‚ö†Ô∏è  Error parsing data, using default")
            obs_xs = default_data
    
    N = len(obs_xs)
    print(f"üîÑ Running inference with Œ±={alpha:.1f}, K_max={K_max}, {num_samples} samples...")
    print(f"   Data: {N} observations")
    
    # Fixed hyperparameters
    mu0 = 0.0
    sig0 = 4.0
    sigx = 0.05
    
    # Run importance sampling
    print(f"   ‚è≥ Sampling from prior and computing likelihoods...")
    key_local = random.PRNGKey(np.random.randint(0, 10000))
    mus_samples, thetas_samples, weights = importance_sampling(
        key_local, obs_xs, alpha, K_max, mu0, sig0, sigx, num_samples
    )
    
    # Resample to get posterior samples
    print(f"   ‚è≥ Resampling to get posterior...")
    key_resample = random.PRNGKey(np.random.randint(0, 10000))
    mus_post, thetas_post = importance_resampling(
        key_resample, mus_samples, thetas_samples, weights, num_samples
    )
    
    # Generate posterior predictive samples
    print(f"   ‚è≥ Generating posterior predictive samples...")
    key_pred = random.PRNGKey(np.random.randint(0, 10000))
    pred_samples = []
    for i in range(num_samples):
        z = random.categorical(key_pred, jnp.log(thetas_post[i]))
        x = random.normal(key_pred) * sigx + mus_post[i, z]
        pred_samples.append(x)
        key_pred = random.split(key_pred, 1)[0]
    
    pred_samples = jnp.array(pred_samples)
    
    # Flatten for visualization
    mus_flat = mus_post.flatten()
    thetas_flat = thetas_post.flatten()
    
    print("   ‚úÖ Inference complete!\n")
    
    # Create visualization
    fig, ax = plt.subplots(figsize=(12, 6))
    
    # Histogram of observed data
    ax.hist(obs_xs, bins=max(3, int(np.sqrt(N))), density=True, alpha=0.5, 
            color='gray', label='Observed data', edgecolor='black')
    
    # Posterior predictive density
    x_range = np.linspace(float(obs_xs.min()) - 2, float(obs_xs.max()) + 2, 200)
    kde_pred = gaussian_kde(np.array(pred_samples))
    ax.plot(x_range, kde_pred(x_range), 'b-', linewidth=3, 
            label='Posterior predictive p(xÃÇ|data)', alpha=0.8)
    
    # Posterior of cluster centers (weighted by mixture probabilities)
    kde_mus = gaussian_kde(np.array(mus_flat), weights=np.array(thetas_flat))
    ax.plot(x_range, kde_mus(x_range), 'r-', linewidth=3,
            label='Posterior p(Œº|data)', alpha=0.8)
    
    ax.set_xlabel('x', fontsize=12)
    ax.set_ylabel('Density', fontsize=12)
    ax.set_title(f'DPMM Inference (Œ±={alpha:.1f}, K_max={K_max})', fontsize=14, fontweight='bold')
    ax.legend(fontsize=11)
    ax.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    effective_clusters = jnp.sum(thetas_post.mean(axis=0) > 0.01)
    print(f"üìä Estimated active clusters: {int(effective_clusters)} (with Œ∏ > 0.01)")
    print(f"üìä Top 3 cluster weights: {[f'{w:.3f}' for w in sorted(thetas_post.mean(axis=0), reverse=True)[:3]]}")

print("‚úì DPMM inference function defined")

# Interactive widget interface
from ipywidgets import interact, FloatSlider, IntSlider, Text

print("üéõÔ∏è  Interactive DPMM Explorer")
print("="*70)
print("Adjust the sliders below and the visualization will update automatically!\n")

interact(
    run_dpmm_inference,
    alpha=FloatSlider(value=2.0, min=0.1, max=10.0, step=0.1, description='Œ± (concentration):'),
    K_max=IntSlider(value=10, min=3, max=20, step=1, description='K_max (clusters):'),
    num_samples=IntSlider(value=500, min=100, max=2000, step=100, description='Samples:'),
    data_str=Text(value='', description='Custom data:', placeholder='Leave empty for default, or enter: -5, -4.8, 5, 4.8')
);

## Exercises

Use the interactive sliders above or call the function directly to try these experiments:

1. **Effect of Œ±**:
   - Set Œ±=0.5: How many clusters are active?
   - Set Œ±=5.0: How does the posterior change?

2. **Different data** (enter in Custom data field or call function):
   - Two clusters: `-5, -4.8, -5.2, 5, 4.8, 5.2`
   - Four clusters: `-10, -9, 0, 1, 10, 11, 20, 21`
   - Single cluster: `0, 0.1, -0.1, 0.2, -0.2`

3. **Truncation level**:
   - Use default data (3 clusters) but set K_max=20
   - What happens to the unused clusters?

4. **Sample size**:
   - Run with 100 samples vs 1000 samples
   - How does it affect the smoothness of posteriors?

## Key Insights

1. **Automatic discovery**: DPMMs automatically discover the number of clusters without specifying K in advance

2. **Concentration parameter**: Œ± controls the "richness" of the mixture:
   - Small Œ± ‚Üí Few large clusters (concentrated)
   - Large Œ± ‚Üí Many small clusters (dispersed)

3. **Posterior uncertainty**: The red curve shows uncertainty about cluster locations, not just point estimates

4. **Predictive distribution**: The blue curve shows what new data might look like, accounting for both parameter uncertainty and cluster structure

## Connection to Tutorial

This notebook demonstrates the concepts from **Chapter 6: Dirichlet Process Mixture Models** in the tutorial. See the tutorial for:
- Detailed explanation of stick-breaking
- Chinese Restaurant Process interpretation
- Comparison with fixed-K GMMs
- GenJAX implementation details