# Notebook 4: Uncertainty Quantification with Annealed SMC

This notebook demonstrates using Annealed Sequential Monte Carlo (SMC) for full Bayesian inference.

**Learning objectives:**
- Understand the difference between MAP-II and full Bayesian inference
- Learn how Annealed SMC works
- Quantify hyperparameter uncertainty
- Compare point estimates vs. posterior distributions

## Setup

In [None]:
# Enable auto-reload for development
%load_ext autoreload
%autoreload 2

# Fix import path
import sys
if '..' not in sys.path:
    sys.path.insert(0, '..')

import os
os.environ['JAX_PLATFORM_NAME'] = 'cpu'

import jax
jax.config.update('jax_platform_name', 'cpu')

import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

from infodynamics_jax.core import Phi
from infodynamics_jax.gp.kernels.params import KernelParams
from infodynamics_jax.gp.kernels.rbf import rbf as rbf_kernel
from infodynamics_jax.gp.likelihoods import get as get_likelihood
from infodynamics_jax.energy import InertialEnergy, InertialCFG
from infodynamics_jax.inference.particle import AnnealedSMC, AnnealedSMCCFG
from infodynamics_jax.inference.optimisation import TypeII, TypeIICFG
from infodynamics_jax.infodynamics import run, RunCFG

print(f"JAX version: {jax.__version__}")

## 1. MAP-II vs. Full Bayesian Inference

### MAP-II (Maximum A Posteriori Type-II)
- Finds **single best** hyperparameters: $\phi^* = \arg\max_{\phi} p(\phi | y)$
- Fast, deterministic
- No uncertainty quantification for hyperparameters
- Risk of overfitting with small datasets

### Full Bayesian (Annealed SMC)
- Approximates **entire posterior**: $p(\phi | y)$
- Represents uncertainty via particle cloud
- More robust to overfitting
- Computational cost: $O(P \times T)$ where $P$ = particles, $T$ = annealing steps

### Annealing Schedule

Annealed SMC uses a temperature schedule $\{\beta_t\}_{t=0}^T$ where $0 = \beta_0 < \beta_1 < ... < \beta_T = 1$:

$$\pi_t(\phi) \propto p(\phi) \cdot p(y | \phi)^{\beta_t}$$

- $t=0$: Start from prior $p(\phi)$
- $t=T$: End at posterior $p(\phi | y)$
- Intermediate: Tempered distributions

## 2. Generate Regression Data

In [None]:
key = jax.random.key(789)

# Small dataset to see uncertainty
N_train = 30
X_train = jnp.linspace(-4, 4, N_train)[:, None]

# True function
def true_function(x):
    return jnp.sin(2 * x[:, 0]) + 0.3 * x[:, 0]

f_train = true_function(X_train)

# Add noise
key, subkey = jax.random.split(key)
noise_std = 0.3
Y_train = f_train + noise_std * jax.random.normal(subkey, (N_train,))

# Test set
X_test = jnp.linspace(-5, 5, 100)[:, None]
f_test = true_function(X_test)

print(f"Training set: {N_train} points")
print(f"True noise std: {noise_std}")

In [None]:
# Visualize data
plt.figure(figsize=(10, 4))
plt.scatter(X_train[:, 0], Y_train, c='red', s=50, alpha=0.6, label='Training data')
plt.plot(X_test[:, 0], f_test, 'g-', linewidth=2, label='True function')
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Small Dataset for Bayesian Inference')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

## 3. Baseline: MAP-II Optimization

In [None]:
# Initialize
kernel_params = KernelParams(lengthscale=jnp.array(1.0), variance=jnp.array(1.0))
M = 10
Z = jnp.linspace(X_train.min(), X_train.max(), M)[:, None]

phi_init = Phi(
    kernel_params=kernel_params,
    Z=Z,
    likelihood_params={"noise_var": jnp.array(0.1)},
    jitter=1e-5,
)

# Energy
gaussian_likelihood = get_likelihood("gaussian")
inertial_cfg = InertialCFG(estimator="gh", gh_n=20, inner_steps=0)
inertial_energy = InertialEnergy(
    kernel_fn=rbf_kernel,
    likelihood=gaussian_likelihood,
    cfg=inertial_cfg,
)

# Run MAP-II
typeii_cfg = TypeIICFG(steps=100, lr=1e-2, optimizer="adam", jit=True, constrain_params=True)
method = TypeII(cfg=typeii_cfg)

key, subkey = jax.random.split(key)
out_map = run(
    key=subkey,
    method=method,
    energy=inertial_energy,
    phi_init=phi_init,
    energy_args=(X_train, Y_train),
    cfg=RunCFG(jit=True),
)

phi_map = out_map.result.phi

print("MAP-II Results:")
print(f"  Lengthscale: {float(phi_map.kernel_params.lengthscale):.3f}")
print(f"  Variance: {float(phi_map.kernel_params.variance):.3f}")
print(f"  Noise var: {float(phi_map.likelihood_params['noise_var']):.3f}")

## 4. Run Annealed SMC

In [None]:
def init_particles_fn(key, n_particles: int):
    """
    Initialize particles from prior.
    Add noise around initial values to explore the space.
    """
    keys = jax.random.split(key, n_particles)
    
    def init_one(key_i):
        # Sample from prior with some spread
        key_l, key_v, key_z, key_n = jax.random.split(key_i, 4)
        
        # Lengthscale: log-normal around 1.0
        lengthscale = jnp.exp(jax.random.normal(key_l, ()) * 0.5)
        
        # Variance: log-normal around 1.0
        variance = jnp.exp(jax.random.normal(key_v, ()) * 0.5)
        
        # Inducing points: slight perturbation
        Z_noisy = phi_init.Z + jax.random.normal(key_z, phi_init.Z.shape) * 0.2
        
        # Noise variance: log-normal around 0.1
        noise_var = jnp.exp(jnp.log(0.1) + jax.random.normal(key_n, ()) * 0.5)
        
        # Constrain to positive
        lengthscale = jnp.maximum(lengthscale, 0.1)
        variance = jnp.maximum(variance, 0.1)
        noise_var = jnp.maximum(noise_var, 0.01)
        
        return Phi(
            kernel_params=KernelParams(lengthscale=lengthscale, variance=variance),
            Z=Z_noisy,
            likelihood_params={"noise_var": noise_var},
            jitter=phi_init.jitter,
        )
    
    particles = jax.vmap(init_one)(keys)
    return particles

print("Particle initialization function created!")

In [None]:
# Configure Annealed SMC
smc_cfg = AnnealedSMCCFG(
    n_particles=64,           # Number of particles
    n_steps=20,               # Number of annealing steps
    ess_threshold=0.5,        # Resample when ESS < 0.5 * n_particles
    rejuvenation="hmc",       # Use HMC for particle rejuvenation
    rejuvenation_steps=2,     # HMC steps per rejuvenation
    jit=True,                 # Enable JIT
)

method_smc = AnnealedSMC(cfg=smc_cfg)

print(f"Annealed SMC Configuration:")
print(f"  Particles: {smc_cfg.n_particles}")
print(f"  Annealing steps: {smc_cfg.n_steps}")
print(f"  Rejuvenation: {smc_cfg.rejuvenation} ({smc_cfg.rejuvenation_steps} steps)")

In [None]:
# Run Annealed SMC
print("Running Annealed SMC...")
print("This may take a minute...")

key, subkey = jax.random.split(key)
result_smc = method_smc.run(
    energy=inertial_energy,
    init_particles_fn=init_particles_fn,
    key=subkey,
    energy_args=(X_train, Y_train),
)

particles = result_smc.particles
logw = result_smc.logw
ess_trace = result_smc.ess_trace
logZ_est = result_smc.logZ_est

print("\nAnnealed SMC Results:")
print(f"  Final ESS: {ess_trace[-1]:.1f} / {smc_cfg.n_particles}")
print(f"  logZ estimate: {logZ_est:.2f}")
print(f"  Log weight range: [{logw.min():.2f}, {logw.max():.2f}]")

## 5. Analyze Posterior Distribution

In [None]:
# Extract particle values
lengthscales = np.array(particles.kernel_params.lengthscale)
variances = np.array(particles.kernel_params.variance)
noise_vars = np.array(particles.likelihood_params["noise_var"])

# Normalize weights
weights = np.exp(logw - logw.max())
weights = weights / weights.sum()

# Weighted statistics
lengthscale_mean = float(np.sum(weights * lengthscales))
lengthscale_std = float(np.sqrt(np.sum(weights * (lengthscales - lengthscale_mean)**2)))

variance_mean = float(np.sum(weights * variances))
variance_std = float(np.sqrt(np.sum(weights * (variances - variance_mean)**2)))

noise_var_mean = float(np.sum(weights * noise_vars))
noise_var_std = float(np.sqrt(np.sum(weights * (noise_vars - noise_var_mean)**2)))

print("Posterior Statistics (Weighted):")
print(f"\nLengthscale:")
print(f"  Mean: {lengthscale_mean:.3f} ± {lengthscale_std:.3f}")
print(f"  MAP:  {float(phi_map.kernel_params.lengthscale):.3f}")

print(f"\nVariance:")
print(f"  Mean: {variance_mean:.3f} ± {variance_std:.3f}")
print(f"  MAP:  {float(phi_map.kernel_params.variance):.3f}")

print(f"\nNoise Variance:")
print(f"  Mean: {noise_var_mean:.3f} ± {noise_var_std:.3f}")
print(f"  MAP:  {float(phi_map.likelihood_params['noise_var']):.3f}")
print(f"  True: {noise_std**2:.3f}")

## 6. Visualize Posterior Distributions

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# ESS trace
ax = axes[0, 0]
ax.plot(ess_trace, 'b-', linewidth=2)
ax.axhline(y=smc_cfg.ess_threshold * smc_cfg.n_particles, 
          color='red', linestyle='--', label=f'Threshold ({smc_cfg.ess_threshold * smc_cfg.n_particles:.0f})')
ax.set_xlabel('Annealing Step')
ax.set_ylabel('ESS')
ax.set_title('Effective Sample Size')
ax.legend()
ax.grid(True, alpha=0.3)

# Lengthscale distribution
ax = axes[0, 1]
ax.hist(lengthscales, bins=30, weights=weights, alpha=0.7, edgecolor='black')
ax.axvline(x=lengthscale_mean, color='blue', linestyle='-', linewidth=2, label=f'Mean: {lengthscale_mean:.2f}')
ax.axvline(x=float(phi_map.kernel_params.lengthscale), color='red', linestyle='--', linewidth=2, label=f'MAP: {float(phi_map.kernel_params.lengthscale):.2f}')
ax.set_xlabel('Lengthscale')
ax.set_ylabel('Posterior Density (Weighted)')
ax.set_title('Lengthscale Posterior')
ax.legend()
ax.grid(True, alpha=0.3)

# Variance distribution
ax = axes[0, 2]
ax.hist(variances, bins=30, weights=weights, alpha=0.7, edgecolor='black')
ax.axvline(x=variance_mean, color='blue', linestyle='-', linewidth=2, label=f'Mean: {variance_mean:.2f}')
ax.axvline(x=float(phi_map.kernel_params.variance), color='red', linestyle='--', linewidth=2, label=f'MAP: {float(phi_map.kernel_params.variance):.2f}')
ax.set_xlabel('Variance')
ax.set_ylabel('Posterior Density (Weighted)')
ax.set_title('Variance Posterior')
ax.legend()
ax.grid(True, alpha=0.3)

# Noise variance distribution
ax = axes[1, 0]
ax.hist(noise_vars, bins=30, weights=weights, alpha=0.7, edgecolor='black')
ax.axvline(x=noise_var_mean, color='blue', linestyle='-', linewidth=2, label=f'Mean: {noise_var_mean:.2f}')
ax.axvline(x=float(phi_map.likelihood_params['noise_var']), color='red', linestyle='--', linewidth=2, label=f'MAP: {float(phi_map.likelihood_params["noise_var"]):.2f}')
ax.axvline(x=noise_std**2, color='green', linestyle=':', linewidth=2, label=f'True: {noise_std**2:.2f}')
ax.set_xlabel('Noise Variance')
ax.set_ylabel('Posterior Density (Weighted)')
ax.set_title('Noise Variance Posterior')
ax.legend()
ax.grid(True, alpha=0.3)

# Joint distribution: lengthscale vs variance
ax = axes[1, 1]
scatter = ax.scatter(lengthscales, variances, c=weights, s=100, alpha=0.6, cmap='viridis', edgecolors='black')
ax.scatter(float(phi_map.kernel_params.lengthscale), float(phi_map.kernel_params.variance), 
          c='red', s=200, marker='*', edgecolors='black', linewidths=2, label='MAP', zorder=10)
ax.set_xlabel('Lengthscale')
ax.set_ylabel('Variance')
ax.set_title('Joint Posterior (Lengthscale vs Variance)')
ax.legend()
ax.grid(True, alpha=0.3)
plt.colorbar(scatter, ax=ax, label='Weight')

# Particle weights
ax = axes[1, 2]
ax.bar(range(len(weights)), np.sort(weights)[::-1], alpha=0.7, edgecolor='black')
ax.set_xlabel('Particle Index (sorted)')
ax.set_ylabel('Normalized Weight')
ax.set_title('Particle Weights Distribution')
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

## Summary

In this notebook, we demonstrated full Bayesian inference with Annealed SMC:

### Key Findings

1. **Posterior Uncertainty**: SMC provides full posterior distributions, not just point estimates
2. **MAP vs. Bayesian Mean**: MAP estimates can differ from posterior means
3. **Hyperparameter Correlations**: Joint distributions reveal parameter dependencies
4. **Uncertainty Quantification**: Standard deviations quantify estimation uncertainty

### When to Use Annealed SMC

**Use Annealed SMC when:**
- Small datasets (high parameter uncertainty)
- Need robust inference (avoid overfitting)
- Want to quantify hyperparameter uncertainty
- Model selection (via marginal likelihood estimate)

**Use MAP-II when:**
- Large datasets (posterior concentrates)
- Speed is critical
- Point estimates are sufficient
- Production deployment

### Annealing Best Practices

1. **Number of particles**: Start with 50-100, increase if ESS drops too much
2. **Annealing steps**: 10-20 usually sufficient
3. **Rejuvenation**: HMC with 1-5 steps maintains diversity
4. **ESS threshold**: 0.5 is standard, lower = more resampling

### Computational Cost

- **MAP-II**: $O(N \times I)$ where $I$ = iterations (typically 100-500)
- **Annealed SMC**: $O(P \times T \times C)$ where:
  - $P$ = particles (50-200)
  - $T$ = annealing steps (10-20)
  - $C$ = cost per particle (similar to MAP-II iteration)

**Rule of thumb**: SMC is 10-50× slower than MAP-II, but provides full uncertainty quantification.