<a href="https://colab.research.google.com/github/handley-lab/workshop-blackjax-nested-sampling/blob/main/blackjax_nested_sampling_workshop.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# BlackJAX Nested Sampling Workshop

## GPU-Native Nested Sampling for Modern SBI

Welcome to this hands-on workshop on nested sampling with BlackJAX! We'll explore how to leverage GPU-native nested sampling for simulation-based inference.

**Workshop Overview:**
- Why nested sampling for SBI?
- BlackJAX nested sampling basics
- Professional visualization with Anesthetic
- Performance comparison: nested sampling vs. ensemble samplers

**Duration:** ~45 minutes

## 1. Setup & Installation

First, let's install the required packages for this workshop:

In [None]:
# Install required packages
!pip install blackjax[all] anesthetic emcee corner

# Standard imports
import jax
import jax.numpy as jnp
import jax.random as random
import numpy as np
import matplotlib.pyplot as plt
import time
from typing import Callable

# Configure JAX for better performance
jax.config.update("jax_enable_x64", True)  # Higher precision
print(f"JAX devices: {jax.devices()}")
print(f"JAX version: {jax.__version__}")

## 2. Why Nested Sampling for SBI?

**Key Insight:** Almost all SBI methods (NLE, NRE, NJE) except NPE require a **sampler** to draw from posterior distributions.

**The Problem:** Traditional samplers are either:
- Legacy Fortran codes (MultiNest, PolyChord) - hard to modify/extend
- Slow Python implementations (dynesty, ultranest, nautilus)

**BlackJAX Solution:**
- GPU-native implementation
- Open source & community-driven
- JAX benefits: autodiff + JIT compilation
- Seamless integration with modern ML workflows

## 3. Example 1: Linear Regression with Nested Sampling

Let's start with a simple but illustrative example: fitting a straight line to noisy data.

In [None]:
# Generate synthetic data
key = random.PRNGKey(42)
n_data = 50

# True parameters
true_slope = 2.5
true_intercept = 1.0
true_sigma = 0.3

# Generate data
x = jnp.linspace(0, 5, n_data)
key, subkey = random.split(key)
y_true = true_slope * x + true_intercept
y = y_true + true_sigma * random.normal(subkey, (n_data,))

# Plot the data
plt.figure(figsize=(8, 6))
plt.scatter(x, y, alpha=0.7, label='Data')
plt.plot(x, y_true, 'r--', label=f'True: y = {true_slope}x + {true_intercept}')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.title('Linear Regression Data')
plt.grid(True, alpha=0.3)
plt.show()

### Define the Bayesian Model

We need to define:
1. **Likelihood function**: How well our model explains the data
2. **Prior distributions**: Our beliefs about parameter ranges
3. **Log probability**: Combined likelihood + prior

In [None]:
import blackjax
from jax.scipy import stats

def log_likelihood(params, x_data, y_data):
    """Log likelihood for linear regression."""
    slope, intercept, log_sigma = params
    sigma = jnp.exp(log_sigma)  # Ensure sigma > 0
    
    # Model prediction
    y_pred = slope * x_data + intercept
    
    # Gaussian likelihood
    return jnp.sum(stats.norm.logpdf(y_data, y_pred, sigma))

def log_prior(params):
    """Log prior for parameters."""
    slope, intercept, log_sigma = params
    
    # Priors: slope ~ N(0, 10), intercept ~ N(0, 10), log_sigma ~ N(0, 1)
    return (stats.norm.logpdf(slope, 0, 10) + 
            stats.norm.logpdf(intercept, 0, 10) + 
            stats.norm.logpdf(log_sigma, 0, 1))

def log_probability(params):
    """Log posterior probability."""
    lp = log_prior(params)
    if not jnp.isfinite(lp):
        return -jnp.inf
    return lp + log_likelihood(params, x, y)

# Test the log probability function
test_params = jnp.array([2.0, 1.0, jnp.log(0.5)])
print(f"Test log probability: {log_probability(test_params):.3f}")

### Run BlackJAX Nested Sampling

Now let's use BlackJAX to sample from our posterior distribution:

In [None]:
# Define prior sampling function for nested sampling
def prior_sampler(key, n_samples):
    """Sample from the prior distribution."""
    keys = random.split(key, 3)
    slope = random.normal(keys[0], (n_samples,)) * 10.0
    intercept = random.normal(keys[1], (n_samples,)) * 10.0 
    log_sigma = random.normal(keys[2], (n_samples,))
    return jnp.column_stack([slope, intercept, log_sigma])

# Configure nested sampling
n_live = 500  # Number of live points
n_dim = 3     # Number of parameters

# Initialize nested sampling algorithm
ns_algorithm = blackjax.nested_sampling(
    log_probability,
    prior_sampler,
    n_live_points=n_live
)

# Initialize state
key, subkey = random.split(key)
initial_state = ns_algorithm.init(subkey)

print("Starting nested sampling...")
start_time = time.time()

# Run nested sampling
key, subkey = random.split(key)
samples, log_evidence = blackjax.nested_sampling_inference(
    ns_algorithm,
    initial_state,
    subkey,
    max_samples=5000
)

end_time = time.time()
print(f"Nested sampling completed in {end_time - start_time:.2f} seconds")
print(f"Log evidence: {log_evidence:.3f}")
print(f"Number of samples: {samples.shape[0]}")

### Analyze Results

Let's examine our posterior samples and compare with the true values:

In [None]:
# Extract parameter samples
slope_samples = samples[:, 0]
intercept_samples = samples[:, 1]
log_sigma_samples = samples[:, 2]
sigma_samples = jnp.exp(log_sigma_samples)

# Calculate summary statistics
def summarize_parameter(samples, true_value, name):
    mean = jnp.mean(samples)
    std = jnp.std(samples)
    q16, q84 = jnp.percentile(samples, [16, 84])
    print(f"{name:12s}: {mean:.3f} ± {std:.3f} [{q16:.3f}, {q84:.3f}] (true: {true_value:.3f})")
    return mean, std

print("\nPosterior Summary:")
print("=" * 50)
slope_mean, slope_std = summarize_parameter(slope_samples, true_slope, "Slope")
intercept_mean, intercept_std = summarize_parameter(intercept_samples, true_intercept, "Intercept") 
sigma_mean, sigma_std = summarize_parameter(sigma_samples, true_sigma, "Sigma")

## 4. Professional Visualization with Anesthetic

Anesthetic is a powerful library designed specifically for visualizing nested sampling results:

In [None]:
import anesthetic

# Create anesthetic samples object
parameter_names = ['slope', 'intercept', 'log_sigma']
anesthetic_samples = anesthetic.NestedSamples(
    data=np.column_stack([slope_samples, intercept_samples, log_sigma_samples]),
    columns=parameter_names
)

# Create corner plot
fig, axes = anesthetic_samples.plot_2d(
    parameter_names,
    figsize=(10, 10)
)

# Add true values
true_values = [true_slope, true_intercept, jnp.log(true_sigma)]
for i, ax in enumerate(axes.diagonal()):
    ax.axvline(true_values[i], color='red', linestyle='--', label='True value')

for i in range(len(parameter_names)):
    for j in range(i):
        axes[i, j].scatter(true_values[j], true_values[i], 
                          color='red', marker='x', s=100, label='True value')

plt.suptitle('Posterior Distribution - Linear Regression', fontsize=16)
plt.tight_layout()
plt.show()

### Posterior Predictive Checks

Let's visualize how well our model fits the data:

In [None]:
# Generate posterior predictive samples
n_pred_samples = 100
x_test = jnp.linspace(0, 5, 100)

# Sample random posterior draws
key, subkey = random.split(key)
idx = random.choice(subkey, len(slope_samples), (n_pred_samples,))

plt.figure(figsize=(10, 6))

# Plot posterior predictive lines
for i in range(min(50, n_pred_samples)):  # Plot subset for clarity
    y_pred = slope_samples[idx[i]] * x_test + intercept_samples[idx[i]]
    plt.plot(x_test, y_pred, 'b-', alpha=0.1)

# Plot data and true line
plt.scatter(x, y, alpha=0.7, color='black', label='Data', zorder=5)
plt.plot(x_test, true_slope * x_test + true_intercept, 'r--', 
         linewidth=2, label='True model', zorder=4)

# Plot posterior mean
y_mean = slope_mean * x_test + intercept_mean
plt.plot(x_test, y_mean, 'g-', linewidth=2, label='Posterior mean', zorder=3)

plt.xlabel('x')
plt.ylabel('y')
plt.title('Posterior Predictive Check')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

## 5. Example 2: Multimodal Distribution

Nested sampling excels at multimodal distributions. Let's demonstrate with a mixture of Gaussians:

In [None]:
def log_prob_mixture(params):
    """Log probability of a mixture of two Gaussians."""
    x, y = params
    
    # Two Gaussian components
    comp1 = stats.multivariate_normal.logpdf(
        jnp.array([x, y]), 
        jnp.array([-2.0, -2.0]), 
        jnp.eye(2) * 0.5
    )
    comp2 = stats.multivariate_normal.logpdf(
        jnp.array([x, y]), 
        jnp.array([2.0, 2.0]), 
        jnp.eye(2) * 0.3
    )
    
    # Log sum of exponentials (mixture)
    return jnp.logaddexp(comp1 + jnp.log(0.6), comp2 + jnp.log(0.4))

def mixture_prior_sampler(key, n_samples):
    """Sample from uniform prior on [-5, 5] x [-5, 5]."""
    return random.uniform(key, (n_samples, 2), minval=-5.0, maxval=5.0)

# Run nested sampling on mixture
ns_mixture = blackjax.nested_sampling(
    log_prob_mixture,
    mixture_prior_sampler,
    n_live_points=1000
)

key, subkey = random.split(key)
initial_state_mixture = ns_mixture.init(subkey)

print("Running nested sampling on multimodal distribution...")
start_time = time.time()

key, subkey = random.split(key)
mixture_samples, mixture_log_evidence = blackjax.nested_sampling_inference(
    ns_mixture,
    initial_state_mixture, 
    subkey,
    max_samples=8000
)

end_time = time.time()
print(f"Completed in {end_time - start_time:.2f} seconds")
print(f"Log evidence: {mixture_log_evidence:.3f}")

In [None]:
# Visualize the multimodal posterior
plt.figure(figsize=(12, 5))

# Plot 1: Scatter plot of samples
plt.subplot(1, 2, 1)
plt.scatter(mixture_samples[:, 0], mixture_samples[:, 1], 
           alpha=0.6, s=10, c='blue')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Nested Sampling: Multimodal Distribution')
plt.grid(True, alpha=0.3)

# Plot 2: Contour plot
plt.subplot(1, 2, 2)
x_grid = jnp.linspace(-5, 5, 50)
y_grid = jnp.linspace(-5, 5, 50)
X, Y = jnp.meshgrid(x_grid, y_grid)
Z = jnp.zeros_like(X)

for i in range(len(x_grid)):
    for j in range(len(y_grid)):
        Z = Z.at[j, i].set(jnp.exp(log_prob_mixture([X[j, i], Y[j, i]])))

plt.contour(X, Y, Z, levels=10, colors='red', alpha=0.7)
plt.scatter(mixture_samples[:, 0], mixture_samples[:, 1], 
           alpha=0.4, s=5, c='blue')
plt.xlabel('x')
plt.ylabel('y')
plt.title('True Distribution (contours) vs Samples')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 6. Performance Comparison: Nested Sampling vs. Ensemble Sampling

Let's compare BlackJAX nested sampling with a traditional ensemble sampler (emcee-style AIES):

In [None]:
# First, let's run emcee on our linear regression problem
import emcee

def log_prob_emcee(params):
    """Log probability function for emcee (numpy version)."""
    return float(log_probability(jnp.array(params)))

# Set up emcee sampler
n_walkers = 32
n_dim = 3
n_steps = 2000
n_burn = 500

# Initialize walkers near the maximum likelihood estimate
initial_guess = np.array([2.0, 1.0, np.log(0.5)])
pos = initial_guess + 0.1 * np.random.randn(n_walkers, n_dim)

# Run emcee
print("Running emcee (AIES)...")
start_time = time.time()

sampler = emcee.EnsembleSampler(n_walkers, n_dim, log_prob_emcee)
sampler.run_mcmc(pos, n_steps, progress=True)

emcee_time = time.time() - start_time
print(f"Emcee completed in {emcee_time:.2f} seconds")

# Extract samples (after burn-in)
emcee_samples = sampler.get_chain(discard=n_burn, flat=True)
print(f"Emcee effective samples: {len(emcee_samples)}")

In [None]:
# Now let's compare the results
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

param_names = ['Slope', 'Intercept', 'log(σ)']
true_vals = [true_slope, true_intercept, jnp.log(true_sigma)]
colors = ['blue', 'orange']
labels = ['BlackJAX NS', 'emcee (AIES)']

# Plot marginal distributions
for i in range(3):
    # Histograms
    axes[0, i].hist(samples[:, i], bins=50, alpha=0.7, 
                   density=True, color=colors[0], label=labels[0])
    axes[0, i].hist(emcee_samples[:, i], bins=50, alpha=0.7, 
                   density=True, color=colors[1], label=labels[1])
    axes[0, i].axvline(true_vals[i], color='red', linestyle='--', 
                      label='True value')
    axes[0, i].set_xlabel(param_names[i])
    axes[0, i].set_ylabel('Density')
    axes[0, i].legend()
    axes[0, i].grid(True, alpha=0.3)

# Plot traces (show convergence)
chain = sampler.get_chain()
for i in range(3):
    # Show a few walker traces for emcee
    for j in range(min(5, n_walkers)):
        axes[1, i].plot(chain[:, j, i], color=colors[1], alpha=0.3)
    axes[1, i].axhline(true_vals[i], color='red', linestyle='--')
    axes[1, i].set_xlabel('Step')
    axes[1, i].set_ylabel(param_names[i])
    axes[1, i].set_title('emcee Chains')
    axes[1, i].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print comparison summary
print("\nPerformance Comparison:")
print("=" * 40)
print(f"BlackJAX NS time: {end_time - start_time:.2f}s")
print(f"emcee time:        {emcee_time:.2f}s")
print(f"BlackJAX samples:  {len(samples)}")
print(f"emcee samples:     {len(emcee_samples)}")
print(f"\nBlackJAX also provides log evidence: {log_evidence:.3f}")

## 7. Your Turn: Experiment!

Now it's time to experiment with BlackJAX nested sampling. Here are some suggestions:

### Option 1: Modify the Linear Model
Try adding more complexity to our linear regression:

In [None]:
# TODO: Try a polynomial model
# y = a*x^2 + b*x + c + noise

# Your code here!
pass

### Option 2: Bring Your Own Model
If you have JAX code from Viraj's workshop, try applying nested sampling to it:

In [None]:
# TODO: Apply nested sampling to your own problem
# 1. Define your log_probability function
# 2. Define your prior_sampler function  
# 3. Run nested sampling
# 4. Visualize with anesthetic

# Your code here!
pass

### Option 3: Advanced Features
Explore more BlackJAX nested sampling features:

In [None]:
# TODO: Try different nested sampling variants
# - Slice sampling within nested sampling
# - Different proposal mechanisms
# - Adaptive live point allocation

# Example: Using slice sampling for proposals
# ns_algorithm = blackjax.nested_sampling(
#     log_probability,
#     prior_sampler,
#     n_live_points=n_live,
#     mcmc=blackjax.slice
# )

# Your experiments here!
pass

## 8. Key Takeaways

🎯 **What we've learned:**

1. **Nested sampling is essential** for most SBI methods (except NPE)
2. **BlackJAX provides GPU-native** nested sampling with JAX benefits
3. **Anesthetic makes visualization easy** and professional
4. **Performance gains** come from JIT compilation and GPU acceleration
5. **Model evidence** is a bonus feature for model comparison

🚀 **Next steps:**
- Try BlackJAX nested sampling on your own problems
- Explore the [BlackJAX documentation](https://github.com/handley-lab/blackjax)
- Use [Anesthetic](https://anesthetic.readthedocs.io) for publication-quality plots
- Join the community discussion on GitHub

💡 **Remember:** The future of scientific computing is GPU-native. BlackJAX positions you at the forefront of modern Bayesian inference!

---

**Questions? Suggestions? Let's discuss!** 🗣️