# BlackJAX Nested Sampling GPU Performance Testing

This notebook tests BlackJAX nested sampling performance on Apple Metal GPU.

**Goal**: Find the optimal `num_delete` parameter to maximize GPU utilization.

The `num_delete` parameter controls how many live points are replaced in each iteration.
Higher values allow more parallel work per iteration, which can better utilize GPU cores.

## Setup: Installing JAX Metal

If JAX Metal is not installed, run:
```bash
pip install jax-metal
```

**Note**: JAX Metal requires macOS 12.0+ and an Apple Silicon (M1/M2/M3) chip.

In [None]:
# First, let's check the current JAX setup
import jax
import jax.numpy as jnp

print(f"JAX version: {jax.__version__}")
print(f"Available devices: {jax.devices()}")
print(f"Default backend: {jax.default_backend()}")

# Check if Metal is available
try:
    metal_devices = [d for d in jax.devices() if 'metal' in str(d).lower() or 'gpu' in str(d).lower()]
    if metal_devices:
        print(f"\n✓ Metal GPU available: {metal_devices}")
    else:
        print("\n⚠ No Metal GPU detected. You may need to install jax-metal:")
        print("  pip install jax-metal")
except Exception as e:
    print(f"Error checking devices: {e}")

## Install jax-metal if needed

Uncomment and run the cell below if Metal is not available:

In [None]:
# Uncomment to install jax-metal
# !pip install jax-metal

# After installing, you may need to restart the kernel

## Set up the Discovery likelihood

We'll use the same 3-pulsar red noise model from the example notebook.

In [None]:
import discovery as ds
import glob
import os
import time
import numpy as np

# Find discovery package location
discovery_location = os.path.dirname(ds.__file__)
print(f"Discovery package at: {discovery_location}")

# Load pulsars
allpsrs = [ds.Pulsar.read_feather(psrfile) 
           for psrfile in sorted(glob.glob(f'{discovery_location}/../../data/*-[JB]*.feather'))]
print(f"Loaded {len(allpsrs)} pulsars")

# Use 3 pulsars for testing
psrs = allpsrs[:3]

print("Building likelihood...")
m = ds.ArrayLikelihood([ds.PulsarLikelihood([psr.residuals,
                                        ds.makenoise_measurement(psr, psr.noisedict),
                                        ds.makegp_ecorr(psr, psr.noisedict),
                                        ds.makegp_timing(psr, svd=True),
                                        ds.makegp_fourier(psr, ds.powerlaw, components=30, name='rednoise')])
                for psr in psrs])
print(f"Done. Parameters: {m.logL.params}")

In [None]:
# Define priors
priors = {
    'B1855+09_rednoise_gamma': {'dist': 'uniform', 'min': 0, 'max': 7},
    'B1855+09_rednoise_log10_A': {'dist': 'uniform', 'min': -20, 'max': -11},
    'B1937+21_rednoise_gamma': {'dist': 'uniform', 'min': 0, 'max': 7},
    'B1937+21_rednoise_log10_A': {'dist': 'uniform', 'min': -20, 'max': -11},
    'B1953+29_rednoise_gamma': {'dist': 'uniform', 'min': 0, 'max': 7},
    'B1953+29_rednoise_log10_A': {'dist': 'uniform', 'min': -20, 'max': -11},
}

ndim = len(priors)
print(f"Number of dimensions: {ndim}")

## Import BlackJAX interface

In [None]:
import sys
sys.path.insert(0, '..')  # Add parent directory to path

from src.discoverysamplers.blackjax_interface import DiscoveryBlackJAXBridge

## Benchmark function

This function runs the nested sampler with different `num_delete` values and measures performance.

In [None]:
def benchmark_nested_sampling(
    model, 
    priors,
    n_live: int = 500,
    num_delete_values: list = [1, 5, 10, 20, 50, 100],
    max_iterations: int = 100,
    num_inner_steps: int = None,
    seed: int = 42,
    warmup_iterations: int = 5,
):
    """
    Benchmark nested sampling with different num_delete values.
    
    Parameters
    ----------
    model : Discovery model
    priors : dict
    n_live : int
        Number of live points
    num_delete_values : list
        Values of num_delete to test
    max_iterations : int
        Maximum iterations per run
    num_inner_steps : int, optional
        HRSS steps per iteration (default: 5 * ndim)
    seed : int
        Random seed
    warmup_iterations : int
        Iterations to run for JIT warmup (not timed)
    
    Returns
    -------
    dict : Results with timing and efficiency metrics
    """
    import blackjax
    import blackjax.ns.utils as ns_utils
    
    ndim = len([k for k, v in priors.items() if v.get('dist', v[0] if isinstance(v, tuple) else 'uniform') != 'fixed'])
    if num_inner_steps is None:
        num_inner_steps = ndim * 5
    
    results = []
    
    for num_delete in num_delete_values:
        if num_delete > n_live:
            print(f"Skipping num_delete={num_delete} (> n_live={n_live})")
            continue
            
        print(f"\nTesting num_delete={num_delete}...")
        
        # Create bridge
        bridge = DiscoveryBlackJAXBridge(model, priors)
        
        # Create sampler
        algo = blackjax.nss(
            logprior_fn=bridge.log_prior_fn,
            loglikelihood_fn=bridge.loglikelihood_fn,
            num_delete=num_delete,
            num_inner_steps=num_inner_steps,
        )
        
        # Initialize
        rng_key = jax.random.PRNGKey(seed)
        rng_key, init_key = jax.random.split(rng_key)
        unit_samples = jax.random.uniform(init_key, (n_live, bridge.ndim))
        initial_particles = jax.vmap(bridge.prior_transform)(unit_samples)
        
        live = algo.init(initial_particles)
        step_fn = jax.jit(algo.step)
        
        # Warmup (includes JIT compilation)
        print(f"  Warming up ({warmup_iterations} iterations for JIT compilation)...")
        for i in range(warmup_iterations):
            rng_key, subkey = jax.random.split(rng_key)
            live, _ = step_fn(subkey, live)
        
        # Block until warmup is complete
        jax.block_until_ready(live.logZ)
        
        # Timed run
        print(f"  Running {max_iterations} timed iterations...")
        dead = []
        
        start_time = time.perf_counter()
        for i in range(max_iterations):
            rng_key, subkey = jax.random.split(rng_key)
            live, dead_info = step_fn(subkey, live)
            dead.append(dead_info)
        
        # Block until complete
        jax.block_until_ready(live.logZ)
        end_time = time.perf_counter()
        
        total_time = end_time - start_time
        time_per_iter = total_time / max_iterations
        dead_points_per_sec = (num_delete * max_iterations) / total_time
        
        result = {
            'num_delete': num_delete,
            'n_live': n_live,
            'num_inner_steps': num_inner_steps,
            'iterations': max_iterations,
            'total_time': total_time,
            'time_per_iter': time_per_iter,
            'dead_points_per_sec': dead_points_per_sec,
            'logZ': float(live.logZ),
        }
        results.append(result)
        
        print(f"  Total time: {total_time:.2f}s")
        print(f"  Time per iteration: {time_per_iter*1000:.2f}ms")
        print(f"  Dead points/sec: {dead_points_per_sec:.1f}")
        print(f"  logZ: {result['logZ']:.2f}")
    
    return results

## Run the benchmark

We'll test different `num_delete` values to find the sweet spot for GPU utilization.

**Key insight**: 
- Higher `num_delete` = more parallelism per iteration
- But diminishing returns once GPU is saturated
- Memory constraints may limit maximum value

In [None]:
# Check current device
print(f"Running on: {jax.devices()}")
print(f"Backend: {jax.default_backend()}")
print()

In [None]:
# Run benchmark with moderate settings
# Adjust n_live and num_delete_values based on your GPU memory

benchmark_results = benchmark_nested_sampling(
    model=m,
    priors=priors,
    n_live=500,
    num_delete_values=[1, 5, 10, 20, 50, 100],
    max_iterations=50,
    warmup_iterations=5,
    seed=42,
)

## Visualize results

In [None]:
import matplotlib.pyplot as plt

if benchmark_results:
    num_deletes = [r['num_delete'] for r in benchmark_results]
    times_per_iter = [r['time_per_iter'] * 1000 for r in benchmark_results]  # ms
    dead_per_sec = [r['dead_points_per_sec'] for r in benchmark_results]
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    # Time per iteration
    axes[0].plot(num_deletes, times_per_iter, 'o-', markersize=8)
    axes[0].set_xlabel('num_delete')
    axes[0].set_ylabel('Time per iteration (ms)')
    axes[0].set_title('Iteration Time vs num_delete')
    axes[0].set_xscale('log')
    axes[0].grid(True, alpha=0.3)
    
    # Throughput (dead points per second)
    axes[1].plot(num_deletes, dead_per_sec, 'o-', markersize=8, color='green')
    axes[1].set_xlabel('num_delete')
    axes[1].set_ylabel('Dead points / second')
    axes[1].set_title('Throughput vs num_delete')
    axes[1].set_xscale('log')
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print summary table
    print("\nSummary Table:")
    print("-" * 70)
    print(f"{'num_delete':>12} {'time/iter (ms)':>15} {'dead pts/sec':>15} {'logZ':>12}")
    print("-" * 70)
    for r in benchmark_results:
        print(f"{r['num_delete']:>12} {r['time_per_iter']*1000:>15.2f} {r['dead_points_per_sec']:>15.1f} {r['logZ']:>12.2f}")
    print("-" * 70)
    
    # Find optimal
    best = max(benchmark_results, key=lambda x: x['dead_points_per_sec'])
    print(f"\n✓ Optimal num_delete = {best['num_delete']} ({best['dead_points_per_sec']:.1f} dead points/sec)")
else:
    print("No benchmark results to plot.")

## Extended benchmark with more live points

If GPU memory allows, try higher `n_live` values for better GPU utilization:

In [None]:
# Uncomment to run extended benchmark with more live points
# This will take longer but may show better GPU utilization

# extended_results = benchmark_nested_sampling(
#     model=m,
#     priors=priors,
#     n_live=1000,
#     num_delete_values=[10, 25, 50, 100, 200],
#     max_iterations=30,
#     warmup_iterations=5,
#     seed=42,
# )

## Memory profiling

Check GPU memory usage to find the maximum feasible `num_delete`:

In [None]:
# Memory estimation
def estimate_memory_usage(n_live, ndim, num_inner_steps, dtype_bytes=4):
    """
    Rough estimate of memory usage for nested sampling.
    
    Main arrays:
    - Live particles: n_live * ndim
    - Covariance matrix: ndim * ndim
    - Intermediate arrays for slice sampling
    """
    particles = n_live * ndim * dtype_bytes
    covariance = ndim * ndim * dtype_bytes
    # Rough estimate for intermediate computations
    intermediate = n_live * ndim * num_inner_steps * dtype_bytes
    
    total_bytes = particles + covariance + intermediate
    return total_bytes / (1024 ** 2)  # MB

print("Estimated memory usage (MB):")
print("-" * 50)
for n_live in [100, 500, 1000, 2000, 5000]:
    mem = estimate_memory_usage(n_live, ndim, ndim * 5)
    print(f"  n_live={n_live:>5}: ~{mem:.1f} MB")

print("\nNote: Actual GPU memory usage may be higher due to JAX allocations.")
print("Apple Silicon unified memory allows larger allocations than discrete GPUs.")

## Full production run

Once you've found the optimal `num_delete`, run a full nested sampling:

In [None]:
# Run full nested sampling with optimal settings
# Adjust num_delete based on benchmark results above

optimal_num_delete = 20  # Adjust based on your benchmark results

bridge = DiscoveryBlackJAXBridge(m, priors)

print(f"Running full nested sampling with num_delete={optimal_num_delete}...")
start = time.perf_counter()

results = bridge.run_sampler(
    n_live=500,
    num_delete=optimal_num_delete,
    termination_threshold=-3.0,
    seed=42,
    progress=True,
)

elapsed = time.perf_counter() - start
print(f"\nCompleted in {elapsed:.1f}s")
print(f"logZ = {results['logZ']:.2f} ± {results['logZ_err']:.2f}")
print(f"Total samples: {results['samples'].shape[0]}")
print(f"Iterations: {results['n_iterations']}")

In [None]:
# Plot the results
fig = bridge.plot_corner()
plt.show()

## Metal GPU Tips

### Maximizing Metal GPU performance:

1. **Batch size matters**: Higher `num_delete` values allow more parallel work
2. **Unified memory advantage**: Apple Silicon shares memory between CPU/GPU, allowing larger `n_live`
3. **JIT compilation**: First iteration is slow; subsequent ones benefit from compiled kernels
4. **Float32 vs Float64**: Metal works better with float32 (JAX default)

### Known Metal limitations:

- Some JAX operations may fall back to CPU
- Less mature than CUDA backend
- Limited profiling tools compared to NVIDIA

### To force Metal backend:
```python
import jax
jax.config.update('jax_platform_name', 'metal')
```