# Exercise 3: Apply SBI to Your Own Problem 🚀

**Time:** 20 minutes  
**Goal:** Apply what you've learned to a new simulator

## 🎯 Learning Objectives

By the end of this exercise, you will:
1. ✅ Adapt the SBI workflow to a new problem
2. ✅ Define appropriate priors for your parameters
3. ✅ Run inference and diagnostics on your simulator
4. ✅ Leave with working code you can adapt

## Choose Your Adventure!

We provide two well-tested example simulators, or you can bring your own:

### 🎾 Option A: Ball Throw Physics
- **Story**: You're analyzing baseball pitches or golf drives
- **Physics**: Projectile motion with air resistance
- **Challenge**: Infer launch conditions from landing position

### 🦠 Option B: SIR Epidemic Model
- **Story**: You're tracking disease spread in a community
- **Model**: Classic compartmental epidemic model
- **Challenge**: Infer transmission rates from outbreak data

### 🔬 Option C: Your Own Simulator
- Bring your research problem!
- We'll help you adapt it

## Setup

In [None]:
# Core imports
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# SBI imports
from sbi import inference
from sbi import analysis
from sbi import utils

# Our example simulators
import sys

sys.path.append("..")
from simulators.ball_throw import ball_throw_simulator, create_ball_throw_prior
from simulators.sir_model import sir_epidemic_simulator, create_sir_prior

# Set style
plt.style.use("seaborn-v0_8-darkgrid")
sns.set_palette("colorblind")

# Random seed
torch.manual_seed(42)
np.random.seed(42)

print("✅ Ready to apply SBI to your problem!")

## Part 1: Explore the Simulators

Let's understand what each simulator does before choosing one.

### 🎾 Ball Throw Physics

This simulator models projectile motion with air resistance:

**Differential equations:**
- Horizontal: `d²x/dt² = wind - friction·dx/dt`
- Vertical: `d²y/dt² = -gravity - friction·dy/dt`

**Parameters to infer:**
1. Initial velocity (5-30 m/s)
2. Launch angle (0.2-1.4 radians ≈ 11°-80°)
3. Friction coefficient (0.0-0.5)

**What we observe:**
- Landing distance (meters)
- Maximum height reached (meters)

In [None]:
# Test the ball throw simulator
test_params = torch.tensor([15.0, 0.8, 0.1])  # 15 m/s, ~45°, low friction
observations = ball_throw_simulator(test_params)

print("🎾 Ball Throw Test:")
print(
    f"  Parameters: v₀={test_params[0]:.1f} m/s, θ={test_params[1]:.2f} rad, μ={test_params[2]:.2f}"
)
print(
    f"  Observations: distance={observations[0]:.1f}m, max_height={observations[1]:.1f}m"
)

# Visualize a trajectory
obs, x_traj, y_traj = ball_throw_simulator(test_params, return_trajectory=True)

plt.figure(figsize=(10, 4))
plt.plot(x_traj, y_traj, "b-", linewidth=2, label="Trajectory")
plt.scatter([obs[0].item()], [0], color="red", s=100, zorder=5, label="Landing")
plt.scatter(
    [x_traj[np.argmax(y_traj)]],
    [obs[1].item()],
    color="green",
    s=100,
    zorder=5,
    label="Peak",
)
plt.xlabel("Distance (m)", fontsize=12)
plt.ylabel("Height (m)", fontsize=12)
plt.title("Ball Trajectory with Air Resistance", fontsize=14)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.show()

print(
    "\n💡 We observe only the landing distance and max height, not the full trajectory!"
)

### 🦠 SIR Epidemic Model

This simulator models disease spread through a population:

**Compartments:**
- **S**usceptible: Can catch the disease
- **I**nfected: Currently sick and contagious
- **R**ecovered: Immune after recovery

**Differential equations:**
- `dS/dt = -β·S·I/N` (infection)
- `dI/dt = β·S·I/N - γ·I` (infection - recovery)
- `dR/dt = γ·I` (recovery)

**Parameters to infer:**
1. β: Infection rate (0.1-2.0 per day)
2. γ: Recovery rate (0.05-0.5 per day)
3. I₀: Initial infected count (1-100 people)

**What we observe:**
- Peak number of infected
- Time to reach peak (days)
- Total recovered at end
- Epidemic duration (days)

In [None]:
# Test the SIR simulator
test_params = torch.tensor([0.5, 0.1, 10])  # β=0.5, γ=0.1, I₀=10
observations = sir_epidemic_simulator(test_params)

print("🦠 SIR Epidemic Test:")
print(
    f"  Parameters: β={test_params[0]:.2f}, γ={test_params[1]:.2f}, I₀={test_params[2]:.0f}"
)
print(f"  Basic reproduction number R₀ = β/γ = {test_params[0] / test_params[1]:.1f}")
print(f"\n  Observations:")
print(f"    Peak infected: {observations[0]:.0f} people")
print(f"    Time to peak: {observations[1]:.0f} days")
print(f"    Total recovered: {observations[2]:.0f} people")
print(f"    Epidemic duration: {observations[3]:.0f} days")

# Visualize epidemic curves
obs, time_series = sir_epidemic_simulator(test_params, return_time_series=True)

plt.figure(figsize=(10, 5))
plt.plot(
    time_series["t"], time_series["S"], label="Susceptible", linewidth=2, color="blue"
)
plt.plot(time_series["t"], time_series["I"], label="Infected", linewidth=2, color="red")
plt.plot(
    time_series["t"], time_series["R"], label="Recovered", linewidth=2, color="green"
)

# Mark observations
peak_idx = np.argmax(time_series["I"])
plt.scatter(
    [time_series["t"][peak_idx]],
    [time_series["I"][peak_idx]],
    color="red",
    s=100,
    zorder=5,
    label=f"Peak: {obs[0]:.0f}",
)

plt.xlabel("Time (days)", fontsize=12)
plt.ylabel("Number of individuals", fontsize=12)
plt.title("SIR Epidemic Dynamics (Population = 10,000)", fontsize=14)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.show()

print("\n💡 We observe summary statistics, not the full time series!")

### 🔬 Your Own Simulator

If you brought your own simulator, adapt this template:

In [None]:
def your_simulator(params):
    """
    Template for your own simulator.

    Requirements:
    1. Takes parameters (torch.Tensor or numpy array)
    2. Returns observations (torch.Tensor)
    3. Should include some stochasticity (noise)
    4. Runs reasonably fast (< 1 second)
    """
    # Convert to torch if needed
    if isinstance(params, np.ndarray):
        params = torch.tensor(params, dtype=torch.float32)

    # Your simulation code here
    # ...

    # Add observation noise (important!)
    # observations = observations * (1 + torch.randn_like(observations) * 0.05)

    # Return as torch tensor
    # return torch.tensor(observations, dtype=torch.float32)

    pass  # Remove this when implementing

## Part 2: Choose Your Simulator and Run SBI

**👇 Choose ONE option below by uncommenting the appropriate section:**

In [None]:
# ========== OPTION A: Ball Throw ==========
simulator = ball_throw_simulator
prior = create_ball_throw_prior(include_wind=False)  # Set True to include wind
param_names = ["v₀ (m/s)", "θ (rad)", "μ (friction)"]
obs_names = ["distance (m)", "max height (m)"]

# ========== OPTION B: SIR Model ==========
# simulator = sir_epidemic_simulator
# prior = create_sir_prior()
# param_names = ["β (infection)", "γ (recovery)", "I₀ (initial)"]
# obs_names = ["peak infected", "time to peak", "total recovered", "duration"]

# ========== OPTION C: Your Simulator ==========
# simulator = your_simulator
# prior = utils.BoxUniform(
#     low=torch.tensor([...]),   # Your parameter lower bounds
#     high=torch.tensor([...])   # Your parameter upper bounds
# )
# param_names = [...]  # Your parameter names
# obs_names = [...]    # Your observable names

print(f"Selected simulator: {simulator.__name__}")
print(f"Parameters: {param_names}")
print(f"Observables: {obs_names}")

## Part 3: Generate "Observed" Data

In a real application, this would be your experimental data.

In [None]:
# Generate synthetic observation (ground truth for testing)
true_params = prior.sample()
observed_data = simulator(true_params)

print("\n🎯 True parameters (hidden in real applications):")
for i, name in enumerate(param_names):
    print(f"  {name}: {true_params[i]:.3f}")

print("\n📊 Observed data:")
for i, name in enumerate(obs_names):
    print(f"  {name}: {observed_data[i]:.3f}")

print("\n🎲 Challenge: Can we recover the true parameters from observations alone?")

## Part 4: Run Neural Posterior Estimation 🚀

The same 4-step workflow from Exercise 1!

In [None]:
# Step 1: Create NPE object
npe = inference.NPE(prior=prior)

# Step 2: Train on simulations
print("🏃 Training neural network...")
print("This will take 20-40 seconds...\n")

npe = npe.append_simulations(
    simulator,
    num_simulations=5000,  # Use 10000+ for better results
).train()

# Step 3: Build posterior for our observation
posterior = npe.build_posterior()

# Step 4: Sample from posterior
posterior_samples = posterior.sample((2000,), x=observed_data, show_progress_bars=True)

print("\n✅ Inference complete! Let's see what we learned...")

## Part 5: Visualize Results 📊

In [None]:
# Compare posterior to prior and truth
fig, axes = plt.subplots(1, len(true_params), figsize=(4 * len(true_params), 4))

# Handle single parameter case
if len(true_params) == 1:
    axes = [axes]

for i, ax in enumerate(axes):
    # Posterior
    ax.hist(
        posterior_samples[:, i],
        bins=30,
        alpha=0.7,
        density=True,
        label="Posterior",
        color="green",
        edgecolor="darkgreen",
    )

    # Prior
    prior_samples = prior.sample((1000,))
    ax.hist(
        prior_samples[:, i],
        bins=30,
        alpha=0.3,
        density=True,
        label="Prior",
        color="gray",
        edgecolor="black",
    )

    # Truth
    ax.axvline(
        true_params[i], color="red", linewidth=2, linestyle="--", label="True value"
    )

    # Posterior mean
    post_mean = posterior_samples[:, i].mean()
    ax.axvline(post_mean, color="blue", linewidth=2, linestyle="-", label="Post. mean")

    ax.set_xlabel(param_names[i], fontsize=12)
    ax.set_ylabel("Density", fontsize=12)
    ax.legend(fontsize=10, loc="upper right")
    ax.grid(True, alpha=0.3)

plt.suptitle("Posterior vs Prior: Learning from Data", fontsize=14)
plt.tight_layout()
plt.show()

# Print quantitative summary
print("\n📈 Parameter Recovery Summary:")
print("=" * 60)
print(f"{'Parameter':<20} {'True':<10} {'Post. Mean ± Std':<20} {'Error':<10}")
print("-" * 60)
for i, name in enumerate(param_names):
    true_val = true_params[i].item()
    post_mean = posterior_samples[:, i].mean().item()
    post_std = posterior_samples[:, i].std().item()
    error = abs(true_val - post_mean)
    print(
        f"{name:<20} {true_val:<10.3f} {post_mean:.3f} ± {post_std:.3f}  {error:<10.3f}"
    )

# Calculate credible intervals
print("\n📊 95% Credible Intervals:")
for i, name in enumerate(param_names):
    q_low = torch.quantile(posterior_samples[:, i], 0.025).item()
    q_high = torch.quantile(posterior_samples[:, i], 0.975).item()
    true_val = true_params[i].item()
    in_ci = "✅" if q_low <= true_val <= q_high else "❌"
    print(f"  {name}: [{q_low:.3f}, {q_high:.3f}] {in_ci}")

## Part 6: Diagnostic - Posterior Predictive Check 🔍

Can parameters from our posterior reproduce the observed data?

In [None]:
# Generate posterior predictive samples
n_predictive = 200
predictive_data = []

print("Generating posterior predictive samples...")
for _ in range(n_predictive):
    # Sample parameters from posterior
    param_sample = posterior.sample((1,), x=observed_data)
    # Simulate with those parameters
    sim_data = simulator(param_sample[0])
    predictive_data.append(sim_data)

predictive_data = torch.stack(predictive_data)

# Visualize
fig, axes = plt.subplots(1, len(observed_data), figsize=(4 * len(observed_data), 4))

if len(observed_data) == 1:
    axes = [axes]

for i, ax in enumerate(axes):
    # Predictive distribution
    ax.hist(
        predictive_data[:, i],
        bins=25,
        alpha=0.6,
        label="Predictive",
        color="blue",
        density=True,
    )

    # Observed value
    ax.axvline(
        observed_data[i], color="red", linewidth=2, linestyle="--", label="Observed"
    )

    # Add percentiles
    q5 = torch.quantile(predictive_data[:, i], 0.05)
    q95 = torch.quantile(predictive_data[:, i], 0.95)
    ax.axvspan(q5, q95, alpha=0.2, color="blue", label="90% CI")

    ax.set_xlabel(obs_names[i], fontsize=12)
    ax.set_ylabel("Density", fontsize=12)
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3)

plt.suptitle("Posterior Predictive Check: Can We Reproduce the Data?", fontsize=14)
plt.tight_layout()
plt.show()

# Check if observations fall within predictive distribution
print("\n✅ Diagnostic Results:")
for i, name in enumerate(obs_names):
    percentile = (predictive_data[:, i] < observed_data[i]).float().mean() * 100
    print(f"  {name}: Observed value is at {percentile:.1f}th percentile")
    if 5 < percentile < 95:
        print(f"    ✅ Well within predictive distribution")
    else:
        print(f"    ⚠️ Near edge of predictive distribution")

print("\n💡 If observed data falls within the predictive distribution,")
print("   our posterior is consistent with the observations!")

## Part 7: Explore Parameter Correlations 🔗

In [None]:
if len(param_names) > 1:
    # Compute correlation matrix
    correlation_matrix = torch.corrcoef(posterior_samples.T)

    # Visualize
    plt.figure(figsize=(8, 6))
    im = plt.imshow(correlation_matrix, cmap="RdBu_r", vmin=-1, vmax=1)
    plt.colorbar(im, label="Correlation")

    # Add labels
    n_params = len(param_names)
    plt.xticks(range(n_params), param_names, rotation=45, ha="right")
    plt.yticks(range(n_params), param_names)

    # Add correlation values
    for i in range(n_params):
        for j in range(n_params):
            text = plt.text(
                j,
                i,
                f"{correlation_matrix[i, j]:.2f}",
                ha="center",
                va="center",
                color="white" if abs(correlation_matrix[i, j]) > 0.5 else "black",
                fontsize=12,
                fontweight="bold",
            )

    plt.title("Parameter Correlations in Posterior", fontsize=14)
    plt.tight_layout()
    plt.show()

    # Identify strong correlations
    print("\n🔗 Parameter Correlations:")
    for i in range(n_params):
        for j in range(i + 1, n_params):
            corr = correlation_matrix[i, j].item()
            if abs(corr) > 0.3:
                strength = "Strong" if abs(corr) > 0.7 else "Moderate"
                direction = "positive" if corr > 0 else "negative"
                print(
                    f"  {param_names[i]} ↔ {param_names[j]}: {strength} {direction} ({corr:.2f})"
                )

    print("\n💡 Correlations reveal parameter trade-offs and identifiability issues!")
else:
    print("Single parameter - no correlations to show.")

## Part 8: Experiments & What-If Analysis 🔬

Let's explore how different choices affect our inference!

In [None]:
# Experiment 1: What if we had noisier observations?
print("🔬 Experiment 1: Effect of Observation Noise\n")

# Add extra noise to observations
noisy_obs = observed_data * (1 + torch.randn_like(observed_data) * 0.2)

# Get posterior for noisy data
noisy_posterior_samples = posterior.sample((500,), x=noisy_obs)

print(f"Original observations: {observed_data.numpy()}")
print(f"Noisy observations:    {noisy_obs.numpy()}")
print(f"\nPosterior uncertainty (std):")
print(f"  Original: {posterior_samples.std(dim=0).numpy()}")
print(f"  Noisy:    {noisy_posterior_samples.std(dim=0).numpy()}")
print("\n💡 Noisier observations → higher posterior uncertainty!")

In [None]:
# Experiment 2: How many simulations do we really need?
print("🔬 Experiment 2: Effect of Training Data Size\n")

# Train with fewer simulations
small_npe = inference.NPE(prior=prior)
small_npe = small_npe.append_simulations(
    simulator,
    num_simulations=500,  # 10x fewer!
).train(show_train_summary=False)

small_posterior = small_npe.build_posterior()
small_samples = small_posterior.sample((500,), x=observed_data)

# Compare accuracy
print("Parameter recovery (distance from truth):")
for i, name in enumerate(param_names):
    error_5000 = abs(posterior_samples[:, i].mean() - true_params[i]).item()
    error_500 = abs(small_samples[:, i].mean() - true_params[i]).item()
    print(f"  {name}:")
    print(f"    5000 sims: error = {error_5000:.4f}")
    print(f"    500 sims:  error = {error_500:.4f}")

print("\n💡 More simulations → better parameter recovery!")
print("   But diminishing returns after ~10,000 simulations.")

## 🎯 Challenge: Multiple Observations

What if you had multiple independent observations?

In [None]:
# Generate multiple observations with the same true parameters
n_observations = 3
multi_obs = torch.stack([simulator(true_params) for _ in range(n_observations)])

print(f"🔄 Generated {n_observations} independent observations:")
for i, obs in enumerate(multi_obs):
    print(f"  Obs {i + 1}: {obs.numpy()}")

# Strategy 1: Use the mean observation
mean_obs = multi_obs.mean(dim=0)
print(f"\n📊 Mean observation: {mean_obs.numpy()}")

# Get posterior for mean observation
multi_posterior_samples = posterior.sample((1000,), x=mean_obs)

# Compare uncertainties
print("\n📈 Posterior uncertainty (std):")
print(f"  Single obs:   {posterior_samples.std(dim=0).numpy()}")
print(f"  Mean of {n_observations} obs: {multi_posterior_samples.std(dim=0).numpy()}")

print("\n✅ Multiple observations reduce uncertainty!")
print("\n💡 Advanced: For proper treatment of multiple observations,")
print("   retrain NPE with concatenated observations or use Sequential NPE.")

## 🎉 Congratulations!

You've successfully:
- ✅ Applied SBI to different problems
- ✅ Learned the universal NPE workflow
- ✅ Performed diagnostic checks
- ✅ Explored how choices affect inference

### 🔑 Key Takeaways:

1. **SBI is universal** - Same workflow for any simulator!
2. **Prior choice matters** - Must cover true parameters
3. **Diagnostics are essential** - Always check predictive distributions
4. **More data = less uncertainty** - Both simulations and observations help
5. **Parameters can be correlated** - Trade-offs and identifiability

### 🚀 Next Steps:

For your research:
1. **Start simple** - Test with synthetic data first
2. **Scale up gradually** - Increase complexity step by step
3. **Use Sequential NPE** - More efficient for expensive simulators
4. **Try other methods** - NLE, NRE for different use cases

### 📚 Resources:

- 📖 [SBI Documentation](https://sbi-dev.github.io/sbi)
- 💻 [GitHub Repository](https://github.com/sbi-dev/sbi)
- 📰 [JOSS Paper](https://joss.theoj.org/papers/10.21105/joss.02505)
- 💬 [Community Forum](https://github.com/sbi-dev/sbi/discussions)

---

## 🙏 Thank you for participating!

**Now go forth and quantify uncertainty in your simulators!**