# Interactive Dirichlet Process Mixture Model Explorer (GenJAX)

This notebook demonstrates **Bayesian inference for Dirichlet Process Mixture Models (DPMM)** using **GenJAX** - Gen's JAX backend.

**What you'll learn:**
- How to define generative models with GenJAX's `@gen` decorator
- How DPMMs automatically discover the number of clusters in data
- The stick-breaking construction for infinite mixture models
- How the concentration parameter Œ± controls cluster formation
- Posterior inference using importance resampling with GenJAX

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/josephausterweil/probintro/blob/amplify/notebooks/dpmm_interactive.ipynb)

## Setup

First, let's install the required packages if running on Google Colab:

**Note**: After running the installation cell below, you may need to restart the runtime (Runtime ‚Üí Restart runtime) before proceeding.

In [None]:
# Check if running on Colab
try:
    import google.colab
    IN_COLAB = True
except:
    IN_COLAB = False

if IN_COLAB:
    print("Running on Google Colab - installing dependencies...")
    # Install GenJAX with compatible versions
    # GenJAX requires specific JAX and numpy versions
    !pip install -q --upgrade "jax>=0.4.20,<0.5" "jaxlib>=0.4.20,<0.5" "numpy>=1.22,<2.0" genjax scipy ipywidgets
    print("‚úì Dependencies installed")
    print("‚ö†Ô∏è  Please restart runtime (Runtime ‚Üí Restart runtime) before continuing")
else:
    print("Running locally")

Import required libraries:

In [None]:
# Force JAX to use CPU to avoid CUDA plugin conflicts
import os
os.environ['JAX_PLATFORMS'] = 'cpu'

import jax
import jax.numpy as jnp
import jax.random as random
from jax.scipy.stats import norm
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
import ipywidgets as widgets
from IPython.display import display, clear_output

# Import GenJAX
import genjax
from genjax import gen

# Enable widgets in Colab
try:
    import google.colab
    from google.colab import output
    output.enable_custom_widget_manager()
except:
    pass

# Set random seed for reproducibility
key = random.PRNGKey(42)

print("‚úì Imports successful (CPU mode)")
print(f"JAX version: {jax.__version__}")
print(f"GenJAX version: {genjax.__version__}")

## The Dirichlet Process Mixture Model

A DPMM is an infinite mixture model that can automatically discover the number of clusters in data.

### Model Structure

For each cluster $k = 1, 2, \ldots, K_{\text{max}}$:
1. Draw cluster center: $\mu_k \sim \text{Normal}(\mu_0, \sigma_0^2)$
2. Draw stick-breaking weight: $\beta_k \sim \text{Beta}(1, \alpha)$
3. Compute mixture probability: $\pi_k = \beta_k \prod_{j<k}(1-\beta_j)$

Then normalize to get $\theta \sim \text{Dirichlet}(\pi_1, \ldots, \pi_K)$

For each observation $i = 1, \ldots, N$:
1. Draw cluster assignment: $z_i \sim \text{Categorical}(\theta)$
2. Draw observation: $x_i \sim \text{Normal}(\mu_{z_i}, \sigma_x^2)$

### Key Parameters

- **Œ± (concentration)**: Controls how spread out the mixture is
  - Small Œ± ‚Üí Few active clusters
  - Large Œ± ‚Üí Many clusters with similar weights
- **$K_{\text{max}}$**: Truncation level (upper bound on clusters)
- **$\mu_0, \sigma_0$**: Prior on cluster centers
- **$\sigma_x$**: Observation noise

## GenJAX Implementation

Let's implement the DPMM using **proper GenJAX patterns**:

### Key GenJAX Concepts

1. **`@gen` decorator**: Marks a function as a generative model
2. **`@ "address"` syntax**: Every random choice needs a unique string address
3. **`Target` + `ChoiceMap`**: Define posterior by conditioning on observations
4. **`ImportanceK`**: GenJAX's importance sampling algorithm
5. **`jax.vmap`**: Parallel sampling for efficiency

### Why GenJAX?

- **Programmable inference**: Mix different inference methods (importance sampling, MCMC, variational)
- **JAX integration**: JIT compilation and automatic differentiation
- **Composable models**: Build complex models from simple generative functions

In [None]:
from genjax import normal, beta as beta_dist, dirichlet, categorical
from genjax.inference.smc import ImportanceK
from genjax import Target, ChoiceMap

def make_dpmm_model(K, N):
    """Factory function to create a DPMM model with fixed K and N.
    
    This is necessary because GenJAX can't trace functions with dynamic loops.
    We fix K and N at model creation time as closure variables.
    """
    @gen
    def dpmm_model(alpha, mu0, sig0, sigx):
        """DPMM generative model using stick-breaking construction with GenJAX.
        
        K and N are fixed at model creation (closure variables).
        alpha, mu0, sig0, sigx are traced parameters.
        """
        # Sample cluster centers (K is fixed, so this loop is unrollable)
        mus = []
        for k in range(K):
            mu_k = normal(mu0, sig0) @ f"mu_{k}"
            mus.append(mu_k)
        
        # Stick-breaking process for mixture weights
        betas = []
        pis = []
        remaining_stick = 1.0
        for k in range(K):
            beta_k = beta_dist(1.0, alpha) @ f"beta_{k}"
            betas.append(beta_k)
            pi_k = beta_k * remaining_stick
            pis.append(pi_k)
            remaining_stick *= (1.0 - beta_k)
        
        # Normalize mixture weights
        pis_array = jnp.array(pis)
        pis_array = jnp.maximum(pis_array, 1e-6)
        pis_array = pis_array / jnp.sum(pis_array)
        
        # Sample from Dirichlet to get final mixture weights
        thetas = dirichlet(pis_array * 100.0) @ "thetas"
        
        # Sample cluster assignments and observations (N is fixed)
        zs = []
        xs = []
        mus_array = jnp.array(mus)
        for i in range(N):
            z_i = categorical(thetas) @ f"z_{i}"
            x_i = normal(mus_array[z_i], sigx) @ f"x_{i}"
            zs.append(z_i)
            xs.append(x_i)
        
        return {
            'mus': mus_array,
            'thetas': thetas,
            'zs': jnp.array(zs),
            'xs': jnp.array(xs),
            'pis': pis_array,
            'betas': jnp.array(betas)
        }
    
    return dpmm_model

def importance_sampling_genjax(key, obs_xs, alpha, K, mu0, sig0, sigx, num_samples):
    """Perform importance sampling using GenJAX's ImportanceK algorithm."""
    N = len(obs_xs)
    
    # Create the model with fixed K and N
    dpmm_model = make_dpmm_model(K, N)
    
    # Create observations ChoiceMap (condition on observed data)
    obs_dict = {f"x_{i}": float(obs_xs[i]) for i in range(N)}
    observations = ChoiceMap.d(obs_dict)
    
    # Create target distribution (posterior)
    posterior_target = Target(
        dpmm_model,
        (alpha, mu0, sig0, sigx),
        observations
    )
    
    # Run importance sampling with SMALL k_particles
    # Note: k_particles is the number of particles PER CALL to random_weighted
    # Total particles = k_particles * num_samples
    # Using k_particles=10 means 10 * 50 = 500 particles total (reasonable)
    alg = ImportanceK(posterior_target, k_particles=10)
    
    # Generate samples sequentially
    all_mus = []
    all_thetas = []
    all_zs = []
    log_weights = []
    
    keys = random.split(key, num_samples)
    
    print(f"   ‚è≥ Generating {num_samples} samples (10 particles each)...")
    for i in range(num_samples):
        if i % 10 == 0:
            print(f"      Sample {i}/{num_samples}...")
        
        # random_weighted returns (log_weight, choice_map)
        log_weight, choice_map = alg.random_weighted(keys[i], posterior_target)
        
        # Extract values from the choice map
        mus_i = jnp.array([choice_map[f"mu_{k}"] for k in range(K)])
        thetas_i = choice_map["thetas"]
        zs_i = jnp.array([choice_map[f"z_{j}"] for j in range(N)])
        
        all_mus.append(mus_i)
        all_thetas.append(thetas_i)
        all_zs.append(zs_i)
        log_weights.append(float(log_weight))
    
    # Compute normalized importance weights
    log_weights = jnp.array(log_weights)
    log_weights = log_weights - jax.scipy.special.logsumexp(log_weights)
    weights = jnp.exp(log_weights)
    
    return jnp.array(all_mus), jnp.array(all_thetas), jnp.array(all_zs), weights

def importance_resampling_genjax(key, mus_samples, thetas_samples, weights, num_resamples):
    """Resample from importance samples to get posterior samples."""
    indices = random.choice(key, len(weights), shape=(num_resamples,), p=weights)
    return mus_samples[indices], thetas_samples[indices]

print("‚úì GenJAX DPMM model factory defined")
print("‚úì GenJAX importance sampling functions defined")
print("‚ÑπÔ∏è  Using k_particles=10 per sample (avoids timeout)")
print("‚ÑπÔ∏è  Progress messages show sampling progress")

## Interactive Exploration

Now let's create an interactive function to explore how DPMMs work! The cell below defines `run_dpmm_inference()` which you can call with different parameters.

In [None]:
# Default observed data (3 clear clusters)
default_data = np.array([-10.4, -10., -9.4, -10.1, -9.9, 0., 9.5, 9.9, 10., 10.1, 10.5])

def run_dpmm_inference(alpha=2.0, K_max=10, num_samples=500, data_str=None):
    """Run DPMM inference with GenJAX and visualize results.
    
    Parameters:
    -----------
    alpha : float
        Concentration parameter (0.1 to 10.0)
    K_max : int
        Maximum number of clusters (3 to 20)
    num_samples : int
        Number of importance samples (100 to 2000)
    data_str : str, optional
        Comma-separated data points. If None, uses default 3-cluster data.
    """
    # Parse data
    if data_str is None or data_str == "":
        obs_xs = jnp.array(default_data)
    else:
        try:
            obs_xs = jnp.array([float(x.strip()) for x in data_str.split(',')])
        except:
            print("‚ö†Ô∏è  Error parsing data, using default")
            obs_xs = jnp.array(default_data)
    
    N = len(obs_xs)
    print(f"üîÑ Running GenJAX DPMM inference with Œ±={alpha:.1f}, K_max={K_max}, {num_samples} samples...")
    print(f"   Data: {N} observations")
    
    # Fixed hyperparameters
    mu0 = 0.0
    sig0 = 4.0
    sigx = 0.05
    
    # Run GenJAX importance sampling
    print(f"   ‚è≥ Sampling from GenJAX model with ImportanceK...")
    key_local = random.PRNGKey(np.random.randint(0, 10000))
    mus_samples, thetas_samples, zs_samples, weights = importance_sampling_genjax(
        key_local, obs_xs, alpha, K_max, mu0, sig0, sigx, num_samples
    )
    
    # Resample to get posterior samples
    print(f"   ‚è≥ Resampling to get posterior...")
    key_resample = random.PRNGKey(np.random.randint(0, 10000))
    mus_post, thetas_post = importance_resampling_genjax(
        key_resample, mus_samples, thetas_samples, weights, num_samples
    )
    
    # Generate posterior predictive samples
    print(f"   ‚è≥ Generating posterior predictive samples...")
    key_pred = random.PRNGKey(np.random.randint(0, 10000))
    pred_samples = []
    for i in range(num_samples):
        z = random.categorical(key_pred, jnp.log(thetas_post[i]))
        x = random.normal(key_pred) * sigx + mus_post[i, z]
        pred_samples.append(x)
        key_pred = random.split(key_pred, 1)[0]
    
    pred_samples = jnp.array(pred_samples)
    
    # Flatten for visualization
    mus_flat = mus_post.flatten()
    thetas_flat = thetas_post.flatten()
    
    print("   ‚úÖ Inference complete!\n")
    
    # Create visualization
    fig, ax = plt.subplots(figsize=(12, 6))
    
    # Histogram of observed data
    ax.hist(obs_xs, bins=max(3, int(np.sqrt(N))), density=True, alpha=0.5, 
            color='gray', label='Observed data', edgecolor='black')
    
    # Posterior predictive density
    x_range = np.linspace(float(obs_xs.min()) - 2, float(obs_xs.max()) + 2, 200)
    kde_pred = gaussian_kde(np.array(pred_samples))
    ax.plot(x_range, kde_pred(x_range), 'b-', linewidth=3, 
            label='Posterior predictive p(xÃÇ|data)', alpha=0.8)
    
    # Posterior of cluster centers (weighted by mixture probabilities)
    kde_mus = gaussian_kde(np.array(mus_flat), weights=np.array(thetas_flat))
    ax.plot(x_range, kde_mus(x_range), 'r-', linewidth=3,
            label='Posterior p(Œº|data)', alpha=0.8)
    
    ax.set_xlabel('x', fontsize=12)
    ax.set_ylabel('Density', fontsize=12)
    ax.set_title(f'GenJAX DPMM Inference (Œ±={alpha:.1f}, K_max={K_max})', fontsize=14, fontweight='bold')
    ax.legend(fontsize=11)
    ax.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    effective_clusters = jnp.sum(thetas_post.mean(axis=0) > 0.01)
    print(f"üìä Estimated active clusters: {int(effective_clusters)} (with Œ∏ > 0.01)")
    print(f"üìä Top 3 cluster weights: {[f'{w:.3f}' for w in sorted(thetas_post.mean(axis=0), reverse=True)[:3]]}")

print("‚úì GenJAX DPMM inference function defined")

# Interactive widget interface
from ipywidgets import interact, FloatSlider, IntSlider, Text

print("üéõÔ∏è  Interactive DPMM Explorer")
print("="*70)
print("Adjust the sliders below and the visualization will update automatically!\n")

interact(
    run_dpmm_inference,
    alpha=FloatSlider(value=2.0, min=0.1, max=10.0, step=0.1, description='Œ± (concentration):'),
    K_max=IntSlider(value=10, min=3, max=20, step=1, description='K_max (clusters):'),
    num_samples=IntSlider(value=500, min=100, max=2000, step=100, description='Samples:'),
    data_str=Text(value='', description='Custom data:', placeholder='Leave empty for default, or enter: -5, -4.8, 5, 4.8')
);

## Exercises

Use the interactive sliders above or call the function directly to try these experiments:

1. **Effect of Œ±**:
   - Set Œ±=0.5: How many clusters are active?
   - Set Œ±=5.0: How does the posterior change?

2. **Different data** (enter in Custom data field or call function):
   - Two clusters: `-5, -4.8, -5.2, 5, 4.8, 5.2`
   - Four clusters: `-10, -9, 0, 1, 10, 11, 20, 21`
   - Single cluster: `0, 0.1, -0.1, 0.2, -0.2`

3. **Truncation level**:
   - Use default data (3 clusters) but set K_max=20
   - What happens to the unused clusters?

4. **Sample size**:
   - Run with 100 samples vs 1000 samples
   - How does it affect the smoothness of posteriors?

## Key Insights

1. **Automatic discovery**: DPMMs automatically discover the number of clusters without specifying K in advance

2. **Concentration parameter**: Œ± controls the "richness" of the mixture:
   - Small Œ± ‚Üí Few large clusters (concentrated)
   - Large Œ± ‚Üí Many small clusters (dispersed)

3. **Posterior uncertainty**: The red curve shows uncertainty about cluster locations, not just point estimates

4. **Predictive distribution**: The blue curve shows what new data might look like, accounting for both parameter uncertainty and cluster structure

## Connection to Tutorial

This notebook demonstrates the concepts from **Chapter 6: Dirichlet Process Mixture Models** in the tutorial. See the tutorial for:
- Detailed explanation of stick-breaking
- Chinese Restaurant Process interpretation
- Comparison with fixed-K GMMs
- GenJAX implementation details

## Advanced: GenJAX Implementation (Optional)

The implementation above uses pure JAX for simplicity and reliability. For those interested in using GenJAX (Gen's JAX backend), here's how you would implement the same model using generative functions.

**Note**: GenJAX requires `numpy<2.0.0` which conflicts with some Colab packages. The pure JAX implementation above is recommended for most users.

In [None]:
# GenJAX implementation (requires: pip install genjax)
# Uncomment to install: !pip install "genjax" "numpy<2.0.0"

# import genjax
# from genjax import gen, static_check, Pytree

# @gen
# def dpmm_model(alpha: float, K: int, N: int, mu0: float, sig0: float, sigx: float):
#     """DPMM generative model in GenJAX."""
#     # Sample cluster centers
#     mus = jnp.zeros(K)
#     for k in range(K):
#         mus = mus.at[k].set(genjax.normal(mu0, sig0) @ f"mu_{k}")
    
#     # Stick-breaking process
#     pis = jnp.zeros(K)
#     remaining = 1.0
#     for k in range(K):
#         beta_k = genjax.beta(1.0, alpha) @ f"beta_{k}"
#         pis = pis.at[k].set(beta_k * remaining)
#         remaining *= (1.0 - beta_k)
    
#     # Normalize to get mixture weights
#     pis = pis / jnp.sum(pis)
#     thetas = genjax.dirichlet(pis * 100) @ "thetas"
    
#     # Sample observations
#     xs = jnp.zeros(N)
#     for i in range(N):
#         z = genjax.categorical(thetas) @ f"z_{i}"
#         x = genjax.normal(mus[z], sigx) @ f"x_{i}"
#         xs = xs.at[i].set(x)
    
#     return xs

# @gen
# def dpmm_inference(observed_data, alpha, K):
#     """Run importance resampling inference."""
#     # Create observations dict
#     observations = {f"x_{i}": x for i, x in enumerate(observed_data)}
    
#     # Importance sampling
#     num_samples = 500
#     traces = []
#     log_weights = []
    
#     for _ in range(num_samples):
#         trace = genjax.simulate(dpmm_model, (alpha, K, len(observed_data), 0.0, 4.0, 0.05))
#         conditioned = genjax.condition(trace, observations)
#         traces.append(conditioned)
#         log_weights.append(conditioned.get_score())
    
#     # Normalize weights and resample
#     weights = jax.nn.softmax(jnp.array(log_weights))
#     indices = random.choice(random.PRNGKey(0), len(weights), shape=(num_samples,), p=weights)
    
#     return [traces[i] for i in indices]

# # Example usage:
# # posterior_traces = dpmm_inference(default_data, alpha=2.0, K=10)
# # mus_posterior = jnp.array([t["mu_0"] for t in posterior_traces])

print("GenJAX implementation shown (commented out due to dependency conflicts)")
print("The pure JAX implementation above is functionally equivalent and more reliable.")