# Exercise 2: Diagnostic Checks for SBI 🔍

**Time:** 20 minutes  
**Difficulty:** Intermediate  

In this exercise, you'll learn how to validate your neural posterior estimator using diagnostic tools. These checks are crucial for building trust in your inference results.

## 🎯 Learning Objectives

By the end of this exercise, you will be able to:

1. ✅ Perform prior predictive checks to validate your prior
2. ✅ Monitor neural network training with loss curves
3. ✅ Run posterior predictive checks to validate inference
4. ✅ Use simulation-based calibration (SBC) to test the inference method
5. ✅ Interpret diagnostic results and identify problems

## 📖 The Story Continues...

Your initial analysis of the wolf-deer populations impressed the environmental agency! However, before they implement costly interventions based on your predictions, the senior ecologist asks:

*"How do we know the neural network learned correctly? What if it's just memorizing patterns rather than understanding the true dynamics? These predictions will inform important decisions about hunting quotas and conservation efforts."*

Excellent question! Let's verify our results with diagnostic tools.

## Step 1: Setup and Quick Inference Recap

Let's quickly repeat the inference from Exercise 1 (with fewer simulations for speed):

In [None]:
# Imports
from functools import partial
import matplotlib.pyplot as plt
import pickle
from sbi import utils as utils
from sbi.diagnostics import run_sbc
from sbi.inference import simulate_for_sbi

# Our simulator
from simulators.lotka_volterra import (
    create_lotka_volterra_prior,
    generate_observed_data,
    get_lv_summary_stats_names,
    lotka_volterra_simulator,
)

# Import our utility functions
from utils import plot_predictive_check, plot_training_diagnostics

# Set plotting style
plt.style.use("seaborn-v0_8-darkgrid")
plt.rcParams["font.size"] = 14
plt.rcParams["axes.labelsize"] = 16

# Setup
USE_AUTOCORRELATION = True
prior = create_lotka_volterra_prior()
observed_data, true_params = generate_observed_data(
    use_autocorrelation=USE_AUTOCORRELATION
)
lotka_volterra_simulator = partial(
    lotka_volterra_simulator, use_autocorrelation=USE_AUTOCORRELATION
)
num_workers: int = 5

# Let's load the inference object from exercise 01
with open(
    f"lv_inference_{'with_autocorrelation' if USE_AUTOCORRELATION else 'without_autocorrelation'}.pt",
    "rb",
) as handle:
    npe = pickle.load(handle)
# And build a new posterior object
posterior = npe.build_posterior(prior=prior)
# Set default x to observed data
posterior.set_default_x(observed_data)

print("\n✅ Inference loaded! Now let's check if we can trust it...")

## Step 2: Prior Predictive Check

**Key Question:** *Does our observed data fall within the range of data that the prior can generate?*

Before looking at the posterior, we should check if our prior is reasonable:
1. Sample parameters from the prior
2. Simulate data with those parameters  
3. Check if observed data falls within this distribution

If the observed data is far outside the prior predictive distribution, we may need to reconsider our prior.

In [None]:
print("🔍 Diagnostic 1: Prior Predictive Check")
print("=" * 50)
print("Checking if our observed data is consistent with the prior...\n")

# Define the statistics names for the Lotka-Volterra model
stat_names = get_lv_summary_stats_names(USE_AUTOCORRELATION)

# Simulate data from the prior
# theta, x = simulate_for_sbi(lotka_volterra_simulator, prior, num_simulations=10000, num_workers=num_workers)
x = npe._x_roundwise[0]
prior_data_limits: list[tuple[float, float]] = [
    (x.min(dim=0).values[i], x.max(dim=0).values[i]) for i in range(x.shape[1])
]

# Run the prior predictive check using our reusable function
plot_predictive_check(
    x=x,
    observed_data=observed_data,
    stat_names=stat_names,
    title="Prior Predictive Check: Is our observed data consistent with the prior?",
);

## Step 3: Neural Network Training Diagnostics

**Key Question:** *Did the neural network converge during training?*

We need to check if the neural density estimator was trained properly:
- Training loss should decrease and stabilize
- Validation loss should not increase (no overfitting)
- Both loss curves should ideally converge to a plateaux

The `NPE` trainer class saves training and validation loss as well as other statistics
during training. 

In [None]:
# We can just pass the npe object to our plotting function.
plot_training_diagnostics(npe)

## Step 4: Posterior Predictive Check

**Key Question:** *If we simulate data using parameters from our posterior, does it look like our observed data?*

This is the most intuitive diagnostic:
1. Sample parameters from the posterior
2. Simulate data with those parameters
3. Compare simulated data to observations

If the posterior is correct, simulated data should be consistent with observations,
e.g., it should "center around" the observed data. Some fluctuations around the observed
data is valid, e.g., this could be due to simulator noise and parameter uncertainty.
However, when the observed data lies at the edges or outside of the posterior predictive
distributions, this hints at a problem with the overall inference setup.

In [None]:
# Posterior predictive checks: simulate data with parameters samples from the posterior.
_, predictive_sims = simulate_for_sbi(
    lotka_volterra_simulator, posterior, num_simulations=1000, num_workers=num_workers
)

# Visualize the posterior predictive distributions together with the observed data.
plot_predictive_check(
    predictive_sims,
    observed_data,
    stat_names=stat_names,
    limits=prior_data_limits,
    percentile_allowance=90,
);

## Step 5: Simulation-Based Calibration (SBC)

**Key Question:** *Are our credible intervals properly calibrated?*

SBC is a test that checks if our inference method is statistically well calibrated:
- Generate many "ground truth" parameters from the prior
- For each, simulate data and perform inference (many different posteriors)
- Check if X% credible intervals contain the true value X% of the time
- Do this by calculating the rank of a given ground truth parameter under the
  corresponding posterior--the ranks across all SBC samples must be uniformly
  distributed

Importantly, the SBC method can detect whether an SBI method is systematically biased,
e.g., systematically over or underestimates the position of the underlying posterior, or
whether the posterior variances are overconfident (too narrow) or underconfident (too
wide). This difference scenarios are revealed by the shape of the rank distribution.

Note, the SBC method can be "gamed" by just using a very wide posterior estimate, e.g.,
the prior itself. Therefore, it should always be used in conjunction with posterior
predictive checks that show that the posterior estimate is actually able to predict the
observed data.

We'll use the built-in `run_sbc` function from the sbi package for efficiency.

In [None]:
print("🔍 Diagnostic 4: Simulation-Based Calibration (SBC)")
print("=" * 50)
print("Testing if our credible intervals are properly calibrated...")
print("(Note: Using fewer tests for speed - in practice, use 100-1000)\n")

# Run SBC using sbi's built-in function
num_sbc_runs = 200  # Reduced for tutorial speed (use 200+ in practice)

thetas, xs = simulate_for_sbi(
    lotka_volterra_simulator,
    prior,
    num_simulations=num_sbc_runs,
    num_workers=num_workers,
)

print(f"Running {num_sbc_runs} SBC tests...")
ranks, dap_samples = run_sbc(
    thetas,
    xs,
    posterior,
    num_posterior_samples=1000,
    reduce_fns="marginals",
    use_batched_sampling=True,
)

In [None]:
from sbi.analysis import sbc_rank_plot
from sbi.diagnostics import check_sbc

# Perform statistical checks on the SBC results:
# Kolmogorov-Smirnov test on uniformity,
# c2st test of ranks vs uniform dist,
# c2st test of data averaged posterior samples vs prior samples
ks_pvals, c2st_ranks, c2st_daps = check_sbc(ranks, thetas, dap_samples).values()

param_names = [
    "α (deer birth)",
    "β (predation)",
    "δ (wolf efficiency)",
    "γ (wolf death)",
]

sbc_rank_plot(
    ranks,
    1000,
    num_bins=100,
    parameter_labels=param_names,
    plot_type="cdf",
    ranks_labels=["Lotka-Volterra"],
);

## 🎯 Key Takeaways

### Why Diagnostics Matter

1. **Neural networks can fail silently** - They might produce confident-looking but wrong results
2. **Not all posteriors are created equal** - Some might be overconfident, others too uncertain
3. **Trust but verify** - Always check your inference before making decisions

### Your Diagnostic Toolkit

| Diagnostic | What it checks | Red flag |
|------------|---------------|----------|
| Prior Predictive | Prior covers observations | Observations in extreme tails |
| Training Diagnostics | Network convergence | Loss increasing or unstable |
| Posterior Predictive | Can recreate observations | Observations outside CI |
| SBC | Calibrated credible intervals | Non-uniform rank histograms |

### Best Practices

✅ **Always run diagnostics** - Make it part of your workflow  
✅ **Start with the prior** - Bad prior → bad posterior  
✅ **Document results** - Keep diagnostic reports with your analysis  
✅ **Iterate if needed** - Poor diagnostics → adjust and retrain  

---

## 🚀 Challenge: Stress Testing

Try breaking the inference and seeing how diagnostics detect it:

1. **Too few simulations**: Retrain with only 500 simulations
2. **Wrong prior**: Use a prior that doesn't contain true parameters
3. **Corrupted data**: Add extreme outliers to observations

How do the diagnostics change? Which tests catch which problems?

In [None]:
# Space for experimentation
