# Exercise 2: Trust but Verify - Diagnostic Tools for SBI

**Time:** 20 minutes  
**Goal:** Learn to diagnose whether your SBI results are trustworthy

## 🎯 Learning Objectives

By the end of this exercise, you will:
1. Understand why diagnostics are crucial for SBI
2. Perform posterior predictive checks
3. Run simulation-based calibration (coverage tests)
4. Identify common failure modes and warning signs

---

## 📖 The Story Continues...

Your initial analysis of the wolf-deer populations impressed the environmental agency! However, before they implement costly interventions based on your predictions, the senior ecologist asks:

*"How do we know the neural network learned correctly? What if it's just memorizing patterns rather than understanding the true dynamics? These predictions will inform important decisions about hunting quotas and conservation efforts."*

Excellent question! Let's verify our results with diagnostic tools.

## Step 1: Setup and Quick Inference Recap

Let's quickly repeat the inference from Exercise 1 (with fewer simulations for speed):

In [None]:
# Imports
import matplotlib.pyplot as plt
import numpy as np
import torch
from sbi import utils as utils
from sbi.inference import NPE

# Our simulator
from simulators.lotka_volterra import (
    create_lotka_volterra_prior,
    generate_observed_data,
    lotka_volterra_simulator,
)

# Set plotting style
plt.style.use("seaborn-v0_8-darkgrid")
plt.rcParams["font.size"] = 14
plt.rcParams["axes.labelsize"] = 16

print("🔄 Running quick inference for diagnostics...")

# Setup
prior = create_lotka_volterra_prior()
observed_data, true_params = generate_observed_data(seed=2025)

# Quick training (fewer simulations for speed)
npe = NPE(prior)
theta, x = utils.simulate_for_sbi(lotka_volterra_simulator, prior, num_simulations=5000)
density_estimator = npe.append_simulations(theta, x).train(
    training_batch_size=50, max_num_epochs=30
)
posterior = npe.build_posterior(density_estimator)

print("✅ Inference complete! Now let's check if we can trust it...")

## Step 2: Diagnostic 1 - Posterior Predictive Check

**Key Question:** *If we simulate data using parameters from our posterior, does it look like our observed data?*

This is the most intuitive diagnostic:
1. Sample parameters from the posterior
2. Simulate data with those parameters
3. Compare simulated data to observations

If the posterior is correct, simulated data should be consistent with observations.

In [None]:
print("🔍 Diagnostic 1: Posterior Predictive Check")
print("=" * 50)

# Sample from posterior and simulate
n_predictive_samples = 500
posterior_samples = posterior.sample((n_predictive_samples,), x=observed_data.flatten())

# Simulate data for each posterior sample
predictive_sims = []
for params in posterior_samples:
    sim = lotka_volterra_simulator(params, add_noise=True)
    predictive_sims.append(sim)
predictive_sims = torch.stack(predictive_sims)

# Calculate statistics
sim_mean = predictive_sims.mean(dim=0)
sim_std = predictive_sims.std(dim=0)
sim_lower = torch.quantile(predictive_sims, 0.025, dim=0)
sim_upper = torch.quantile(predictive_sims, 0.975, dim=0)

# Visualization
fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True)

times = np.array([10, 15, 20, 25, 30])

# Deer populations
ax = axes[0]
ax.fill_between(
    times, sim_lower[:, 0], sim_upper[:, 0], alpha=0.3, color="brown", label="95% CI"
)
ax.plot(times, sim_mean[:, 0], "o-", color="brown", label="Posterior predictive mean")
ax.plot(
    times,
    observed_data[:, 0],
    "s-",
    color="black",
    linewidth=2,
    markersize=8,
    label="Observed",
)
ax.set_ylabel("Deer Population", fontsize=14)
ax.legend(loc="upper right", fontsize=12)
ax.grid(True, alpha=0.3)

# Wolf populations
ax = axes[1]
ax.fill_between(
    times, sim_lower[:, 1], sim_upper[:, 1], alpha=0.3, color="orange", label="95% CI"
)
ax.plot(times, sim_mean[:, 1], "o-", color="orange", label="Posterior predictive mean")
ax.plot(
    times,
    observed_data[:, 1],
    "s-",
    color="black",
    linewidth=2,
    markersize=8,
    label="Observed",
)
ax.set_xlabel("Time (years)", fontsize=14)
ax.set_ylabel("Wolf Population", fontsize=14)
ax.legend(loc="upper right", fontsize=12)
ax.grid(True, alpha=0.3)

plt.suptitle("Posterior Predictive Check", fontsize=18)
plt.tight_layout()
plt.show()

# Check if observations fall within predictive intervals
print("\n📊 Coverage Analysis:")
for i, year in enumerate(times):
    deer_in_ci = sim_lower[i, 0] <= observed_data[i, 0] <= sim_upper[i, 0]
    wolf_in_ci = sim_lower[i, 1] <= observed_data[i, 1] <= sim_upper[i, 1]

    deer_symbol = "✅" if deer_in_ci else "❌"
    wolf_symbol = "✅" if wolf_in_ci else "❌"

    print(f"Year {year}: Deer {deer_symbol}, Wolves {wolf_symbol}")

# Overall assessment
total_obs = len(times) * 2  # deer and wolves
in_ci = sum(
    [
        (sim_lower[i, j] <= observed_data[i, j] <= sim_upper[i, j]).item()
        for i in range(len(times))
        for j in range(2)
    ]
)
coverage = in_ci / total_obs * 100

print(f"\n🎯 Overall: {in_ci}/{total_obs} observations within 95% CI ({coverage:.1f}%)")
if coverage >= 80:
    print("✅ PASS: Good agreement between posterior predictions and observations!")
elif coverage >= 60:
    print(
        "⚠️  WARNING: Moderate agreement - consider more simulations or checking model"
    )
else:
    print("❌ FAIL: Poor agreement - the posterior may be unreliable!")

## Step 3: Diagnostic 2 - Simulation-Based Calibration (Coverage Test)

**Key Question:** *If we know the true parameters, does our method recover them correctly across many scenarios?*

This is a more rigorous test:
1. Sample "true" parameters from the prior
2. Simulate data with those parameters
3. Infer parameters from the simulated data
4. Check if true parameters fall within credible intervals

For well-calibrated inference, X% credible intervals should contain the true value X% of the time.

In [None]:
print("🔍 Diagnostic 2: Simulation-Based Calibration")
print("=" * 50)
print(
    "Testing if our 90% credible intervals contain the true value ~90% of the time..."
)
print("(Note: Running fewer tests for speed - in practice, use 100-1000)\n")

# Run calibration tests
n_tests = 20  # Reduced for tutorial speed (use 100+ in practice)
credibility_level = 0.9  # 90% credible interval
contained = []

for test_idx in range(n_tests):
    # Sample true parameters
    true_theta = prior.sample((1,)).squeeze()

    # Simulate observation
    x_test = lotka_volterra_simulator(true_theta)

    # Get posterior samples
    posterior_samples_test = posterior.sample((1000,), x=x_test.flatten())

    # Check if true parameters are within credible intervals
    lower = torch.quantile(posterior_samples_test, (1 - credibility_level) / 2, dim=0)
    upper = torch.quantile(
        posterior_samples_test, 1 - (1 - credibility_level) / 2, dim=0
    )

    # Check each parameter
    param_contained = [
        (lower[i] <= true_theta[i] <= upper[i]).item() for i in range(len(true_theta))
    ]
    contained.append(param_contained)

    # Progress indicator
    if (test_idx + 1) % 5 == 0:
        print(f"  Completed {test_idx + 1}/{n_tests} tests...")

# Analyze results
contained = np.array(contained)
coverage_per_param = contained.mean(axis=0) * 100
overall_coverage = contained.mean() * 100

print("\n📊 Coverage Results (Expected: ~90%):")
print("-" * 40)
param_names = [
    "α (deer birth)",
    "β (predation)",
    "δ (wolf efficiency)",
    "γ (wolf death)",
]
for i, name in enumerate(param_names):
    cov = coverage_per_param[i]
    if 80 <= cov <= 100:
        symbol = "✅"
    elif 70 <= cov < 80:
        symbol = "⚠️"
    else:
        symbol = "❌"
    print(f"{symbol} {name:20s}: {cov:.1f}%")

print(f"\n🎯 Overall coverage: {overall_coverage:.1f}%")

# Assessment
if 85 <= overall_coverage <= 95:
    print("✅ PASS: Well-calibrated inference!")
elif 75 <= overall_coverage < 85 or 95 < overall_coverage <= 100:
    print("⚠️  WARNING: Slightly miscalibrated - consider more training data")
else:
    print("❌ FAIL: Poorly calibrated - inference may be unreliable!")

# Visualization
fig, ax = plt.subplots(figsize=(10, 6))
x_pos = np.arange(len(param_names))
bars = ax.bar(
    x_pos,
    coverage_per_param,
    color=[
        "green" if 80 <= c <= 100 else "orange" if 70 <= c < 80 else "red"
        for c in coverage_per_param
    ],
)
ax.axhline(y=90, color="black", linestyle="--", label="Expected (90%)")
ax.set_ylabel("Coverage (%)", fontsize=14)
ax.set_xlabel("Parameter", fontsize=14)
ax.set_title("Coverage Test Results", fontsize=16)
ax.set_xticks(x_pos)
ax.set_xticklabels([name.split(" ")[0] for name in param_names], fontsize=12)
ax.set_ylim([0, 100])
ax.legend(fontsize=12)
ax.grid(True, alpha=0.3, axis="y")

# Add value labels on bars
for bar, cov in zip(bars, coverage_per_param, strict=False):
    height = bar.get_height()
    ax.text(
        bar.get_x() + bar.get_width() / 2.0,
        height + 1,
        f"{cov:.0f}%",
        ha="center",
        va="bottom",
        fontsize=12,
    )

plt.tight_layout()
plt.show()

## Step 4: Common Warning Signs and What They Mean

Let's explore what different diagnostic failures tell us:

In [None]:
print("⚠️  Common SBI Warning Signs and Solutions")
print("=" * 60)

warning_signs = [
    {
        "symptom": "Posterior much wider than expected",
        "causes": ["Too few simulations", "Uninformative observations"],
        "solutions": ["Increase num_simulations", "Use more/better observations"],
    },
    {
        "symptom": "Posterior too narrow (overconfident)",
        "causes": ["Overfitting", "Model misspecification"],
        "solutions": ["Use validation set", "Check simulator assumptions"],
    },
    {
        "symptom": "Poor coverage (<70%)",
        "causes": ["Network architecture issues", "Training problems"],
        "solutions": ["Try different density estimator", "Increase training epochs"],
    },
    {
        "symptom": "Multimodal posterior when not expected",
        "causes": ["Parameter identifiability issues", "Multiple solutions"],
        "solutions": ["Add more observations", "Reparameterize model"],
    },
]

for i, warning in enumerate(warning_signs, 1):
    print(f"\n{i}. {warning['symptom']}")
    print("   Possible causes:")
    for cause in warning["causes"]:
        print(f"     • {cause}")
    print("   Solutions:")
    for solution in warning["solutions"]:
        print(f"     → {solution}")

## Step 5: Quick Diagnostic - Training Loss Check

A simple but useful check: did the neural network training converge?

In [None]:
print("📉 Training Loss Analysis")
print("=" * 40)

# Get training losses from the summary
try:
    summary = npe.summary
    best_loss = summary["best_validation_loss"][-1]

    print(f"Best validation loss: {best_loss:.4f}")

    if best_loss < -1.0:  # Good for normalizing flows
        print("✅ Training converged well")
    elif best_loss < 0:
        print("⚠️  Training partially converged - consider more epochs")
    else:
        print("❌ Poor training - check your setup")

except:
    print("Could not access training summary")

# Additional quick checks
print("\n🔍 Quick Sanity Checks:")

# Check 1: Can we sample from posterior?
try:
    test_samples = posterior.sample((10,), x=observed_data.flatten())
    print("✅ Posterior sampling works")
except:
    print("❌ Cannot sample from posterior")

# Check 2: Are posterior samples within prior bounds?
posterior_test = posterior.sample((100,), x=observed_data.flatten())
prior_low = prior.base_dist.low
prior_high = prior.base_dist.high

within_bounds = torch.all(
    (posterior_test >= prior_low) & (posterior_test <= prior_high)
)
if within_bounds:
    print("✅ All posterior samples within prior bounds")
else:
    print("❌ Some posterior samples outside prior bounds!")

# Check 3: Posterior not collapsed?
posterior_std = posterior_test.std(dim=0)
if torch.all(posterior_std > 0.01):  # Arbitrary small threshold
    print("✅ Posterior has reasonable variance")
else:
    print("❌ Posterior may have collapsed to point mass")

## Step 6: Creating Your Diagnostic Report

Let's create a summary diagnostic report:

In [None]:
print("=" * 60)
print("          📋 SBI DIAGNOSTIC REPORT")
print("=" * 60)

# Collect all diagnostic results
diagnostics = {
    "Posterior Predictive Check": coverage >= 80,
    "Coverage Calibration": 85 <= overall_coverage <= 95,
    "Training Convergence": True,  # Placeholder - set based on actual check
    "Sampling Works": True,  # Placeholder
    "Within Prior Bounds": within_bounds.item(),
}

# Overall assessment
passed = sum(diagnostics.values())
total = len(diagnostics)

print(f"\n📊 Diagnostic Summary: {passed}/{total} checks passed\n")

for test, result in diagnostics.items():
    symbol = "✅" if result else "❌"
    print(f"{symbol} {test}")

print("\n" + "=" * 60)

if passed == total:
    print("🎉 OVERALL: READY FOR PRODUCTION")
    print("Your inference is reliable and can be used for decision-making.")
elif passed >= total * 0.6:
    print("⚠️  OVERALL: USE WITH CAUTION")
    print("Some diagnostics failed. Consider:")
    print("  • Increasing training simulations")
    print("  • Checking model assumptions")
    print("  • Running more thorough diagnostics")
else:
    print("❌ OVERALL: DO NOT USE")
    print("Multiple diagnostic failures detected.")
    print("Please review your setup and retrain.")

print("\n" + "=" * 60)

## 🎯 Key Takeaways

### Why Diagnostics Matter

1. **Neural networks can fail silently** - They might produce confident-looking but wrong results
2. **Not all posteriors are created equal** - Some might be overconfident, others too uncertain
3. **Trust but verify** - Always check your inference before making decisions

### Your Diagnostic Toolkit

| Diagnostic | What it checks | Red flag |
|------------|---------------|----------|
| Posterior Predictive | Can we recreate observations? | Observations outside CI |
| Coverage Test | Are credible intervals calibrated? | Coverage far from nominal |
| Training Loss | Did the network converge? | Loss increasing or high |
| Prior Bounds | Is posterior sensible? | Samples outside prior |

### Best Practices

✅ **Always run diagnostics** - Make it part of your workflow  
✅ **Start simple** - Posterior predictive checks first  
✅ **Document results** - Keep diagnostic reports with your analysis  
✅ **Iterate if needed** - Poor diagnostics → adjust and retrain  

---

## 🚀 Challenge: Stress Testing

Try breaking the inference and seeing how diagnostics detect it:

1. **Too few simulations**: Retrain with only 500 simulations
2. **Wrong prior**: Use a prior that doesn't contain true parameters
3. **Corrupted data**: Add extreme outliers to observations

How do the diagnostics change? Which tests catch which problems?

In [None]:
# Space for experimentation
