# BlackJAX Nested Sampling Workshop
## GPU-Native Bayesian Inference for SBI

**SBI Galev 2025 Workshop**  
*Will Handley, University of Cambridge*

---

### Workshop Overview

In this hands-on workshop, you'll learn to use **BlackJAX nested sampling** - a GPU-native implementation that leverages JAX's autodiff and JIT compilation for modern Bayesian inference.

**Learning Objectives:**
1. Understand when nested sampling excels over MCMC methods
2. Implement nested sampling with BlackJAX
3. Visualize results with Anesthetic
4. Compare performance: nested sampling vs. affine invariant ensemble sampling
5. Leverage GPU acceleration for scientific inference

**Why BlackJAX Nested Sampling?**
- **GPU-native**: Full JAX integration (autodiff + JIT compilation)
- **Open source**: Community-owned alternative to legacy Fortran tools
- **Modern**: Designed for SBI workflows and scientific computing
- **Efficient**: Handles multimodal posteriors and evidence computation


## 1. Setup and Installation

We'll install BlackJAX from the nested sampling branch and set up our environment for GPU-accelerated inference.

In [None]:
# Install BlackJAX nested sampling branch and visualization tools
!pip install git+https://github.com/handley-lab/blackjax@nested_sampling
!pip install anesthetic matplotlib corner tqdm

# Check for GPU availability
import jax
print(f"JAX devices: {jax.devices()}")
print(f"JAX backend: {jax.lib.xla_bridge.get_backend().platform}")

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from functools import partial
import time

# BlackJAX imports
import blackjax
import blackjax.ns.utils as ns_utils

# Visualization
from anesthetic import NestedSamples
import corner

# JAX configuration for precision and reproducibility
jax.config.update("jax_enable_x64", True)
jax.config.update('jax_platform_name', 'cpu')  # Change to 'gpu' if available

# Set random seed
np.random.seed(42)
rng_key = jax.random.key(42)

## 2. Problem Setup: 2D Gaussian Parameter Inference

We'll revisit the **2D Gaussian inference problem** from Viraj's JAX/SciML workshop, but this time tackle it with nested sampling.

**Problem**: Infer parameters of a 2D Gaussian from noisy pixelated observations
- **Parameters**: [μₓ, μᵧ, σₓ, σᵧ, ρₓᵧ] (5D parameter space)
- **Data**: 50×50 pixel noisy images of 2D Gaussians
- **Challenge**: Multimodal posterior due to parameter symmetries

In [None]:
# Problem configuration
image_size = 50
x = jnp.linspace(-3, 3, image_size)
y = jnp.linspace(-3, 3, image_size)
X, Y = jnp.meshgrid(x, y)
coords = jnp.stack([X.ravel(), Y.ravel()]).T  # (2500, 2)

# True parameters for data generation
true_params = jnp.array([0.5, -0.3, 1.2, 0.8, 0.4])  # [μₓ, μᵧ, σₓ, σᵧ, ρₓᵧ]

@jax.jit
def params_to_cov(params):
    """Convert parameters to mean vector and covariance matrix."""
    mu_x, mu_y, sigma_x, sigma_y, rho = params
    
    mean = jnp.array([mu_x, mu_y])
    
    # Covariance matrix
    cov = jnp.array([
        [sigma_x**2, rho * sigma_x * sigma_y],
        [rho * sigma_x * sigma_y, sigma_y**2]
    ])
    
    return mean, cov

@jax.jit
def simulator(params, rng_key, noise_sigma=0.1):
    """Simulate 2D Gaussian observation with noise."""
    mean, cov = params_to_cov(params)
    
    # Evaluate multivariate normal PDF on grid
    logpdf = jax.scipy.stats.multivariate_normal.logpdf(coords, mean, cov)
    image_clean = logpdf.reshape(image_size, image_size)
    image_clean = jnp.exp(image_clean - jnp.max(image_clean))  # Normalize
    
    # Add Gaussian noise
    noise = jax.random.normal(rng_key, image_clean.shape) * noise_sigma
    
    return image_clean + noise

# Generate observed data
rng_key, subkey = jax.random.split(rng_key)
observed_data = simulator(true_params, subkey)

# Visualize the observed data
plt.figure(figsize=(8, 6))
plt.imshow(observed_data, origin='lower', extent=[-3, 3, -3, 3], cmap='viridis')
plt.colorbar(label='Intensity')
plt.title('Observed 2D Gaussian (with noise)')
plt.xlabel('X')
plt.ylabel('Y')
plt.show()

print(f"True parameters: μₓ={true_params[0]:.2f}, μᵧ={true_params[1]:.2f}, "
      f"σₓ={true_params[2]:.2f}, σᵧ={true_params[3]:.2f}, ρ={true_params[4]:.2f}")

## 3. Likelihood and Prior Setup

Define the likelihood function and prior distributions for Bayesian inference.

In [None]:
@jax.jit
def loglikelihood_fn(params):
    """Log-likelihood function for parameter inference."""
    # Simulate clean image from parameters
    mean, cov = params_to_cov(params)
    
    # Check if covariance matrix is positive definite
    det_cov = jnp.linalg.det(cov)
    
    # Return -inf if covariance is not positive definite
    def valid_cov():
        logpdf = jax.scipy.stats.multivariate_normal.logpdf(coords, mean, cov)
        image_pred = logpdf.reshape(image_size, image_size)
        image_pred = jnp.exp(image_pred - jnp.max(image_pred))
        
        # Gaussian likelihood (MSE)
        noise_sigma = 0.1
        residuals = (observed_data - image_pred) / noise_sigma
        return -0.5 * jnp.sum(residuals**2)
    
    def invalid_cov():
        return -jnp.inf
    
    return jax.lax.cond(det_cov > 0, valid_cov, invalid_cov)

# Prior bounds
prior_bounds = {
    "mu_x": (-2.0, 2.0),
    "mu_y": (-2.0, 2.0), 
    "sigma_x": (0.5, 3.0),
    "sigma_y": (0.5, 3.0),
    "rho": (-0.99, 0.99)  # Correlation must be in (-1, 1)
}

# Test likelihood function
test_loglik = loglikelihood_fn(true_params)
print(f"Log-likelihood at true parameters: {test_loglik:.3f}")

## 4. BlackJAX Nested Sampling

Now let's implement nested sampling with BlackJAX. Nested sampling is particularly powerful for:
- **Multimodal posteriors** (common in parameter inference)
- **Evidence computation** (model comparison)
- **Efficient exploration** of complex parameter spaces

### Why Nested Sampling for This Problem?
The 2D Gaussian parameter inference has potential **symmetries** and **multimodality**:
- Parameter correlations can create complex posterior geometry
- Traditional MCMC (like HMC/NUTS) may struggle with mode mixing
- Nested sampling naturally handles these challenges

In [None]:
# Nested sampling configuration
num_live = 1000  # Number of live points (controls precision)
num_dims = 5     # Parameter dimensionality
num_inner_steps = num_dims * 5  # MCMC steps per NS iteration (3-5 × ndim)
num_delete = 50  # Parallelization parameter for GPU efficiency

print(f"Nested sampling configuration:")
print(f"  Live points: {num_live}")
print(f"  Inner MCMC steps: {num_inner_steps}")
print(f"  Parallel deletion: {num_delete}")

# Initialize uniform prior and live points
rng_key, subkey = jax.random.split(rng_key)
particles, logprior_fn = ns_utils.uniform_prior(
    subkey, num_live, prior_bounds
)

print(f"\nInitialized {particles.shape[0]} live points in {particles.shape[1]}D space")
print(f"Parameter bounds: {prior_bounds}")

In [None]:
# Create nested sampler
nested_sampler = blackjax.nss(
    logprior_fn=logprior_fn,
    loglikelihood_fn=loglikelihood_fn,
    num_delete=num_delete,
    num_inner_steps=num_inner_steps
)

# Initialize nested sampling state
rng_key, subkey = jax.random.split(rng_key)
live_state = nested_sampler.init(particles)

# JIT compile the sampling step for efficiency
jit_step = jax.jit(nested_sampler.step)

print(f"Initial evidence estimate: {live_state.logZ:.3f}")
print(f"Initial live evidence: {live_state.logZ_live:.3f}")
print("\nStarting nested sampling...")

# Run nested sampling until convergence
dead_points = []
iteration = 0
convergence_threshold = -3.0  # log(0.05) - stop when evidence contribution < 5%

start_time = time.time()

while (live_state.logZ_live - live_state.logZ) > convergence_threshold:
    rng_key, subkey = jax.random.split(rng_key)
    
    # Take nested sampling step
    live_state, dead_info = jit_step(subkey, live_state)
    dead_points.append(dead_info)
    
    iteration += 1
    
    # Progress updates
    if iteration % 50 == 0:
        remaining_evidence = live_state.logZ_live - live_state.logZ
        print(f"Iteration {iteration:4d}: logZ = {live_state.logZ:.3f}, "
              f"remaining = {remaining_evidence:.3f}")

sampling_time = time.time() - start_time

print(f"\nNested sampling completed!")
print(f"Total iterations: {iteration}")
print(f"Sampling time: {sampling_time:.2f} seconds")
print(f"Final evidence: logZ = {live_state.logZ:.3f} ± {jnp.sqrt(live_state.H):.3f}")

## 5. Results Processing and Visualization

Process the nested sampling results and create publication-quality visualizations with Anesthetic.

In [None]:
# Combine dead and live points
dead = ns_utils.finalise(live_state, dead_points)

# Extract samples and metadata
samples = dead.particles  # Parameter samples
logL = dead.loglikelihood  # Log-likelihood values
logL_birth = dead.logL_birth  # Birth likelihood thresholds

print(f"Total samples: {len(samples)}")
print(f"Evidence: logZ = {live_state.logZ:.3f} ± {jnp.sqrt(live_state.H):.3f}")
print(f"Information: H = {live_state.H:.3f} nats")

# Create NestedSamples object for anesthetic
param_names = ['μₓ', 'μᵧ', 'σₓ', 'σᵧ', 'ρ']
nested_samples = NestedSamples(
    data=samples,
    logL=logL,
    logL_birth=logL_birth,
    columns=param_names
)

# Posterior statistics
print("\nPosterior summary:")
for i, name in enumerate(param_names):
    mean = nested_samples[name].mean()
    std = nested_samples[name].std()
    true_val = true_params[i]
    print(f"  {name}: {mean:.3f} ± {std:.3f} (true: {true_val:.3f})")

In [None]:
# Create corner plot with anesthetic
fig, axes = plt.subplots(num_dims, num_dims, figsize=(12, 10))

# Corner plot with true parameters
nested_samples.plot_2d(
    axes,
    types={'diagonal': 'hist', 'lower': 'contour'},
    alpha=0.7
)

# Add true parameter values
for i in range(num_dims):
    # Diagonal (1D marginals)
    axes[i, i].axvline(true_params[i], color='red', linestyle='--', alpha=0.8, label='True')
    
    # Off-diagonal (2D marginals)
    for j in range(i):
        axes[i, j].axvline(true_params[j], color='red', linestyle='--', alpha=0.8)
        axes[i, j].axhline(true_params[i], color='red', linestyle='--', alpha=0.8)
        axes[i, j].plot(true_params[j], true_params[i], 'r*', markersize=10)

# Add legend
axes[0, 0].legend(loc='upper right')

plt.suptitle('BlackJAX Nested Sampling: 2D Gaussian Parameter Inference', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# Nested sampling diagnostic plots
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Evolution of evidence
axes[0, 0].plot(nested_samples.logL, nested_samples.logZ_p)
axes[0, 0].set_xlabel('Log-likelihood')
axes[0, 0].set_ylabel('Evidence (logZ)')
axes[0, 0].set_title('Evidence Evolution')
axes[0, 0].grid(True, alpha=0.3)

# Log-likelihood sequence
axes[0, 1].plot(logL)
axes[0, 1].set_xlabel('Iteration')
axes[0, 1].set_ylabel('Log-likelihood')
axes[0, 1].set_title('Likelihood Sequence')
axes[0, 1].grid(True, alpha=0.3)

# Posterior weights
weights = nested_samples.get_weights()
axes[1, 0].plot(weights)
axes[1, 0].set_xlabel('Sample index')
axes[1, 0].set_ylabel('Posterior weight')
axes[1, 0].set_title('Posterior Weights')
axes[1, 0].grid(True, alpha=0.3)

# Effective sample size evolution
n_eff = 1.0 / jnp.sum(weights**2)
axes[1, 1].text(0.5, 0.7, f'Effective samples: {n_eff:.0f}', 
               transform=axes[1, 1].transAxes, fontsize=12, ha='center')
axes[1, 1].text(0.5, 0.5, f'Total samples: {len(weights)}', 
               transform=axes[1, 1].transAxes, fontsize=12, ha='center')
axes[1, 1].text(0.5, 0.3, f'Efficiency: {n_eff/len(weights):.2%}', 
               transform=axes[1, 1].transAxes, fontsize=12, ha='center')
axes[1, 1].set_title('Sampling Efficiency')
axes[1, 1].set_xticks([])
axes[1, 1].set_yticks([])

plt.tight_layout()
plt.show()

print(f"\nNested Sampling Diagnostics:")
print(f"  Effective sample size: {n_eff:.0f} / {len(weights)} ({n_eff/len(weights):.1%})")
print(f"  Evidence uncertainty: ±{jnp.sqrt(live_state.H):.3f}")
print(f"  Information gain: {live_state.H:.3f} nats")

## 6. Performance Comparison: Nested Sampling vs. MCMC

Let's compare BlackJAX nested sampling with traditional MCMC methods (NUTS and AIES) on the same problem.

### Why This Comparison Matters:
- **NUTS**: State-of-the-art Hamiltonian Monte Carlo (requires gradients)
- **AIES**: Affine Invariant Ensemble Sampler (like `emcee`)
- **Nested Sampling**: Designed for multimodal posteriors and evidence computation

In [None]:
# NUTS (No-U-Turn Sampler) comparison
from blackjax.mcmc.nuts import build_kernel
from blackjax.mcmc.nuts import init as nuts_init

# Convert to constrained space for NUTS (requires differentiable transforms)
def transform_to_unconstrained(params):
    """Transform to unconstrained space for NUTS."""
    mu_x, mu_y, sigma_x, sigma_y, rho = params
    
    # Log transform for positive parameters
    log_sigma_x = jnp.log(sigma_x)
    log_sigma_y = jnp.log(sigma_y)
    
    # Logit transform for correlation [-1, 1] -> R
    logit_rho = jnp.log((rho + 1) / (2 - rho))
    
    return jnp.array([mu_x, mu_y, log_sigma_x, log_sigma_y, logit_rho])

def transform_to_constrained(unconstrained_params):
    """Transform back to constrained space."""
    mu_x, mu_y, log_sigma_x, log_sigma_y, logit_rho = unconstrained_params
    
    sigma_x = jnp.exp(log_sigma_x)
    sigma_y = jnp.exp(log_sigma_y)
    rho = 2 * jax.nn.sigmoid(logit_rho) - 1
    
    return jnp.array([mu_x, mu_y, sigma_x, sigma_y, rho])

@jax.jit
def unconstrained_logdensity(unconstrained_params):
    """Log-density in unconstrained space (for NUTS)."""
    constrained_params = transform_to_constrained(unconstrained_params)
    
    # Check parameter bounds
    mu_x, mu_y, sigma_x, sigma_y, rho = constrained_params
    
    # Prior constraints
    if not (-2 <= mu_x <= 2 and -2 <= mu_y <= 2 and 
            0.5 <= sigma_x <= 3 and 0.5 <= sigma_y <= 3 and
            -0.99 <= rho <= 0.99):
        return -jnp.inf
    
    loglik = loglikelihood_fn(constrained_params)
    
    # Add Jacobian correction for transforms
    log_sigma_x, log_sigma_y, logit_rho = unconstrained_params[2], unconstrained_params[3], unconstrained_params[4]
    jacobian = log_sigma_x + log_sigma_y + logit_rho - 2 * jnp.log(jnp.cosh(logit_rho))
    
    return loglik + jacobian

print("Setting up NUTS sampler...")

# NUTS configuration
nuts_kernel = build_kernel(unconstrained_logdensity)

# Initialize NUTS from a reasonable starting point
initial_unconstrained = transform_to_unconstrained(true_params + 0.1 * jax.random.normal(rng_key, (5,)))
rng_key, subkey = jax.random.split(rng_key)
nuts_state = nuts_init(initial_unconstrained, subkey)

print(f"NUTS initialized at: {transform_to_constrained(nuts_state.position)}")

In [None]:
# Run NUTS sampling
print("Running NUTS sampling...")

num_nuts_samples = 2000
num_warmup = 1000

jit_nuts_step = jax.jit(nuts_kernel)

start_time = time.time()

# Warmup phase
nuts_samples = []
current_state = nuts_state

for i in range(num_warmup + num_nuts_samples):
    rng_key, subkey = jax.random.split(rng_key)
    current_state, info = jit_nuts_step(subkey, current_state)
    
    # Keep samples after warmup
    if i >= num_warmup:
        constrained_sample = transform_to_constrained(current_state.position)
        nuts_samples.append(constrained_sample)
    
    if (i + 1) % 500 == 0:
        status = "warmup" if i < num_warmup else "sampling"
        print(f"NUTS {status}: {i + 1}/{num_warmup + num_nuts_samples}")

nuts_time = time.time() - start_time
nuts_samples = jnp.array(nuts_samples)

print(f"\nNUTS completed in {nuts_time:.2f} seconds")
print(f"Generated {len(nuts_samples)} samples")

# NUTS posterior statistics
print("\nNUTS posterior summary:")
for i, name in enumerate(param_names):
    mean = jnp.mean(nuts_samples[:, i])
    std = jnp.std(nuts_samples[:, i])
    true_val = true_params[i]
    print(f"  {name}: {mean:.3f} ± {std:.3f} (true: {true_val:.3f})")

In [None]:
# Affine Invariant Ensemble Sampler (AIES) comparison
print("Setting up AIES (emcee-like) sampler...")

# Use BlackJAX's implementation of ensemble sampling
try:
    from blackjax.mcmc.aies import init, build_kernel as aies_build_kernel
    
    # AIES configuration
    num_walkers = 50
    num_aies_steps = 2000
    
    @jax.jit
    def aies_logdensity(params):
        """Log-density for AIES (gradient-free)."""
        return loglikelihood_fn(params)  # Assuming flat prior within bounds
    
    # Initialize ensemble
    rng_key, subkey = jax.random.split(rng_key)
    
    # Initialize walkers around true parameters with some spread
    walker_init = []
    for i in range(num_walkers):
        rng_key, subkey = jax.random.split(rng_key)
        # Sample from prior bounds with some concentration around true values
        bounds_array = jnp.array([[v[0], v[1]] for v in prior_bounds.values()])
        walker = jax.random.uniform(subkey, (5,), 
                                   minval=bounds_array[:, 0], 
                                   maxval=bounds_array[:, 1])
        walker_init.append(walker)
    
    initial_ensemble = jnp.array(walker_init)
    
    aies_kernel = aies_build_kernel(aies_logdensity)
    aies_state = init(initial_ensemble)
    
    print(f"AIES initialized with {num_walkers} walkers")
    
    # Run AIES
    print("Running AIES sampling...")
    start_time = time.time()
    
    jit_aies_step = jax.jit(aies_kernel)
    
    aies_samples = []
    current_aies_state = aies_state
    
    for i in range(num_aies_steps):
        rng_key, subkey = jax.random.split(rng_key)
        current_aies_state, info = jit_aies_step(subkey, current_aies_state)
        
        # Store all walker positions
        aies_samples.append(current_aies_state.position.copy())
        
        if (i + 1) % 500 == 0:
            print(f"AIES step: {i + 1}/{num_aies_steps}")
    
    aies_time = time.time() - start_time
    aies_samples = jnp.array(aies_samples)
    
    # Flatten walker dimension
    aies_samples_flat = aies_samples.reshape(-1, 5)
    
    print(f"\nAIES completed in {aies_time:.2f} seconds")
    print(f"Generated {len(aies_samples_flat)} samples")
    
    # AIES posterior statistics
    print("\nAIES posterior summary:")
    for i, name in enumerate(param_names):
        mean = jnp.mean(aies_samples_flat[:, i])
        std = jnp.std(aies_samples_flat[:, i])
        true_val = true_params[i]
        print(f"  {name}: {mean:.3f} ± {std:.3f} (true: {true_val:.3f})")
    
    aies_available = True
    
except ImportError:
    print("AIES not available in this BlackJAX version")
    print("Skipping AIES comparison...")
    aies_available = False
    aies_time = 0
    aies_samples_flat = None

In [None]:
# Performance comparison visualization
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Timing comparison
methods = ['Nested Sampling', 'NUTS']
times = [sampling_time, nuts_time]
colors = ['blue', 'orange']

if aies_available:
    methods.append('AIES')
    times.append(aies_time)
    colors.append('green')

axes[0, 0].bar(methods, times, color=colors, alpha=0.7)
axes[0, 0].set_ylabel('Sampling Time (seconds)')
axes[0, 0].set_title('Computational Performance')
axes[0, 0].grid(True, alpha=0.3)

# Sample comparison for first two parameters
param_idx = [0, 1]  # μₓ, μᵧ

# Nested sampling
axes[0, 1].scatter(samples[:, param_idx[0]], samples[:, param_idx[1]], 
                  alpha=0.3, s=10, label='Nested Sampling', c='blue')
axes[0, 1].plot(true_params[param_idx[0]], true_params[param_idx[1]], 
               'r*', markersize=15, label='True')
axes[0, 1].set_xlabel(param_names[param_idx[0]])
axes[0, 1].set_ylabel(param_names[param_idx[1]])
axes[0, 1].set_title('Nested Sampling Posterior')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# NUTS
axes[0, 2].scatter(nuts_samples[:, param_idx[0]], nuts_samples[:, param_idx[1]], 
                  alpha=0.3, s=10, label='NUTS', c='orange')
axes[0, 2].plot(true_params[param_idx[0]], true_params[param_idx[1]], 
               'r*', markersize=15, label='True')
axes[0, 2].set_xlabel(param_names[param_idx[0]])
axes[0, 2].set_ylabel(param_names[param_idx[1]])
axes[0, 2].set_title('NUTS Posterior')
axes[0, 2].legend()
axes[0, 2].grid(True, alpha=0.3)

# Parameter accuracy comparison
ns_errors = jnp.abs(jnp.array([nested_samples[name].mean() for name in param_names]) - true_params)
nuts_errors = jnp.abs(jnp.mean(nuts_samples, axis=0) - true_params)

x_pos = jnp.arange(len(param_names))
width = 0.35

axes[1, 0].bar(x_pos - width/2, ns_errors, width, label='Nested Sampling', color='blue', alpha=0.7)
axes[1, 0].bar(x_pos + width/2, nuts_errors, width, label='NUTS', color='orange', alpha=0.7)

if aies_available:
    aies_errors = jnp.abs(jnp.mean(aies_samples_flat, axis=0) - true_params)
    axes[1, 0].bar(x_pos + 1.5*width, aies_errors, width, label='AIES', color='green', alpha=0.7)

axes[1, 0].set_xlabel('Parameters')
axes[1, 0].set_ylabel('Absolute Error')
axes[1, 0].set_title('Parameter Estimation Accuracy')
axes[1, 0].set_xticks(x_pos)
axes[1, 0].set_xticklabels(param_names)
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Marginal distributions comparison
param_to_plot = 4  # ρ parameter (most constrained)

axes[1, 1].hist(samples[:, param_to_plot], bins=30, alpha=0.7, 
               density=True, label='Nested Sampling', color='blue')
axes[1, 1].hist(nuts_samples[:, param_to_plot], bins=30, alpha=0.7, 
               density=True, label='NUTS', color='orange')

if aies_available:
    axes[1, 1].hist(aies_samples_flat[:, param_to_plot], bins=30, alpha=0.7, 
                   density=True, label='AIES', color='green')

axes[1, 1].axvline(true_params[param_to_plot], color='red', linestyle='--', 
                  linewidth=2, label='True')
axes[1, 1].set_xlabel(f'{param_names[param_to_plot]}')
axes[1, 1].set_ylabel('Density')
axes[1, 1].set_title(f'Marginal: {param_names[param_to_plot]}')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

# Summary statistics
summary_text = f"""Method Comparison Summary:

Nested Sampling:
• Time: {sampling_time:.1f}s
• Samples: {len(samples)}
• Evidence: logZ = {live_state.logZ:.2f}
• Handles multimodality

NUTS:
• Time: {nuts_time:.1f}s  
• Samples: {len(nuts_samples)}
• Requires gradients
• May miss modes
"""

if aies_available:
    summary_text += f"""\nAIES:
• Time: {aies_time:.1f}s
• Samples: {len(aies_samples_flat)}
• Gradient-free
• Ensemble method"""

axes[1, 2].text(0.05, 0.95, summary_text, transform=axes[1, 2].transAxes, 
               fontsize=10, verticalalignment='top', fontfamily='monospace')
axes[1, 2].set_xlim(0, 1)
axes[1, 2].set_ylim(0, 1)
axes[1, 2].axis('off')
axes[1, 2].set_title('Method Comparison')

plt.tight_layout()
plt.show()

## 7. GPU Acceleration and Performance

BlackJAX's GPU-native implementation provides significant speedups, especially for:
- **Large parameter spaces** (high-dimensional problems)
- **Complex likelihood evaluations** (neural networks, PDEs)
- **Parallel deletion** in nested sampling

Let's explore the performance characteristics and demonstrate GPU acceleration.

In [None]:
# Performance scaling experiment
print("Testing performance scaling with different configurations...")

# Test different numbers of live points
live_point_configs = [100, 500, 1000, 2000]
performance_results = []

for num_live_test in live_point_configs:
    print(f"\nTesting with {num_live_test} live points...")
    
    # Initialize nested sampling
    rng_key, subkey = jax.random.split(rng_key)
    particles_test, logprior_fn_test = ns_utils.uniform_prior(
        subkey, num_live_test, prior_bounds
    )
    
    nested_sampler_test = blackjax.nss(
        logprior_fn=logprior_fn_test,
        loglikelihood_fn=loglikelihood_fn,
        num_delete=min(50, num_live_test // 10),  # Scale deletion parameter
        num_inner_steps=25
    )
    
    # Initialize and run a few steps for timing
    live_state_test = nested_sampler_test.init(particles_test)
    jit_step_test = jax.jit(nested_sampler_test.step)
    
    # Warm up JIT compilation
    rng_key, subkey = jax.random.split(rng_key)
    _, _ = jit_step_test(subkey, live_state_test)
    
    # Time multiple steps
    num_test_steps = 10
    start_time = time.time()
    
    current_state = live_state_test
    for _ in range(num_test_steps):
        rng_key, subkey = jax.random.split(rng_key)
        current_state, _ = jit_step_test(subkey, current_state)
    
    step_time = (time.time() - start_time) / num_test_steps
    
    performance_results.append({
        'num_live': num_live_test,
        'step_time': step_time,
        'throughput': num_live_test / step_time
    })
    
    print(f"  Step time: {step_time:.4f}s")
    print(f"  Throughput: {num_live_test / step_time:.0f} points/second")

# Plot performance scaling
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

num_live_points = [r['num_live'] for r in performance_results]
step_times = [r['step_time'] for r in performance_results]
throughputs = [r['throughput'] for r in performance_results]

# Step time scaling
ax1.plot(num_live_points, step_times, 'o-', linewidth=2, markersize=8)
ax1.set_xlabel('Number of Live Points')
ax1.set_ylabel('Step Time (seconds)')
ax1.set_title('Nested Sampling Step Time Scaling')
ax1.grid(True, alpha=0.3)
ax1.set_xscale('log')
ax1.set_yscale('log')

# Throughput scaling
ax2.plot(num_live_points, throughputs, 's-', color='orange', linewidth=2, markersize=8)
ax2.set_xlabel('Number of Live Points')
ax2.set_ylabel('Throughput (points/second)')
ax2.set_title('Sampling Throughput')
ax2.grid(True, alpha=0.3)
ax2.set_xscale('log')

plt.tight_layout()
plt.show()

print(f"\nPerformance Summary:")
print(f"  Best throughput: {max(throughputs):.0f} points/second")
print(f"  Scaling: {'Sub-linear' if throughputs[-1] > throughputs[0] else 'Super-linear'}")

## 8. Advanced Features and Extensions

BlackJAX nested sampling offers several advanced features for scientific applications:

### Key Advantages:
1. **Evidence computation**: Essential for model comparison in SBI
2. **Multimodal handling**: Robust exploration of complex posteriors
3. **GPU acceleration**: Leverages modern HPC infrastructure
4. **JAX integration**: Seamless autodiff and JIT compilation
5. **Open source**: Community-driven development

### When to Use Nested Sampling:
- **Model comparison** (Bayesian evidence needed)
- **Multimodal posteriors** (phase transitions, symmetries)
- **High-dimensional problems** (gradient-free exploration)
- **SBI workflows** (NLE, NRE, NJE all need sampling)
- **Scientific inference** (evidence quantification important)

In [None]:
# Model comparison demonstration
print("Demonstrating Bayesian model comparison with nested sampling...")

# Define two competing models
# Model 1: Full 2D Gaussian (5 parameters)
# Model 2: Circular Gaussian (3 parameters: μₓ, μᵧ, σ)

@jax.jit
def circular_gaussian_loglikelihood(params_3d):
    """Likelihood for circular Gaussian model (σₓ = σᵧ, ρ = 0)."""
    mu_x, mu_y, sigma = params_3d
    
    # Convert to 5D parameter space
    params_5d = jnp.array([mu_x, mu_y, sigma, sigma, 0.0])
    
    return loglikelihood_fn(params_5d)

# Model 2 prior bounds (3D)
prior_bounds_3d = {
    "mu_x": (-2.0, 2.0),
    "mu_y": (-2.0, 2.0),
    "sigma": (0.5, 3.0)
}

# Run nested sampling for Model 2 (circular Gaussian)
print("\nRunning nested sampling for Model 2 (circular Gaussian)...")

num_live_simple = 500
rng_key, subkey = jax.random.split(rng_key)
particles_3d, logprior_fn_3d = ns_utils.uniform_prior(
    subkey, num_live_simple, prior_bounds_3d
)

nested_sampler_3d = blackjax.nss(
    logprior_fn=logprior_fn_3d,
    loglikelihood_fn=circular_gaussian_loglikelihood,
    num_delete=25,
    num_inner_steps=15
)

live_state_3d = nested_sampler_3d.init(particles_3d)
jit_step_3d = jax.jit(nested_sampler_3d.step)

# Run until convergence
dead_points_3d = []
iteration_3d = 0

while (live_state_3d.logZ_live - live_state_3d.logZ) > -3.0:
    rng_key, subkey = jax.random.split(rng_key)
    live_state_3d, dead_info_3d = jit_step_3d(subkey, live_state_3d)
    dead_points_3d.append(dead_info_3d)
    iteration_3d += 1
    
    if iteration_3d % 25 == 0:
        remaining = live_state_3d.logZ_live - live_state_3d.logZ
        print(f"  Iteration {iteration_3d}: logZ = {live_state_3d.logZ:.3f}, remaining = {remaining:.3f}")

print(f"\nModel Comparison Results:")
print(f"Model 1 (Full 2D Gaussian):     logZ = {live_state.logZ:.3f} ± {jnp.sqrt(live_state.H):.3f}")
print(f"Model 2 (Circular Gaussian):    logZ = {live_state_3d.logZ:.3f} ± {jnp.sqrt(live_state_3d.H):.3f}")

# Calculate Bayes factor
log_bayes_factor = live_state.logZ - live_state_3d.logZ
bayes_factor = jnp.exp(log_bayes_factor)

print(f"\nBayes Factor (Model 1 / Model 2): {bayes_factor:.2f}")
print(f"Log Bayes Factor: {log_bayes_factor:.3f}")

if log_bayes_factor > 1:
    print("Evidence favors Model 1 (Full 2D Gaussian)")
elif log_bayes_factor < -1:
    print("Evidence favors Model 2 (Circular Gaussian)")
else:
    print("Evidence is inconclusive")

# Visualize model comparison
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Evidence comparison
models = ['Full 2D\nGaussian', 'Circular\nGaussian']
evidences = [live_state.logZ, live_state_3d.logZ]
errors = [jnp.sqrt(live_state.H), jnp.sqrt(live_state_3d.H)]

ax1.bar(models, evidences, yerr=errors, capsize=5, alpha=0.7, color=['blue', 'orange'])
ax1.set_ylabel('Log Evidence')
ax1.set_title('Model Comparison: Bayesian Evidence')
ax1.grid(True, alpha=0.3)

# Parameter comparison for overlapping parameters
dead_3d = ns_utils.finalise(live_state_3d, dead_points_3d)
samples_3d = dead_3d.particles

# Compare μₓ and μᵧ estimates
ax2.scatter(samples[:, 0], samples[:, 1], alpha=0.3, s=10, label='Model 1 (5D)', c='blue')
ax2.scatter(samples_3d[:, 0], samples_3d[:, 1], alpha=0.3, s=10, label='Model 2 (3D)', c='orange')
ax2.plot(true_params[0], true_params[1], 'r*', markersize=15, label='True')
ax2.set_xlabel('μₓ')
ax2.set_ylabel('μᵧ')
ax2.set_title('Position Parameter Comparison')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 9. Your Turn: Apply BlackJAX to Your Problems!

Now it's time to experiment with BlackJAX nested sampling on your own research problems. Here are some suggestions:

### Exercise Options:

1. **Modify the current problem**:
   - Change the true parameters and see how the inference performs
   - Add more noise to make the problem harder
   - Try different prior bounds

2. **Use your own JAX likelihood**:
   - Bring code from Viraj's workshop
   - Implement your own scientific model
   - Compare nested sampling vs. your usual inference method

3. **Explore different configurations**:
   - Increase the number of live points for higher precision
   - Try different `num_inner_steps` values
   - Experiment with the `num_delete` parameter for GPU optimization

### Template for Your Own Problem:

In [None]:
# Template for your own BlackJAX nested sampling experiment

# Step 1: Define your likelihood function
@jax.jit
def your_loglikelihood_fn(params):
    """
    Replace this with your own likelihood function.
    
    Args:
        params: JAX array of parameters
    
    Returns:
        Log-likelihood value
    """
    # Example: simple 2D Gaussian
    mu_x, mu_y = params[:2]
    
    # Your model/simulation here
    model_prediction = your_model(params)
    
    # Your likelihood calculation here
    loglik = your_likelihood_calculation(model_prediction, your_data)
    
    return loglik

# Step 2: Define your prior bounds
your_prior_bounds = {
    "param1": (lower_bound, upper_bound),
    "param2": (lower_bound, upper_bound),
    # Add more parameters as needed
}

# Step 3: Configure and run nested sampling
your_num_live = 500  # Adjust based on your problem complexity
your_num_dims = len(your_prior_bounds)
your_num_inner_steps = your_num_dims * 5

# Initialize
rng_key, subkey = jax.random.split(rng_key)
your_particles, your_logprior_fn = ns_utils.uniform_prior(
    subkey, your_num_live, your_prior_bounds
)

# Create sampler
your_sampler = blackjax.nss(
    logprior_fn=your_logprior_fn,
    loglikelihood_fn=your_loglikelihood_fn,
    num_delete=25,
    num_inner_steps=your_num_inner_steps
)

# Run sampling (add your implementation here)
print("Implement your nested sampling run here!")
print("Follow the pattern from the examples above.")

# Step 4: Analyze results with Anesthetic
# your_nested_samples = NestedSamples(...)
# Make corner plots, compute evidence, etc.

## 10. Resources and Next Steps

### Key Resources:

**BlackJAX Nested Sampling:**
- Repository: https://github.com/handley-lab/blackjax
- Documentation: [BlackJAX docs](https://blackjax-devs.github.io/blackjax/)
- Installation: `pip install git+https://github.com/handley-lab/blackjax@nested_sampling`

**Visualization:**
- Anesthetic: https://anesthetic.readthedocs.io/en/latest/plotting.html
- Corner plots for nested sampling results
- Evidence evolution diagnostics

**JAX Ecosystem:**
- JAX documentation: https://jax.readthedocs.io/
- NumPyro (probabilistic programming): https://num.pyro.ai/
- Optax (optimization): https://optax.readthedocs.io/

### When to Use BlackJAX Nested Sampling:

✅ **Good for:**
- Multimodal posteriors
- Model comparison (evidence computation)
- High-dimensional problems
- GPU-accelerated inference
- SBI workflows (NLE, NRE, NJE)
- Scientific applications requiring evidence quantification

❌ **Consider alternatives for:**
- Simple unimodal posteriors (HMC/NUTS may be faster)
- Very low-dimensional problems (< 3D)
- When you only need posterior samples (not evidence)

### Next Steps:

1. **Try on your research problems**: Apply to your own JAX-based models
2. **Experiment with configurations**: Optimize for your specific use case
3. **Compare methods**: Benchmark against your current inference approach
4. **Contribute**: BlackJAX is community-driven - report issues, suggest features
5. **Stay updated**: Follow BlackJAX development for new features

### Questions?

- GitHub Issues: https://github.com/handley-lab/blackjax/issues
- Discussion: BlackJAX community channels
- This workshop: Experiment with the provided templates!

---

**Thank you for participating in the BlackJAX Nested Sampling Workshop!**

*The future of scientific inference is GPU-native, open-source, and community-driven.*