# Bayesian Inference for Parameter Estimation with Uncertainty

## Learning Objectives

This tutorial covers:
1. **What is Bayesian inference?** - Probabilistic parameter estimation
2. **Prior distributions** - Encoding domain knowledge
3. **MCMC sampling** - Exploring posterior distributions
4. **Convergence diagnostics** - R-hat, effective sample size
5. **Credible intervals** - 95% uncertainty bounds
6. **Comparison with PINN** - When to use each method

## Bayes' Theorem

$$P(\theta|D) = \frac{P(D|\theta) \times P(\theta)}{P(D)}$$

Where:
- **P(θ|D)** = Posterior (what we want) - parameter distributions given data
- **P(D|θ)** = Likelihood - how well parameters explain data
- **P(θ)** = Prior - what we know before seeing data
- **P(D)** = Evidence (normalizing constant)

**Key Insight:** We get full probability distributions, not just point estimates!

In [None]:
# Imports
import numpy as np
import matplotlib.pyplot as plt
import sys

sys.path.insert(0, '..')

from ens_gi_core import ENSGIDigitalTwin

try:
    from ens_gi_bayesian import BayesianEstimator, BayesianConfig, get_default_priors
    import pymc3 as pm
    import arviz as az
    BAYESIAN_AVAILABLE = True
    print(f"PyMC3 version: {pm.__version__}")
    print(f"ArviZ version: {az.__version__}")
except ImportError as e:
    BAYESIAN_AVAILABLE = False
    print(f"Bayesian framework not available: {e}")
    print("Install with: pip install pymc3 arviz")

if BAYESIAN_AVAILABLE:
    print("\nBayesian framework ready!")
else:
    print("\nThis tutorial requires PyMC3 and ArviZ. Please install and restart.")

## Part 1: Understanding Prior Distributions

Priors encode what we know about parameters *before* seeing data.

In [None]:
if not BAYESIAN_AVAILABLE:
    raise ImportError("PyMC3 required for this tutorial")

# Get default priors
priors = get_default_priors()

print("Default Prior Distributions:")
print("="*70)
print(f"{'Parameter':<15} {'Distribution':<15} {'Mean':<10} {'Bounds'}")
print("-"*70)

for prior in priors[:6]:  # Show first 6
    dist_str = f"{prior.distribution}"
    mean_str = str(prior.params.get('mu', prior.params.get('lower', 'N/A')))
    bounds_str = f"[{prior.bounds[0]:.1f}, {prior.bounds[1]:.1f}]" if prior.bounds else "unbounded"
    print(f"{prior.name:<15} {dist_str:<15} {mean_str:<10} {bounds_str}")

print("="*70)
print("\nPrior Interpretation:")
print("  • g_Na ~ Normal(120, 20): We expect ~120 mS/cm² ± 20")
print("  • Bounds prevent unphysical values (e.g., negative conductances)")
print("  • Weakly informative: Allow data to dominate")

## Part 2: Create Synthetic Patient Data

In [None]:
# Create patient with known parameters (ground truth)
print("Creating synthetic patient...")
true_g_Na = 105.0
true_g_K = 38.0

patient = ENSGIDigitalTwin(n_segments=8)
for neuron in patient.network.neurons:
    neuron.params.g_Na = true_g_Na
    neuron.params.g_K = true_g_K

# Generate data
result = patient.run(1000, dt=0.1, I_stim={3: 10.0}, verbose=False)

# Add realistic measurement noise
noise_level = 2.0  # mV
noisy_voltages = result['voltages'] + np.random.normal(0, noise_level, result['voltages'].shape)

print(f"\nData generated:")
print(f"  Clean voltage: {result['voltages'].shape}")
print(f"  Noisy voltage: {noisy_voltages.shape}")
print(f"  Noise level: {noise_level} mV")
print(f"\nTrue parameters (what we want to recover):")
print(f"  g_Na = {true_g_Na}")
print(f"  g_K = {true_g_K}")

## Part 3: Configure Bayesian Estimator

In [None]:
# Configure MCMC sampling
config = BayesianConfig(
    n_chains=4,  # Run 4 parallel chains
    n_draws=1000,  # 1000 samples per chain
    n_tune=500,  # 500 tuning steps (burn-in)
    sampler='NUTS',  # No-U-Turn Sampler (best for continuous params)
    target_accept=0.95,  # High acceptance rate for accuracy
    progressbar=True
)

# Create estimator
twin_for_bayes = ENSGIDigitalTwin(n_segments=8)
bayes = BayesianEstimator(
    digital_twin=twin_for_bayes,
    config=config,
    parameter_names=['g_Na', 'g_K']
)

print("Bayesian Estimator configured:")
print(f"  Sampler: {config.sampler}")
print(f"  Chains: {config.n_chains}")
print(f"  Samples per chain: {config.n_draws}")
print(f"  Total samples: {config.n_chains * config.n_draws}")
print(f"  Parameters: {bayes.parameter_names}")

## Part 4: Run MCMC Sampling

This explores the posterior distribution using Markov Chain Monte Carlo.

In [None]:
# Run MCMC (this may take 3-5 minutes)
print("Running MCMC sampling...")
print("This will take 3-5 minutes. Please be patient!\n")

try:
    trace = bayes.estimate_parameters(
        observed_voltages=noisy_voltages,
        parameter_names=['g_Na', 'g_K']
    )
    
    print("\nMCMC sampling complete!")
    print(f"Trace contains {len(trace.posterior.chain) * len(trace.posterior.draw)} samples")
    
except Exception as e:
    print(f"\nMCMC failed: {e}")
    print("This is OK in test environments. Continuing with demo...")
    trace = None

## Part 5: Analyze Posterior Distributions

In [None]:
if trace is not None:
    # Get summary statistics
    summary = bayes.summarize_posterior(trace)
    
    print("\n" + "="*70)
    print("BAYESIAN POSTERIOR SUMMARY")
    print("="*70)
    print(f"\n{'Parameter':<12} {'True':<10} {'Mean':<12} {'SD':<10} {'95% CI':<25} {'R-hat'}")
    print("-"*70)
    
    true_vals = {'g_Na': true_g_Na, 'g_K': true_g_K}
    
    for param in ['g_Na', 'g_K']:
        if param in summary:
            s = summary[param]
            true = true_vals[param]
            ci_str = f"[{s['ci_lower']:.2f}, {s['ci_upper']:.2f}]"
            covered = "✓" if s['ci_lower'] <= true <= s['ci_upper'] else "✗"
            
            print(f"{param:<12} {true:<10.2f} {s['mean']:<12.2f} {s['std']:<10.2f} "
                  f"{ci_str:<25} {s.get('rhat', 1.0):.3f} {covered}")
    
    print("="*70)
    print("\nInterpretation:")
    print("  • Mean: Expected value (point estimate)")
    print("  • SD: Uncertainty in estimate")
    print("  • 95% CI: We're 95% confident true value is in this range")
    print("  • R-hat < 1.01: Chains converged (good!)")
    print("  • ✓: True value within 95% CI (validation success!)")
else:
    print("Trace not available (MCMC skipped).")

## Part 6: Visualize Posterior Distributions

In [None]:
if trace is not None:
    # Trace plot (shows MCMC chains)
    print("Generating trace plot...")
    az.plot_trace(trace, compact=True, figsize=(12, 6))
    plt.tight_layout()
    plt.show()
    
    print("\nTrace Plot Interpretation:")
    print("  Left: Distributions (should be smooth, unimodal)")
    print("  Right: MCMC chains (should be 'hairy caterpillars' - good mixing)")
    print("  All chains should overlap (convergence)")

In [None]:
if trace is not None:
    # Posterior density plot
    print("Generating posterior density plot...")
    az.plot_posterior(trace, figsize=(12, 4), hdi_prob=0.95)
    plt.tight_layout()
    plt.show()
    
    print("\nPosterior Plot Shows:")
    print("  • Peak: Most likely value")
    print("  • Width: Uncertainty")
    print("  • Shaded region: 95% highest density interval (HDI)")

## Part 7: Convergence Diagnostics

In [None]:
if trace is not None:
    # Check convergence
    print("Convergence Diagnostics:")
    print("="*70)
    
    # R-hat (should be < 1.01)
    rhat = az.rhat(trace)
    print("\nR-hat (Gelman-Rubin):")
    for var in rhat.data_vars:
        val = float(rhat[var].values)
        status = "✓ Good" if val < 1.01 else "✗ Poor" if val < 1.1 else "✗✗ Bad"
        print(f"  {var}: {val:.4f} {status}")
    
    # Effective sample size
    ess = az.ess(trace)
    print("\nEffective Sample Size (ESS):")
    for var in ess.data_vars:
        val = float(ess[var].values)
        total = config.n_chains * config.n_draws
        efficiency = val / total * 100
        status = "✓ Good" if efficiency > 10 else "~ OK" if efficiency > 5 else "✗ Poor"
        print(f"  {var}: {val:.0f} / {total} ({efficiency:.1f}%) {status}")
    
    print("\n" + "="*70)
    print("Interpretation:")
    print("  • R-hat < 1.01: Chains have converged")
    print("  • ESS > 10%: Sufficient independent samples")
else:
    print("Convergence diagnostics require trace (MCMC skipped).")

## Part 8: Compare with PINN Estimates

In [None]:
# Simulate PINN estimates (normally would run actual PINN)
print("Comparing Bayesian vs PINN:\n")

# Mock PINN estimates
pinn_estimates = {
    'g_Na': true_g_Na + np.random.normal(0, 3),
    'g_K': true_g_K + np.random.normal(0, 2)
}
pinn_uncertainties = {'g_Na': 5.0, 'g_K': 3.0}

if trace is not None:
    comparison = bayes.compare_with_pinn(
        pinn_estimates=pinn_estimates,
        pinn_uncertainties=pinn_uncertainties,
        trace=trace
    )
    
    print(f"{'Parameter':<12} {'PINN':<15} {'Bayesian':<15} {'Agreement'}")
    print("-"*60)
    
    for param in ['g_Na', 'g_K']:
        if param in comparison:
            comp = comparison[param]
            bayes_str = f"{summary[param]['mean']:.2f} ± {summary[param]['std']:.2f}"
            pinn_str = f"{pinn_estimates[param]:.2f} ± {pinn_uncertainties[param]:.2f}"
            agree = comp['agreement']
            
            print(f"{param:<12} {pinn_str:<15} {bayes_str:<15} {agree}")
    
    print("\nWhen to use each method:")
    print("  PINN: Fast (minutes), good for screening")
    print("  Bayesian: Rigorous uncertainty, regulatory-grade")
    print("  Best: Use both! PINN first, then Bayesian refinement")
else:
    print("Comparison requires trace (MCMC skipped).")

## Part 9: Save and Load Traces

In [None]:
if trace is not None:
    # Save trace
    bayes.save_trace('bayesian_trace.nc')
    print("Trace saved to: bayesian_trace.nc")
    
    # Load trace
    loaded_trace = BayesianEstimator.load_trace('bayesian_trace.nc')
    print("Trace loaded successfully!")
    
    # Verify
    loaded_summary = bayes.summarize_posterior(loaded_trace)
    print(f"\nOriginal g_Na mean: {summary['g_Na']['mean']:.2f}")
    print(f"Loaded g_Na mean: {loaded_summary['g_Na']['mean']:.2f}")
    print("Traces match!")
else:
    print("Save/load requires trace (MCMC skipped).")

## Summary

This tutorial demonstrated:

1. ✅ **Bayesian inference** - Probabilistic parameter estimation
2. ✅ **Prior distributions** - Encoding domain knowledge
3. ✅ **MCMC sampling** - NUTS algorithm
4. ✅ **Posterior analysis** - Mean, SD, credible intervals
5. ✅ **Convergence diagnostics** - R-hat, ESS
6. ✅ **Visualization** - Trace plots, posterior densities
7. ✅ **Comparison with PINN** - Complementary methods
8. ✅ **Persistence** - Save/load traces

### Key Advantages of Bayesian Inference:

- **Full distributions**: Not just point estimates
- **Rigorous uncertainty**: 95% credible intervals
- **Principled**: Based on probability theory
- **Flexible**: Can incorporate expert knowledge (priors)
- **Interpretable**: "95% confident parameter is between X and Y"

### Bayesian vs PINN:

| Aspect | Bayesian | PINN |
|--------|----------|------|
| Speed | Slow (minutes-hours) | Fast (<5 min) |
| Uncertainty | Full posterior | Bootstrap CI |
| Interpretability | Probabilistic | Point + interval |
| Scalability | Limited (10-20 params) | High (100+ params) |
| Rigor | Regulatory-grade | Research-grade |

### Recommended Workflow:

1. **PINN** - Quick screening (~2 minutes)
2. **Bayesian** - Refinement with uncertainty (~30 minutes)
3. **Clinical decision** - Use Bayesian credible intervals

### Next Steps:

- Try the PINN tutorial (`pinn_tutorial.ipynb`)
- Explore clinical workflow (`clinical_workflow.ipynb`)
- Test on real patient data
- Submit for regulatory approval!