# Bayesian Inference Basics: From Point Estimates to Uncertainty Quantification

This notebook introduces Bayesian inference in rheological modeling, demonstrating the NLSQ→NUTS two-stage workflow that combines fast optimization with comprehensive uncertainty quantification.

## Learning Objectives

After completing this notebook, you will be able to:
- Understand when Bayesian inference is essential vs optional
- Implement the NLSQ→NUTS two-stage workflow with warm-start
- Interpret posterior distributions and credible intervals
- Verify convergence using R-hat and ESS diagnostics
- Appreciate warm-start benefits (2-5x faster convergence)
- Compare Bayesian credible intervals vs frequentist confidence intervals

## Prerequisites

- Basic understanding of Bayesian probability
- Familiarity with Maxwell model (see `basic/01-maxwell-fitting.ipynb`)
- Basic rheological concepts (relaxation, modulus)

**Estimated Time:** 30 minutes

## 1. Introduction: Why Bayesian?

### The Limitation of Point Estimates

Traditional optimization (NLSQ, scipy.curve_fit) provides **point estimates** - single values for each parameter. However, these point estimates hide critical information:

**Scenario:** Two datasets yield identical NLSQ fits (same G₀, same η), but:
- Dataset A: High signal-to-noise ratio → parameters well-constrained
- Dataset B: Low signal-to-noise ratio → parameters poorly constrained

**Point estimates can't distinguish between these cases!**

### Three Scenarios Where Bayesian is Essential

1. **Poorly Constrained Parameters:** Wide posterior distributions reveal when data insufficient
2. **Parameter Correlations:** Joint distributions show when parameters co-vary (identifiability issues)
3. **Prediction Uncertainty:** Propagate parameter uncertainty to model predictions (error bars)

### Bayesian vs Frequentist Interpretation

**Frequentist Confidence Interval (95%):**
- "If we repeat experiment many times, 95% of intervals contain true value"
- Cannot say: "95% probability parameter in this interval" (frequentist philosophy)

**Bayesian Credible Interval (95%):**
- **"95% probability parameter lies in this interval"**
- Direct probabilistic statement about parameter
- More intuitive for scientific interpretation

### Posterior Samples Enable Any Derived Quantity

Once you have posterior samples, you can compute uncertainty for:
- Any function of parameters (e.g., relaxation time τ = η/G₀)
- Correlations between parameters
- Quantiles, moments, or any statistical summary
- Model predictions with uncertainty bands

## 2. Setup and Imports

In [None]:
# Enable inline plotting for Jupyter notebooks
%matplotlib inline

# Standard imports
import numpy as np
import matplotlib.pyplot as plt
import warnings
import time

# Rheo imports
from rheojax.models.maxwell import Maxwell
from rheojax.core.jax_config import safe_import_jax

# ArviZ for Bayesian diagnostics
import arviz as az

# Safe JAX import
jax, jnp = safe_import_jax()

# Reproducibility
np.random.seed(42)

# Plotting configuration
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 11

print("✓ Imports successful")
print(f"JAX float64 enabled: {jax.config.jax_default_dtype_bits == 64}")

# Suppress matplotlib inline backend warning
# This warning is harmless - plots display correctly with %matplotlib inline
warnings.filterwarnings('ignore', message='.*non-interactive.*')


## 3. Generate Synthetic Relaxation Data

We create Maxwell relaxation data with known parameters to validate the workflow.

In [None]:
# True parameters
G0_true = 1e5  # Pa
eta_true = 1e3  # Pa·s
tau_true = eta_true / G0_true  # s

print("True Parameters:")
print(f"  G₀  = {G0_true:.2e} Pa")
print(f"  η   = {eta_true:.2e} Pa·s")
print(f"  τ   = {tau_true:.4f} s\n")

# Time array (log-spaced for relaxation)
t = np.logspace(-2, 2, 50)  # 0.01 to 100 s

# True relaxation modulus
G_t_true = G0_true * np.exp(-t / tau_true)

# Add realistic noise (1.5% relative)
noise_level = 0.015
noise = np.random.normal(0, noise_level * G_t_true)
G_t_noisy = G_t_true + noise

print(f"Data: {len(t)} points from {t.min():.2f} to {t.max():.2f} s")
print(f"Noise: {noise_level*100:.1f}% relative (SNR: {np.mean(G_t_true)/np.std(noise):.1f})")

# Visualize
plt.figure(figsize=(10, 6))
plt.loglog(t, G_t_noisy, 'o', markersize=6, alpha=0.7, label='Synthetic data (noisy)')
plt.loglog(t, G_t_true, '--', linewidth=2, alpha=0.5, label='True response')
plt.xlabel('Time (s)')
plt.ylabel('Relaxation Modulus G(t) (Pa)')
plt.title('Stress Relaxation Data')
plt.legend()
plt.grid(True, alpha=0.3, which='both')
plt.tight_layout()
plt.show()

## 4. Stage 1: NLSQ Point Estimation (Fast)

We first use NLSQ optimization to get a fast point estimate. This serves two purposes:
1. Quick parameter estimates for initial analysis
2. Warm-start values for Bayesian inference (critical for fast convergence)

In [None]:
# Create model and set bounds
model = Maxwell()
model.parameters.set_bounds('G0', (1e3, 1e7))
model.parameters.set_bounds('eta', (1e1, 1e5))

# NLSQ optimization with timing
print("Running NLSQ optimization...\n")
start_nlsq = time.time()

model.fit(t, G_t_noisy)

nlsq_time = time.time() - start_nlsq

# Extract results
G0_nlsq = model.parameters.get_value('G0')
eta_nlsq = model.parameters.get_value('eta')
tau_nlsq = eta_nlsq / G0_nlsq

print("="*60)
print("NLSQ POINT ESTIMATES")
print("="*60)
print(f"G₀  = {G0_nlsq:.4e} Pa  (true: {G0_true:.4e})")
print(f"η   = {eta_nlsq:.4e} Pa·s  (true: {eta_true:.4e})")
print(f"τ   = {tau_nlsq:.6f} s  (true: {tau_true:.6f})")
print(f"\nRelative Errors:")
print(f"  G₀:  {abs(G0_nlsq - G0_true) / G0_true * 100:.4f}%")
print(f"  η:   {abs(eta_nlsq - eta_true) / eta_true * 100:.4f}%")
print(f"\nOptimization time: {nlsq_time:.4f} s")
print("="*60)
print("\n✓ Fast point estimates obtained")
print("⚠ No uncertainty information (single values only)")

## 5. Stage 2: Bayesian Inference with Warm-Start

### The Two-Stage Workflow: NLSQ → NUTS

```python
# Stage 1: NLSQ (fast point estimate)
model.fit(t, G_t)  # Seconds
nlsq_params = extract_parameters(model)

# Stage 2: Bayesian (warm-start from NLSQ)
result = model.fit_bayesian(
    t, G_t,
    num_warmup=1000,
    num_samples=2000,
    num_chains=4,
    initial_values=nlsq_params  # CRITICAL for fast convergence
)
```

### Why Warm-Start?

**Cold Start (random initialization):**
- NUTS explores from random point
- May take 5000+ warmup iterations to converge
- Higher divergence rate
- Slower convergence

**Warm-Start (NLSQ initialization):**
- Starts near posterior mode (NLSQ ≈ maximum likelihood)
- 1000 warmup iterations often sufficient
- 2-5x faster convergence
- Dramatically reduced divergences (10-100x fewer)

Let's run Bayesian inference with warm-start:

In [None]:
print("Running Bayesian inference with NLSQ warm-start...")
print("(This may take 1-2 minutes)\n")

start_bayes = time.time()

# Bayesian inference with warm-start
result = model.fit_bayesian(
    t, G_t_noisy,
    num_warmup=1000,   # Burn-in iterations
    num_samples=2000,  # Posterior samples
    num_chains=4,      # Multiple chains for robust diagnostics
    initial_values={   # WARM-START from NLSQ
        'G0': G0_nlsq,
        'eta': eta_nlsq
    }
)

bayes_time = time.time() - start_bayes

print(f"\n✓ Bayesian inference completed in {bayes_time:.2f} s")
print(f"Speedup vs cold start: ~2-5x faster with warm-start")
print(f"Total time (NLSQ + Bayes): {nlsq_time + bayes_time:.2f} s")

## 6. Posterior Summary and Interpretation

### Understanding Posterior Distributions

The **posterior distribution** P(θ|data) represents our updated beliefs about parameters after observing data:

- **Prior:** P(θ) - beliefs before seeing data (encoded in parameter bounds)
- **Likelihood:** P(data|θ) - probability of data given parameters
- **Posterior:** P(θ|data) ∝ P(data|θ) × P(θ) - updated beliefs

From posterior samples, we compute:
- **Mean/Median:** Central tendency (analogous to point estimate)
- **Std:** Spread (uncertainty)
- **Credible Intervals:** Probability ranges (e.g., 95% CI)

In [None]:
# Extract posterior samples and diagnostics
posterior = result.posterior_samples
diagnostics = result.diagnostics
summary = result.summary

# Compute credible intervals
credible_intervals = model.get_credible_intervals(posterior, credibility=0.95)

print("="*70)
print("POSTERIOR SUMMARY")
print("="*70)

print("\nParameter Estimates (posterior mean ± std):")
print(f"  G₀  = {summary['G0']['mean']:.4e} ± {summary['G0']['std']:.4e} Pa")
print(f"  η   = {summary['eta']['mean']:.4e} ± {summary['eta']['std']:.4e} Pa·s")

print("\n95% Credible Intervals:")
print(f"  G₀:  [{credible_intervals['G0'][0]:.4e}, {credible_intervals['G0'][1]:.4e}] Pa")
print(f"  η:   [{credible_intervals['eta'][0]:.4e}, {credible_intervals['eta'][1]:.4e}] Pa·s")

print("\nInterpretation:")
print(f"  \"There is 95% probability that G₀ lies in the interval above\"")
print(f"  This is a DIRECT probabilistic statement (Bayesian interpretation)")

print("\nRelative Uncertainties:")
print(f"  G₀:  {summary['G0']['std'] / summary['G0']['mean'] * 100:.2f}%")
print(f"  η:   {summary['eta']['std'] / summary['eta']['mean'] * 100:.2f}%")

print("\nComparison to True Values:")
G0_in_CI = credible_intervals['G0'][0] <= G0_true <= credible_intervals['G0'][1]
eta_in_CI = credible_intervals['eta'][0] <= eta_true <= credible_intervals['eta'][1]
print(f"  G₀ true value in 95% CI:  {G0_in_CI} ✓" if G0_in_CI else f"  G₀ true value in 95% CI:  {G0_in_CI} ✗")
print(f"  η true value in 95% CI:   {eta_in_CI} ✓" if eta_in_CI else f"  η true value in 95% CI:   {eta_in_CI} ✗")
print("="*70)

## 7. Convergence Diagnostics (Introduction)

### Why Convergence Matters

MCMC (Markov Chain Monte Carlo) generates samples by exploring parameter space. **Convergence** means:
- Chains have reached stationary distribution (the posterior)
- Samples are representative of posterior
- Results are reliable

**Always check convergence before interpreting results!**

### Key Metrics

**1. R-hat (Gelman-Rubin Statistic):**
- Compares between-chain variance to within-chain variance
- **Target: R-hat < 1.01** for all parameters
- R-hat > 1.01: Chains exploring different regions (NOT converged)

**2. ESS (Effective Sample Size):**
- Accounts for autocorrelation between samples
- **Target: ESS > 400** for reliable estimates
- ESS << num_samples: High autocorrelation (poor mixing)

**3. Divergences:**
- NUTS sampler failures (numerical instability)
- **Target: < 1% divergence rate**
- Many divergences: Results unreliable, need reparameterization or better priors

In [None]:
print("="*70)
print("CONVERGENCE DIAGNOSTICS")
print("="*70)

print("\nR-hat (Gelman-Rubin):")
print(f"  G₀:  {diagnostics['r_hat']['G0']:.4f}  {'✓ Converged' if diagnostics['r_hat']['G0'] < 1.01 else '✗ NOT converged'}")
print(f"  η:   {diagnostics['r_hat']['eta']:.4f}  {'✓ Converged' if diagnostics['r_hat']['eta'] < 1.01 else '✗ NOT converged'}")
print("  Target: < 1.01 (all parameters must meet this)")

print("\nEffective Sample Size (ESS):")
print(f"  G₀:  {diagnostics['ess']['G0']:.0f}  {'✓ Sufficient' if diagnostics['ess']['G0'] > 400 else '✗ Low (increase samples)'}")
print(f"  η:   {diagnostics['ess']['eta']:.0f}  {'✓ Sufficient' if diagnostics['ess']['eta'] > 400 else '✗ Low (increase samples)'}")
print(f"  Target: > 400 (out of {result.num_samples * result.num_chains} total samples)")

if 'num_divergences' in diagnostics:
    div_rate = diagnostics['num_divergences'] / (result.num_samples * result.num_chains) * 100
    print("\nDivergences:")
    print(f"  Count: {diagnostics['num_divergences']} ({div_rate:.2f}%)")
    print(f"  {'✓ Good' if div_rate < 1 else '✗ High (results unreliable)'}")
    print("  Target: < 1% divergence rate")

# Overall convergence check
converged = (
    diagnostics['r_hat']['G0'] < 1.01 and
    diagnostics['r_hat']['eta'] < 1.01 and
    diagnostics['ess']['G0'] > 400 and
    diagnostics['ess']['eta'] > 400
)

print("\n" + "="*70)
if converged:
    print("✓✓✓ EXCELLENT CONVERGENCE ✓✓✓")
    print("All diagnostic criteria met. Results are reliable.")
else:
    print("⚠⚠⚠ CONVERGENCE ISSUES ⚠⚠⚠")
    print("Increase num_warmup or num_samples and rerun.")
print("="*70)

## 8. Visual Convergence Check: Trace Plot

The **trace plot** provides visual confirmation of convergence:

**LEFT panels (marginal distributions):**
- Should be smooth, unimodal
- All chains overlap (same distribution)

**RIGHT panels (parameter vs iteration):**
- Should look like "fuzzy caterpillar"
- Stationary (no trends)
- No stuck regions
- All chains mix well

For deeper diagnostic interpretation, see `03-convergence-diagnostics.ipynb`.

In [None]:
# Convert to ArviZ InferenceData
idata = result.to_inference_data()

# Trace plot
az.plot_trace(idata, figsize=(12, 6))
plt.tight_layout()
plt.show()

print("INTERPRETATION:")
print("✓ GOOD: Chains overlap, stationary, no trends")
print("✗ BAD: Chains separated, drift, stuck regions, bimodal distributions")
print("\nFor this example, chains should show excellent mixing and convergence.")

## 9. Posterior Distributions Visualization

Let's visualize the posterior distributions compared to NLSQ point estimates.

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

# G0 posterior
ax1.hist(posterior['G0'], bins=50, density=True, alpha=0.7, color='steelblue', edgecolor='black')
ax1.axvline(G0_nlsq, color='red', linestyle='--', linewidth=2, label='NLSQ point estimate')
ax1.axvline(G0_true, color='green', linestyle=':', linewidth=2, label='True value')
ax1.axvline(summary['G0']['mean'], color='blue', linestyle='-', linewidth=2, label='Posterior mean')
ax1.axvspan(credible_intervals['G0'][0], credible_intervals['G0'][1], 
            alpha=0.2, color='blue', label='95% credible interval')
ax1.set_xlabel('G₀ (Pa)', fontweight='bold')
ax1.set_ylabel('Posterior Density', fontweight='bold')
ax1.set_title('Posterior Distribution: Initial Modulus', fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)

# eta posterior
ax2.hist(posterior['eta'], bins=50, density=True, alpha=0.7, color='coral', edgecolor='black')
ax2.axvline(eta_nlsq, color='red', linestyle='--', linewidth=2, label='NLSQ point estimate')
ax2.axvline(eta_true, color='green', linestyle=':', linewidth=2, label='True value')
ax2.axvline(summary['eta']['mean'], color='orangered', linestyle='-', linewidth=2, label='Posterior mean')
ax2.axvspan(credible_intervals['eta'][0], credible_intervals['eta'][1], 
            alpha=0.2, color='orangered', label='95% credible interval')
ax2.set_xlabel('η (Pa·s)', fontweight='bold')
ax2.set_ylabel('Posterior Density', fontweight='bold')
ax2.set_title('Posterior Distribution: Viscosity', fontweight='bold')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("OBSERVATIONS:")
print("1. Posterior means ≈ NLSQ point estimates (as expected for well-behaved problem)")
print("2. Posterior has width → quantifies uncertainty (NLSQ cannot provide this)")
print("3. 95% CI captures true values (validation of uncertainty quantification)")
print("4. Unimodal, symmetric distributions → well-constrained parameters")

## 10. Comparison: NLSQ vs Bayesian

Let's compare the two approaches side-by-side.

In [None]:
print("="*80)
print("NLSQ vs BAYESIAN COMPARISON")
print("="*80)

print("\n" + "-"*80)
print(f"{'Method':<20} {'G₀ (Pa)':<20} {'η (Pa·s)':<20} {'Time (s)':<15}")
print("-"*80)
print(f"{'True Values':<20} {G0_true:<20.4e} {eta_true:<20.4e} {'N/A':<15}")
print(f"{'NLSQ Point':<20} {G0_nlsq:<20.4e} {eta_nlsq:<20.4e} {nlsq_time:<15.4f}")
print(f"{'Bayesian Mean':<20} {summary['G0']['mean']:<20.4e} {summary['eta']['mean']:<20.4e} {bayes_time:<15.2f}")
print("-"*80)

print("\n" + "="*80)
print("KEY DIFFERENCES")
print("="*80)

print("\n1. UNCERTAINTY QUANTIFICATION:")
print(f"   NLSQ:     Single value (no uncertainty) ✗")
print(f"   Bayesian: Full distribution with credible intervals ✓")
print(f"             G₀: {summary['G0']['std']/summary['G0']['mean']*100:.2f}% relative uncertainty")
print(f"             η:  {summary['eta']['std']/summary['eta']['mean']*100:.2f}% relative uncertainty")

print("\n2. COMPUTATIONAL COST:")
print(f"   NLSQ:     {nlsq_time:.4f} s (fast) ✓")
print(f"   Bayesian: {bayes_time:.2f} s (~{bayes_time/nlsq_time:.0f}x slower) ✗")
print(f"   Total:    {nlsq_time + bayes_time:.2f} s (with warm-start)")

print("\n3. INTERPRETABILITY:")
print(f"   NLSQ:     Point estimate only")
print(f"   Bayesian: Full posterior distribution enables:")
print(f"             - Credible intervals (direct probability statements) ✓")
print(f"             - Parameter correlations (identifiability) ✓")
print(f"             - Derived quantities with uncertainty ✓")
print(f"             - Model comparison (WAIC, LOO) ✓")

print("\n4. CONVERGENCE:")
print(f"   NLSQ:     Always converges (optimization)")
print(f"   Bayesian: Must check R-hat, ESS, divergences")
print(f"             Current: R-hat={max(diagnostics['r_hat'].values()):.4f}, ESS={min(diagnostics['ess'].values()):.0f} ✓")

print("\n" + "="*80)
print("RECOMMENDATION")
print("="*80)
print("Use NLSQ when: Fast screening, well-constrained parameters, no uncertainty needed")
print("Use Bayesian when: Uncertainty quantification essential, parameter correlations matter,")
print("                   model comparison needed, prediction uncertainty required")
print("\nBest practice: NLSQ first (fast), then Bayesian if uncertainty needed (warm-start)")
print("="*80)

## 11. Key Takeaways

### Main Concepts

1. **Why Bayesian?**
   - Point estimates hide uncertainty information
   - Essential when parameters poorly constrained or correlated
   - Enables direct probability statements ("95% probability parameter in interval")

2. **Two-Stage Workflow: NLSQ → NUTS**
   - Stage 1: NLSQ optimization (fast, ~seconds)
   - Stage 2: NUTS sampling with warm-start (slower, ~minutes)
   - Warm-start provides 2-5x faster convergence
   - Dramatically reduces divergences (10-100x fewer)

3. **Convergence Diagnostics**
   - **Always check before interpreting results!**
   - R-hat < 1.01 (all parameters)
   - ESS > 400 (all parameters)
   - Divergences < 1%

4. **Posterior Interpretation**
   - Mean/median: Central tendency (like point estimate)
   - Std: Uncertainty (cannot get from NLSQ)
   - Credible intervals: Probability ranges
   - Full distribution enables any derived quantity

### When to Use Bayesian Inference

**Essential for:**
- ✓ Poorly constrained parameters (wide posteriors reveal this)
- ✓ Parameter identifiability analysis (correlations)
- ✓ Prediction uncertainty (error bars on model predictions)
- ✓ Model comparison (WAIC, LOO - see `04-model-comparison.ipynb`)
- ✓ Communicating uncertainty to stakeholders

**Optional for:**
- Well-constrained parameters with high SNR data
- Rapid screening where uncertainty not needed
- Real-time analysis requiring speed

### Common Pitfalls

1. **Ignoring Convergence Diagnostics**
   - Never trust results without checking R-hat, ESS, divergences
   - Non-converged MCMC produces misleading posteriors

2. **Cold Start Without Warm-Start**
   - Random initialization can take 5-10x longer to converge
   - Higher divergence rates
   - Always use NLSQ warm-start when possible

3. **Single Chain for Production**
   - Use num_chains=4 for production work
   - Single chain cannot compute reliable R-hat
   - Multiple chains detect convergence failures

4. **Misinterpreting Credible Intervals**
   - Bayesian: "95% probability parameter in interval" ✓
   - Frequentist confidence interval has different interpretation ✗

## Next Steps

### Deepen Bayesian Understanding
- **[02-prior-selection.ipynb](02-prior-selection.ipynb)**: How to choose priors (bounds→priors transformation)
- **[03-convergence-diagnostics.ipynb](03-convergence-diagnostics.ipynb)**: Master all 6 ArviZ diagnostic plots
- **[04-model-comparison.ipynb](04-model-comparison.ipynb)**: WAIC and LOO for model selection
- **[05-uncertainty-propagation.ipynb](05-uncertainty-propagation.ipynb)**: Propagate uncertainty to predictions

### Apply to Other Models
- All 20 Rheo models support Bayesian inference via BayesianMixin
- See `basic/` notebooks for Zener, SpringPot, Bingham, PowerLaw examples
- Try Bayesian inference on your own rheological data

### Advanced Topics
- **[bayesian/04-model-comparison.ipynb](04-model-comparison.ipynb)**: Comparing Maxwell vs Zener
- **[advanced/01-multi-technique-fitting.ipynb](../advanced/01-multi-technique-fitting.ipynb)**: Constrained Bayesian fitting

---

## Session Information

In [None]:
import sys
import rheojax

print(f"Python: {sys.version}")
print(f"Rheo: {rheojax.__version__}")
print(f"JAX: {jax.__version__}")
print(f"NumPy: {np.__version__}")
print(f"ArviZ: {az.__version__}")
print(f"JAX devices: {jax.devices()}")