# Bayesian Workflow Demo: Complete NLSQ → NUTS → ArviZ Pipeline

This notebook demonstrates the recommended three-stage Bayesian workflow:
1. **Stage 1**: NLSQ point estimation (fast, ~seconds)
2. **Stage 2**: NUTS posterior sampling with warm-start (~minutes)
3. **Stage 3**: ArviZ diagnostic plots (visual verification)

**Expected runtime**: ~30 seconds (depending on hardware)

**Requirements**:
- rheojax with Bayesian dependencies (numpyro, arviz)
- matplotlib for visualization

In [None]:
# Google Colab Setup - Run this cell first!
# Skip if running locally with rheojax already installed

import sys

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    # Install rheojax and dependencies
    !pip install -q rheojax
    
    # Colab uses float32 by default - we need float64 for numerical stability
    # This MUST be set before importing JAX
    import os
    os.environ['JAX_ENABLE_X64'] = 'true'
    
    print("✓ RheoJAX installed successfully!")
    print("✓ Float64 precision enabled")

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display
from rheojax.models import Maxwell

from rheojax.core.jax_config import safe_import_jax

# Safe JAX import (ensures float64 precision)
jax, jnp = safe_import_jax()

print("="*70)
print("BAYESIAN WORKFLOW DEMONSTRATION")
print("="*70)
print("\nThis demo shows the recommended NLSQ → NUTS → ArviZ workflow")
print("for uncertainty quantification in rheological modeling.\n")

import sys, os
sys.path.insert(0, os.path.dirname(os.path.abspath("")))
from utils.plotting_utils import (
    plot_nlsq_fit, display_arviz_diagnostics, plot_posterior_predictive
)

FAST_MODE = os.environ.get("FAST_MODE", "1") == "1"


## Step 1: Generate Synthetic Relaxation Data

We'll create synthetic Maxwell relaxation data with realistic noise to demonstrate the workflow.

In [None]:
print("Step 1: Generating synthetic Maxwell relaxation data...")
print("-"*70)

# True parameters
G0_true = 1e5  # Pa
eta_true = 1e3  # Pa·s
tau_true = eta_true / G0_true  # 0.01 s

print(f"  True parameters:")
print(f"    G₀  = {G0_true:.2e} Pa")
print(f"    η   = {eta_true:.2e} Pa·s")
print(f"    τ   = {tau_true:.4f} s")

# Time array (log-spaced for relaxation)
t = np.logspace(-2, 2, 50)  # 0.01 to 100 seconds

# True relaxation modulus
G_t_true = G0_true * np.exp(-t / tau_true)

# Add realistic noise (1.5% relative)
np.random.seed(42)
noise_level = 0.015
noise = np.random.normal(0, noise_level * G_t_true)
G_t_noisy = G_t_true + noise

print(f"\n  Generated {len(t)} data points from {t.min():.2e} to {t.max():.1f} s")
print(f"  Noise level: {noise_level*100:.1f}% relative")
print(f"  Signal-to-noise ratio: {np.mean(G_t_true)/np.std(noise):.1f}")

## Stage 1: NLSQ Point Estimation (Fast)

First, we obtain fast point estimates using nonlinear least squares optimization.

In [None]:
print("\n" + "="*70)
print("STAGE 1: NLSQ POINT ESTIMATION")
print("="*70)

model = Maxwell()
model.parameters.set_bounds('G0', (1e3, 1e7))
model.parameters.set_bounds('eta', (1e1, 1e5))

print("\nRunning NLSQ optimization...")
import time

start_nlsq = time.time()

model.fit(t, G_t_noisy, method='nlsq')

nlsq_time = time.time() - start_nlsq

# Extract NLSQ results
G0_nlsq = model.parameters.get_value('G0')
eta_nlsq = model.parameters.get_value('eta')
tau_nlsq = eta_nlsq / G0_nlsq

print(f"\n✓ NLSQ completed in {nlsq_time:.3f} seconds")
print(f"\nNLSQ Point Estimates:")
print(f"  G₀  = {G0_nlsq:.4e} Pa  (error: {abs(G0_nlsq-G0_true)/G0_true*100:.2f}%)")
print(f"  η   = {eta_nlsq:.4e} Pa·s  (error: {abs(eta_nlsq-eta_true)/eta_true*100:.2f}%)")
print(f"  τ   = {tau_nlsq:.6f} s  (error: {abs(tau_nlsq-tau_true)/tau_true*100:.2f}%)")
print(f"\n⚠  Note: NLSQ provides point estimates only (no uncertainty)")

## Stage 2: Bayesian Inference with Warm-Start

Now we perform MCMC sampling using NUTS, warm-starting from the NLSQ estimates for faster convergence.

In [None]:
print("\n" + "="*70)
print("STAGE 2: BAYESIAN INFERENCE (NUTS)")
print("="*70)

print("\nRunning NUTS sampling with NLSQ warm-start...")
print("  Configuration:")
print(f"    • num_chains: 4 (for robust diagnostics)")
print(f"    • num_warmup: 1000 (burn-in iterations)")
print(f"    • num_samples: 2000 (posterior samples per chain)")
print(f"    • warm-start: Yes (from NLSQ estimates)")
print("\n  This may take 20-60 seconds depending on your hardware...")

start_bayes = time.time()

# Run Bayesian inference with warm-start
result = model.fit_bayesian(
    t, G_t_noisy,
    num_warmup=1000,
    num_samples=2000,
    num_chains=4,
    initial_values={  # Warm-start from NLSQ
        'G0': G0_nlsq,
        'eta': eta_nlsq
    }
)

bayes_time = time.time() - start_bayes

print(f"\n✓ Bayesian inference completed in {bayes_time:.1f} seconds")
print(f"  Total time (NLSQ + Bayes): {nlsq_time + bayes_time:.1f} seconds")
print(f"  Generated {result.num_chains * result.num_samples} posterior samples")

## Posterior Results

Extract and display the posterior summary statistics and credible intervals.

In [None]:
print("\n" + "="*70)
print("POSTERIOR RESULTS")
print("="*70)

posterior = result.posterior_samples
summary = result.summary

print("\nPosterior Estimates (mean ± std):")
print(f"  G₀  = {summary['G0']['mean']:.4e} ± {summary['G0']['std']:.4e} Pa")
print(f"  η   = {summary['eta']['mean']:.4e} ± {summary['eta']['std']:.4e} Pa·s")

# Compute credible intervals
intervals = model.get_credible_intervals(posterior, credibility=0.95)
print("\n95% Credible Intervals:")
print(f"  G₀:  [{intervals['G0'][0]:.4e}, {intervals['G0'][1]:.4e}] Pa")
print(f"  η:   [{intervals['eta'][0]:.4e}, {intervals['eta'][1]:.4e}] Pa·s")

print("\nInterpretation:")
print("  'There is 95% probability that G₀ lies in the interval above'")
print("  This is a DIRECT probabilistic statement (Bayesian interpretation)")

# Relative uncertainties
print("\nRelative Uncertainties:")
print(f"  G₀:  {summary['G0']['std']/summary['G0']['mean']*100:.2f}%")
print(f"  η:   {summary['eta']['std']/summary['eta']['mean']*100:.2f}%")

# Check if true values are in credible intervals
G0_in_CI = intervals['G0'][0] <= G0_true <= intervals['G0'][1]
eta_in_CI = intervals['eta'][0] <= eta_true <= intervals['eta'][1]
print("\nValidation (true values in 95% CI):")
print(f"  G₀:  {'✓ Yes' if G0_in_CI else '✗ No'}")
print(f"  η:   {'✓ Yes' if eta_in_CI else '✗ No'}")

## Stage 3: Convergence Diagnostics (CRITICAL!)

**Always check convergence before interpreting Bayesian results!**

We examine:
- **R-hat (Gelman-Rubin)**: Should be < 1.01 for all parameters
- **ESS (Effective Sample Size)**: Should be > 400 for reliable inference
- **Divergences**: Should be < 1% of total samples

In [None]:
print("\n" + "="*70)
print("STAGE 3: CONVERGENCE DIAGNOSTICS")
print("="*70)

diagnostics = result.diagnostics

print("\n⚠  ALWAYS check convergence before interpreting Bayesian results!")
print("\n1. R-hat (Gelman-Rubin Statistic):")
print(f"   Target: < 1.01 for all parameters")
for param in ['G0', 'eta']:
    rhat = diagnostics['r_hat'][param]
    status = '✓ Converged' if rhat < 1.01 else '✗ NOT converged'
    print(f"     {param:<5} R-hat = {rhat:.4f}  {status}")

print("\n2. ESS (Effective Sample Size):")
print(f"   Target: > 400 (out of {result.num_chains * result.num_samples} total)")
for param in ['G0', 'eta']:
    ess = diagnostics['ess'][param]
    efficiency = ess / (result.num_chains * result.num_samples) * 100
    status = '✓ Sufficient' if ess > 400 else '✗ Low (increase samples)'
    print(f"     {param:<5} ESS = {ess:.0f} ({efficiency:.1f}% efficient)  {status}")

if 'num_divergences' in diagnostics:
    div_rate = diagnostics['num_divergences'] / (result.num_chains * result.num_samples) * 100
    print("\n3. Divergences:")
    print(f"   Count: {diagnostics['num_divergences']} ({div_rate:.2f}%)")
    status = '✓ Good' if div_rate < 1 else '✗ High (results unreliable)'
    print(f"   Target: < 1%  {status}")

# Overall convergence assessment
all_converged = (
    all(diagnostics['r_hat'][p] < 1.01 for p in ['G0', 'eta']) and
    all(diagnostics['ess'][p] > 400 for p in ['G0', 'eta'])
)

print("\n" + "-"*70)
if all_converged:
    print("✓✓✓ EXCELLENT CONVERGENCE ✓✓✓")
    print("All diagnostic criteria met. Results are reliable.")
else:
    print("⚠⚠⚠ CONVERGENCE ISSUES ⚠⚠⚠")
    print("Increase num_warmup or num_samples and rerun.")
print("-"*70)

## Visual Diagnostics (ArviZ Integration)

ArviZ provides comprehensive diagnostic visualizations. We'll generate 6 key plots:

1. **Trace plot**: Visual convergence check
2. **Rank plot**: Most sensitive convergence test
3. **Pair plot**: Parameter correlations + divergences
4. **Autocorrelation plot**: Mixing quality
5. **ESS plot**: Sampling efficiency
6. **Forest plot**: Credible interval comparison

In [None]:
print("\n" + "="*70)
print("VISUAL DIAGNOSTICS (ArviZ Integration)")
print("="*70)

print("\nGenerating diagnostic plots...")

# Convert to ArviZ InferenceData

idata = result.to_inference_data()

In [None]:
# ArviZ diagnostic plots (trace, pair, forest, energy, autocorr, rank)
display_arviz_diagnostics(result, ['G0', 'eta'], fast_mode=FAST_MODE)

## Workflow Summary

Recap of the complete 3-stage workflow and key results.

In [None]:
print("\n" + "="*70)
print("WORKFLOW SUMMARY")
print("="*70)

print("\n✓ Completed 3-stage Bayesian workflow:")
print(f"  [1] NLSQ point estimation:      {nlsq_time:.2f}s")
print(f"  [2] NUTS posterior sampling:     {bayes_time:.1f}s (warm-start)")
print(f"  [3] ArviZ diagnostic plots:      6 plots generated")

print(f"\n✓ Convergence assessment:")
if all_converged:
    print("  • All parameters converged (R-hat < 1.01, ESS > 400)")
    print("  • Results are reliable and can be interpreted")
else:
    print("  • ⚠ Convergence issues detected")
    print("  • Increase num_warmup or num_samples and rerun")

print(f"\n✓ Uncertainty quantification:")
print(f"  • G₀ uncertainty: ±{summary['G0']['std']/summary['G0']['mean']*100:.1f}%")
print(f"  • η uncertainty:  ±{summary['eta']['std']/summary['eta']['mean']*100:.1f}%")

## Next Steps

1. Apply this workflow to your own rheological data
2. Try different models (20 models available)
3. Explore tutorial notebooks in `examples/bayesian/`:
   - `01-bayesian-basics.ipynb` (40 min)
   - `02-prior-selection.ipynb` (35 min)
   - `03-convergence-diagnostics.ipynb` (45 min)
   - `04-model-comparison.ipynb` (40 min)
   - `05-uncertainty-propagation.ipynb` (45 min)

4. Read documentation:
   - `docs/BAYESIAN_QUICK_START.md`
   - `docs/BAYESIAN_WORKFLOW_SUMMARY.md`

**For questions or issues:**
- GitHub: https://github.com/imewei/rheojax
- Docs: https://rheojax.readthedocs.io