In [None]:
# Force CPU backend on Apple Silicon to avoid Metal issues
# os.environ["JAX_PLATFORMS"] = "cpu"
import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
# Disable LaTeX rendering in matplotlib
import matplotlib.pyplot as plt

plt.rcParams.update(
    {
        "text.usetex": False,
        "font.family": "serif",
        "font.serif": ["DejaVu Serif"],
    }
)

import pickle
import time

import jax.numpy as jnp
from jax import random
from sbijax import NLE
from sbijax.nn import make_maf

from vr_foraging_sbi_ddm.models import Config, format_name
from vr_foraging_sbi_ddm.simulator import JaxPatchForagingDdm, create_prior

# SNLE
from vr_foraging_sbi_ddm.snle.snle_inference_jax import infer_parameters_snle, train_snle

# Shared utilities
from vr_foraging_sbi_ddm.snle.snle_utils_jax import (
    pairplot,
)
from vr_foraging_sbi_ddm.validation import (
    compute_sbc_metrics,
    plot_recovery_scatter,
    plot_sbc_diagnostics,
    validate_parameter_recovery,
)

# Setup
Run Configuration, Simulator, and Prior cells first, then run the SNLE pipeline.

In [None]:
CONFIG = Config(n_simulations=500000, batch_size=256, n_iter=100000, n_early_stopping_patience=20, force_retrain=True)
print("Configuration loaded")
print(f"Model directory: {CONFIG.filename}")

In [None]:
# ============================================================================
# Shared: Initialize simulator and prior
# ============================================================================

simulator = JaxPatchForagingDdm(
    initial_prob=0.8,
    depletion_rate=-0.1,
    threshold=1.0,
    start_point=0.0,
    inter_site_min=CONFIG.inter_site_min,
    inter_site_exp_alpha=CONFIG.inter_site_exp_alpha,
    inter_site_max=CONFIG.inter_site_max,
    length_normalizing_factor=CONFIG.length_normalizing_factor,
    odor_site_length=CONFIG.odor_site_length,
    max_sites_per_window=CONFIG.window_size,
    n_feat=CONFIG.n_feat,
)

prior_fn = create_prior(prior_low=jnp.array(CONFIG.prior_low), prior_high=jnp.array(CONFIG.prior_high))
rng_key = random.PRNGKey(CONFIG.seed)

print("Simulator initialized")
print(f"Prior bounds: {CONFIG.prior_low} -> {CONFIG.prior_high}")

# SNLE + MCMC Pipeline

Sequential Neural Likelihood Estimation learns $p(x | \theta)$ and uses MCMC to sample the posterior.

In [None]:
# ============================================================================
# 1: Train SNLE
# ============================================================================

model_dir = CONFIG.base_output_dir / format_name(CONFIG)
model_dir.mkdir(parents=True, exist_ok=True)
checkpoint_dir = model_dir / "checkpoints"
checkpoint_dir.mkdir(exist_ok=True)
model_path = model_dir / "model.pkl"

print(f"Model directory: {model_dir}")

if model_path.exists() and not CONFIG.force_retrain:
    print(f"Loading existing model from {model_path}")
    with open(model_path, "rb") as f:
        model_data = pickle.load(f)
    snle_params = model_data["snle_params"]
    y_mean = model_data["y_mean"]
    y_std = model_data["y_std"]

    # Reconstruct SNLE object
    rng_key, test_key = random.split(rng_key)
    test_theta = prior_fn().sample(seed=test_key)
    test_x = simulator.simulator_fn(seed=test_key, theta=test_theta)
    flow = make_maf(
        n_dimension=test_x.shape[-1],
        n_layers=CONFIG.num_layers,
        hidden_sizes=(CONFIG.hidden_dim, CONFIG.hidden_dim),
    )
    snle = NLE((prior_fn, simulator.simulator_fn), flow)
    print("Model loaded")
else:
    print("Training new SNLE model...")
    snle, snle_params, losses, rng_key, y_mean, y_std = train_snle(
        simulator,
        prior_fn,
        n_simulations=CONFIG.n_simulations,
        hidden_dim=CONFIG.hidden_dim,
        num_layers=CONFIG.num_layers,
        n_iter=CONFIG.n_iter,
        batch_size=CONFIG.batch_size,
        n_early_stopping_patience=CONFIG.n_early_stopping_patience,
        learning_rate=CONFIG.learning_rate,
        transition_steps=CONFIG.transition_steps,
        decay_rate=CONFIG.decay_rate,
        percentage_data_as_validation_set=0.1,
        rng_key=rng_key,
    )

    # Save
    model_data = {
        "snle_params": snle_params,
        "losses": losses,
        "y_mean": y_mean,
        "y_std": y_std,
        "config": CONFIG.model_dump(),
    }
    with open(model_path, "wb") as f:
        pickle.dump(model_data, f)
    print(f"Model saved to {model_path}")

In [None]:
# ============================================================================
# 2: Simulate test observation + SNLE inference
# ============================================================================

# Simulate a test observation
rng_key, subkey = random.split(rng_key)
true_theta = prior_fn().sample(seed=subkey)["theta"]
rng_key, subkey = random.split(rng_key)
_, observed_stats = simulator.simulate_one_window(true_theta, subkey)
print(f"True theta: {true_theta}")
print(f"Observed stats shape: {observed_stats.shape}")

# Run SNLE + MCMC inference
start_time = time.time()
rng_key, subkey = random.split(rng_key)
posterior_samples, diagnostics = infer_parameters_snle(
    snle,
    snle_params,
    observed_stats,
    y_mean,
    y_std,
    num_samples=1_000,
    num_warmup=500,
    num_chains=4,
    rng_key=subkey,
)
snle_time = time.time() - start_time
print(f"\nSNLE inference time: {snle_time:.2f}s")

In [None]:
# ============================================================================
# 3: Plot SNLE posteriors
# ============================================================================

param_names = ["drift_rate", "reward_bump", "failure_bump", "noise_std"]

fig, axes = plt.subplots(1, 4, figsize=(10, 2))
fig.suptitle("SNLE + MCMC Posterior", fontsize=12, fontweight="bold")

for i in range(4):
    counts, bins, _ = axes[i].hist(posterior_samples[:, i], bins=30, color="dodgerblue", edgecolor=None, alpha=0.7)
    mode_index = jnp.argmax(counts)
    posterior_mode = (bins[mode_index] + bins[mode_index + 1]) / 2

    axes[i].axvline(true_theta[i], color="orangered", linestyle="--", label="true value")
    axes[i].axvline(posterior_mode, color="k", linestyle="--", label="MAP estimate")
    axes[i].set_xlabel(param_names[i])
    axes[i].set_xlim(CONFIG.prior_low[i], CONFIG.prior_high[i])
    axes[i].set_ylabel("Frequency")

axes[-1].legend(loc="center left", bbox_to_anchor=(1, 0.5))
plt.tight_layout()

pairplot(posterior_samples, true_theta, param_names, figsize_per_param=2.0)
plt.show()

In [None]:
# ============================================================================
# 4: SNLE Validation (Parameter Recovery + SBC)
# ============================================================================
# Note: Each test runs MCMC sampling â€” can take 30s-2min per test.
# For faster iteration, reduce n_tests. For publication, use n_tests >= 100.

print("=" * 80)
print("SNLE PARAMETER RECOVERY")
print("=" * 80)
start_time = time.time()

recovery_results_snle = validate_parameter_recovery(
    snle,
    snle_params,
    y_mean,
    y_std,
    simulator,
    prior_fn,
    infer_parameters_snle,
    n_tests=5,
    num_samples=1000,
    num_warmup=500,
    num_chains=4,
)
plot_recovery_scatter(recovery_results_snle, save_path=model_dir / "recovery.png")

recovery_time = time.time() - start_time
print(f"\nRecovery completed in {recovery_time:.1f}s")

print("\n" + "=" * 80)
print("SNLE SBC")
print("=" * 80)
start_time = time.time()

sbc_results_snle = compute_sbc_metrics(
    snle,
    snle_params,
    y_mean,
    y_std,
    simulator,
    prior_fn,
    infer_parameters_snle,
    n_tests=20,
    num_samples=500,
    num_warmup=100,
    num_chains=2,
)
plot_sbc_diagnostics(sbc_results_snle, bins=10, save_path=model_dir / "sbc.png")

sbc_time = time.time() - start_time
print(f"\nSBC completed in {sbc_time:.1f}s")
print(f"Total SNLE validation: {recovery_time + sbc_time:.1f}s")