# MCMC Convergence Diagnostics: A Comprehensive Guide to ArviZ Plots

This notebook provides a complete guide to using all 6 ArviZ diagnostic plots as an integrated workflow, systematically diagnosing MCMC failures and troubleshooting convergence issues.

## Learning Objectives

After completing this notebook, you will be able to:
- Use all 6 ArviZ diagnostic plots as integrated workflow
- Diagnose common MCMC failures systematically
- Troubleshoot divergences, poor mixing, and non-convergence
- Understand when to increase warmup vs samples
- Apply multi-chain MCMC for robust diagnostics

## Prerequisites

- Bayesian basics (`01-bayesian-basics.ipynb`)
- Prior selection (`02-prior-selection.ipynb`)
- Understanding of MCMC sampling

**Estimated Time:** 45 minutes

## 1. Introduction: The Diagnostic Workflow

### Why Convergence Diagnostics Matter

MCMC (Markov Chain Monte Carlo) generates samples by exploring parameter space. **Convergence** means:
- Chains reached stationary distribution (the posterior)
- Samples are representative
- Results are reliable

**Non-converged MCMC produces misleading posteriors!**

### Recommended Diagnostic Sequence

1. **R-hat & ESS** (automated) → Quick pass/fail
2. **Trace plot** → Visual convergence check
3. **Rank plot** → Most sensitive convergence diagnostic
4. **Pair plot** → Parameter correlations & divergences
5. **Autocorrelation plot** → Mixing quality (if ESS low)
6. **ESS plot** → Per-parameter efficiency
7. **Energy plot** → Posterior geometry (multi-chain only)

### Setup Requirements

**Use multi-chain MCMC (num_chains=4)** for robust diagnostics:
- R-hat requires multiple chains
- Energy plot requires ≥2 chains
- Better detection of convergence failures

## 2. Setup and Imports

In [None]:
# Enable inline plotting for Jupyter/VS Code
%matplotlib inline

# Standard imports
import numpy as np
import matplotlib.pyplot as plt
import warnings

# Rheo imports
from rheojax.models.maxwell import Maxwell
from rheojax.core.jax_config import safe_import_jax

# ArviZ for diagnostics
import arviz as az

# Safe JAX import
jax, jnp = safe_import_jax()

# Reproducibility
np.random.seed(42)

# Plotting configuration
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 11

print("✓ Imports successful")
print(f"ArviZ version: {az.__version__}")

# Suppress matplotlib backend warning in VS Code
warnings.filterwarnings('ignore', message='.*non-interactive.*')


## 3. Generate Data and Run Multi-Chain MCMC

In [None]:
# True parameters
G0_true = 1e5  # Pa
eta_true = 1e3  # Pa·s
tau_true = eta_true / G0_true  # s

# Generate relaxation data
t = np.logspace(-2, 2, 50)
G_t_true = G0_true * np.exp(-t / tau_true)
noise = np.random.normal(0, 0.015 * G_t_true)
G_t_noisy = G_t_true + noise

print("Data Generated:")
print(f"  True G₀  = {G0_true:.2e} Pa")
print(f"  True η   = {eta_true:.2e} Pa·s")
print(f"  True τ   = {tau_true:.4f} s\n")

# NLSQ warm-start
model = Maxwell()
model.parameters.set_bounds('G0', (1e3, 1e7))
model.parameters.set_bounds('eta', (1e1, 1e5))
model.fit(t, G_t_noisy)

nlsq_params = {
    'G0': model.parameters.get_value('G0'),
    'eta': model.parameters.get_value('eta')
}

print("NLSQ Warm-Start:")
print(f"  G₀  = {nlsq_params['G0']:.4e} Pa")
print(f"  η   = {nlsq_params['eta']:.4e} Pa·s\n")

# Multi-chain MCMC (CRITICAL: 4 chains for robust diagnostics)
print("Running MULTI-CHAIN MCMC (4 chains)...")
print("(This may take 2-3 minutes)\n")

result = model.fit_bayesian(
    t, G_t_noisy,
    num_warmup=1000,
    num_samples=2000,
    num_chains=4,  # IMPORTANT: Multiple chains for diagnostics
    initial_values=nlsq_params
)

print("✓ Inference complete\n")

# Convert to ArviZ InferenceData
idata = result.to_inference_data()

print("ArviZ InferenceData structure:")
print(idata)

## 4. Step 1: Automated Checks (R-hat & ESS)

In [None]:
# Automated convergence checks
diagnostics = result.diagnostics

print("="*70)
print("AUTOMATED CONVERGENCE DIAGNOSTICS")
print("="*70)

print("\n1. R-hat (Gelman-Rubin Statistic):")
print("   Measures between-chain vs within-chain variance")
print(f"   Target: < 1.01\n")
for param in ['G0', 'eta']:
    rhat = diagnostics['r_hat'][param]
    status = '✓ Converged' if rhat < 1.01 else '✗ NOT converged'
    print(f"   {param:<5} R-hat = {rhat:.4f}  {status}")

print("\n2. ESS (Effective Sample Size):")
print("   Accounts for autocorrelation between samples")
print(f"   Target: > 400 (out of {result.num_samples * result.num_chains} total)\n")
for param in ['G0', 'eta']:
    ess = diagnostics['ess'][param]
    status = '✓ Sufficient' if ess > 400 else '✗ Low'
    efficiency = ess / (result.num_samples * result.num_chains) * 100
    print(f"   {param:<5} ESS = {ess:.0f}  ({efficiency:.1f}% efficiency)  {status}")

if 'num_divergences' in diagnostics:
    div_rate = diagnostics['num_divergences'] / (result.num_samples * result.num_chains) * 100
    print("\n3. Divergences:")
    print(f"   Count: {diagnostics['num_divergences']} ({div_rate:.2f}%)")
    status = '✓ Good' if div_rate < 1 else '✗ High'
    print(f"   Target: < 1%  {status}")

# Overall assessment
all_converged = (
    all(diagnostics['r_hat'][p] < 1.01 for p in ['G0', 'eta']) and
    all(diagnostics['ess'][p] > 400 for p in ['G0', 'eta'])
)

print("\n" + "="*70)
if all_converged:
    print("✓✓✓ QUICK CHECK: PASSED ✓✓✓")
    print("Proceed to visual diagnostics for detailed assessment.")
else:
    print("⚠⚠⚠ QUICK CHECK: FAILED ⚠⚠⚠")
    print("Use visual diagnostics below to identify issues.")
print("="*70)

## 5. Step 2: Trace Plot (Visual Convergence)

### What It Shows
- **Left panels**: Marginal posterior distributions
- **Right panels**: Parameter evolution over iterations

### Target Patterns
- ✓ **Left**: Smooth, unimodal distributions; all chains overlap
- ✓ **Right**: Stationary "fuzzy caterpillar" with no trends
- ✗ **Bad**: Trends, jumps, stuck regions, bimodal distributions

In [None]:
# Trace plot
az.plot_trace(idata, var_names=['G0', 'eta'], figsize=(14, 8))
plt.tight_layout()
plt.show()

print("\nTRACE PLOT INTERPRETATION:")
print("-" * 70)
print("LEFT PANELS (Marginal Distributions):")
print("  ✓ GOOD: Smooth, unimodal, all chains overlap")
print("  ✗ BAD: Bimodal, ragged, chains don't overlap")
print("\nRIGHT PANELS (Parameter vs Iteration):")
print("  ✓ GOOD: Fuzzy caterpillar, stationary, no trends")
print("  ✗ BAD: Drift, stuck regions, discontinuities")
print("-" * 70)
print("\nCOMMON ISSUES:")
print("1. Trend in trace → Not converged (increase num_warmup)")
print("2. Stuck regions → Chain trapped (check priors, reparameterize)")
print("3. Bimodal distribution → Multiple modes (may need more samples)")
print("4. Chains don't overlap → Not converged (R-hat will be high)")

## 6. Step 3: Rank Plot (Most Sensitive)

### Why Rank Plot?
- **Most sensitive convergence diagnostic**
- Detects subtle issues R-hat misses
- Standard for modern MCMC validation

### What It Shows
- Histogram of ranked samples across chains
- If converged, histogram should be **uniform** (flat)

### Target Pattern
- ✓ **Uniform histogram** across all bins
- ✗ **Peaks, valleys, trends** indicate non-convergence

In [None]:
# Rank plot (most sensitive)
az.plot_rank(idata, var_names=['G0', 'eta'], figsize=(14, 5))
plt.tight_layout()
plt.show()

print("\nRANK PLOT INTERPRETATION:")
print("-" * 70)
print("UNIFORM HISTOGRAM (all bins similar height):")
print("  ✓ CONVERGED: Chains sampling from same distribution")
print("\nNON-UNIFORM (peaks, valleys, trends):")
print("  ✗ NOT CONVERGED: Chains exploring different regions")
print("-" * 70)
print("\nACTION ITEMS:")
print("1. Non-uniform → Increase num_warmup (double: 1000 → 2000)")
print("2. Peaks at edges → Chain sticking (check initial values, priors)")
print("3. Consistent pattern → Systematic bias (reparameterize model)")
print("\nWHY RANK PLOT OVER TRACE PLOT?")
print("  - More sensitive to subtle non-convergence")
print("  - Works better for high-dimensional posteriors")
print("  - Less sensitive to posterior scale")

## 7. Step 4: Pair Plot (Correlations & Divergences)

### What It Shows
- **Diagonal**: Marginal posteriors
- **Off-diagonal**: Joint distributions (correlations)
- **Red points**: Divergent transitions (MCMC failures)

### Correlation Patterns
- **Elliptical**: Moderate correlation (normal)
- **Diagonal line**: Strong correlation (identifiability issue)
- **Funnel**: One parameter constrains another (reparameterization needed)

### Divergences
- **< 1%**: Acceptable
- **1-5%**: Moderate (increase target_accept_prob)
- **> 5%**: Problematic (results unreliable)

In [None]:
# Pair plot with divergences
az.plot_pair(
    idata,
    var_names=['G0', 'eta'],
    kind='scatter',
    divergences=True,
    figsize=(10, 8)
)
plt.tight_layout()
plt.show()

# Compute correlation
G0_samples = result.posterior_samples['G0']
eta_samples = result.posterior_samples['eta']
correlation = np.corrcoef(G0_samples, eta_samples)[0, 1]

print("\nPAIR PLOT INTERPRETATION:")
print("-" * 70)
print(f"Parameter Correlation: ρ(G₀, η) = {correlation:.3f}\n")
print("CORRELATION STRENGTH:")
print("  |ρ| < 0.3:  Weakly correlated (well-identified) ✓")
print("  0.3 < |ρ| < 0.7:  Moderate correlation (acceptable) ✓")
print("  |ρ| > 0.7:  Strong correlation (identifiability issue) ✗")
print("-" * 70)
print("\nDIVERGENCE TROUBLESHOOTING:")
if 'num_divergences' in diagnostics:
    div_rate = diagnostics['num_divergences'] / (result.num_samples * result.num_chains) * 100
    print(f"Divergence rate: {div_rate:.2f}%")
    if div_rate < 1:
        print("  ✓ < 1%: Acceptable, model fit reliable")
    elif div_rate < 5:
        print("  ⚠ 1-5%: Moderate")
        print("    Solution: Increase target_accept_prob=0.9")
    else:
        print("  ✗ > 5%: Problematic")
        print("    Solution 1: Increase target_accept_prob=0.95")
        print("    Solution 2: Reparameterize (non-centered)")
        print("    Solution 3: Tighter priors to constrain problematic regions")
print("-" * 70)
print("\nFOR MAXWELL MODEL:")
print(f"  G₀ and η correlation ({correlation:.3f}) is typical")
print("  Both affect relaxation time τ = η/G₀")
print("  This correlation is physical, not a problem")

## 8. Step 5: Autocorrelation Plot (Mixing Quality)

### What It Shows
- Correlation between samples at different lags
- High autocorrelation → many samples needed for good ESS

### Target
- ✓ Autocorrelation drops to ~0 within 20-30 lags
- ✗ Slow decay → high autocorrelation → poor mixing

### Relation to ESS
```
ESS ≈ num_samples / (1 + 2 × Σ autocorrelations)
```
High autocorrelation → low ESS

In [None]:
# Autocorrelation plot
az.plot_autocorr(
    idata,
    var_names=['G0', 'eta'],
    max_lag=100,
    figsize=(14, 5)
)
plt.tight_layout()
plt.show()

print("\nAUTOCORRELATION PLOT INTERPRETATION:")
print("-" * 70)
print("TARGET: Autocorrelation drops to ~0 within 20-30 lags\n")
print("FAST DECAY (< 30 lags):")
print("  ✓ Good mixing, independent samples obtained quickly")
print("  ✓ High ESS (>50% efficiency)")
print("\nMODERATE DECAY (30-50 lags):")
print("  ⚠ Acceptable mixing")
print("  ⚠ ESS ~20-50% of num_samples")
print("  → Consider increasing num_samples if ESS < 400")
print("\nSLOW DECAY (> 50 lags):")
print("  ✗ Poor mixing, high autocorrelation")
print("  ✗ ESS < 20% of num_samples")
print("  → Solution 1: Increase num_samples significantly")
print("  → Solution 2: Check pair plot for strong correlations")
print("  → Solution 3: Reparameterize if due to model structure")
print("-" * 70)
print("\nRELATION TO ESS:")
for param in ['G0', 'eta']:
    ess = diagnostics['ess'][param]
    efficiency = ess / (result.num_samples * result.num_chains) * 100
    print(f"  {param}: ESS = {ess:.0f} ({efficiency:.1f}% efficiency)")
    if efficiency > 50:
        print(f"       ✓ Excellent: Fast mixing")
    elif efficiency > 20:
        print(f"       ✓ Good: Acceptable mixing")
    else:
        print(f"       ✗ Poor: High autocorrelation")

## 9. Step 6: ESS Plot (Sampling Efficiency)

### What It Shows
- Effective sample size per parameter
- Quantifies sampling efficiency

### ESS Types
- **Bulk ESS**: Central posterior (mean, median)
- **Tail ESS**: Extreme quantiles (credible interval ends)
- **Local ESS**: ESS at different quantiles (full curve)

### Targets
- Bulk ESS > 400: Reliable mean/median estimates
- Tail ESS > 400: Reliable credible intervals
- If tail ESS < bulk ESS: Need more samples for tails

In [None]:
# ESS plot (local)
az.plot_ess(
    idata,
    var_names=['G0', 'eta'],
    kind='local',
    figsize=(14, 5)
)
plt.tight_layout()
plt.show()

print("\nESS PLOT INTERPRETATION:")
print("-" * 70)
print("ESS measures effective independent samples (accounts for autocorrelation)\n")
print("TARGETS:")
print("  Bulk ESS > 400:  Reliable mean/median ✓")
print("  Tail ESS > 400:  Reliable credible intervals ✓")
print("-" * 70)
print("\nACTION ITEMS:")
print("ESS < 100:    Critical - Increase samples 10x")
print("ESS 100-400:  Increase samples 2-3x")
print("ESS > 400:    ✓ Sufficient")
print("Tail << Bulk: Increase samples for CI reliability")
print("-" * 70)
print("\nEFFICIENCY CALCULATION:")
total_samples = result.num_samples * result.num_chains
print(f"Total samples: {total_samples}")
for param in ['G0', 'eta']:
    ess = diagnostics['ess'][param]
    efficiency = ess / total_samples * 100
    print(f"  {param}: ESS = {ess:.0f}, Efficiency = {efficiency:.1f}%")
    if efficiency > 50:
        print(f"       ✓ Excellent sampling efficiency")
    elif efficiency > 20:
        print(f"       ✓ Good efficiency")
    elif efficiency > 10:
        print(f"       ⚠ Acceptable but could be better")
    else:
        print(f"       ✗ Poor efficiency (need more samples or reparameterization)")

## 10. Step 7: Energy Plot (Posterior Geometry)

### What It Shows
- NUTS energy diagnostic
- **Marginal energy**: Expected under model
- **Conditional energy**: Actual from MCMC

### Requirements
- **Multi-chain MCMC** (num_chains ≥ 2)
- Single chain: plot will fail

### Target
- ✓ **Good overlap** between distributions
- ✗ **Mismatch** → posterior geometry issues

### Causes of Mismatch
- Funnel-shaped posteriors
- Heavy tails
- Complex correlations

In [None]:
# Energy plot (requires multi-chain)
try:
    az.plot_energy(idata)
    plt.tight_layout()
    plt.show()
    
    print("\nENERGY PLOT INTERPRETATION:")
    print("-" * 70)
    print("GOOD FIT (distributions overlap well):")
    print("  ✓ Posterior geometry is well-behaved")
    print("  ✓ NUTS sampling efficient")
    print("\nPOOR FIT (distributions don't match):")
    print("  ✗ Posterior geometry problematic")
    print("  → Possible causes:")
    print("    - Funnel-shaped posterior")
    print("    - Heavy tails")
    print("    - Complex parameter correlations")
    print("-" * 70)
    print("\nACTION ITEMS:")
    print("1. Mismatch detected → Reparameterize model")
    print("   (Use non-centered parameterization for hierarchical models)")
    print("2. Persistent issues → Tighter priors to regularize")
    print("3. Use with pair plot to identify problematic parameters")
except Exception as e:
    print("\nENERGY PLOT FAILED:")
    print(f"Error: {e}")
    print("\nLikely cause: Single chain MCMC (energy plot requires ≥2 chains)")
    print("Use num_chains=4 when running fit_bayesian()")

## 11. Systematic Troubleshooting Guide

Follow this decision tree when diagnosing convergence issues:

In [None]:
# Automated troubleshooting function

def diagnose_convergence(result, show_solutions=True):
    """
    Automated convergence diagnostic with actionable recommendations.
    """
    diagnostics = result.diagnostics
    issues = []
    solutions = []
    
    print("\n" + "="*70)
    print("AUTOMATED TROUBLESHOOTING REPORT")
    print("="*70)
    
    # Check R-hat
    max_rhat = max(diagnostics['r_hat'].values())
    if max_rhat > 1.01:
        issues.append(f"R-hat > 1.01 (max: {max_rhat:.4f})")
        solutions.append("1. Increase num_warmup (current → 2× current)")
        solutions.append("2. Check trace plot: Are chains exploring same region?")
        solutions.append("3. Check rank plot: Is histogram uniform?")
    
    # Check ESS
    min_ess = min(diagnostics['ess'].values())
    if min_ess < 400:
        issues.append(f"ESS < 400 (min: {min_ess:.0f})")
        solutions.append("1. Increase num_samples (2000 → 5000)")
        solutions.append("2. Check autocorrelation plot: Is mixing slow?")
        solutions.append("3. Check pair plot: Are parameters correlated?")
    
    # Check divergences
    if 'num_divergences' in diagnostics:
        div_rate = diagnostics['num_divergences'] / (result.num_samples * result.num_chains)
        if div_rate > 0.05:
            issues.append(f"High divergence rate ({div_rate*100:.1f}%)")
            solutions.append("1. Increase target_accept_prob (0.8 → 0.9 or 0.95)")
            solutions.append("2. Check pair plot: Where do divergences occur?")
            solutions.append("3. Use tighter priors or reparameterize")
    
    # Report
    if not issues:
        print("\n✓✓✓ NO ISSUES DETECTED ✓✓✓")
        print("All convergence criteria met.")
        print(f"  - R-hat: {max_rhat:.4f} < 1.01")
        print(f"  - ESS: {min_ess:.0f} > 400")
        if 'num_divergences' in diagnostics:
            print(f"  - Divergences: {diagnostics['num_divergences']} ({div_rate*100:.2f}%)")
    else:
        print("\n⚠⚠⚠ ISSUES DETECTED ⚠⚠⚠\n")
        for issue in issues:
            print(f"  - {issue}")
        
        if show_solutions:
            print("\nRECOMMENDED ACTIONS:\n")
            for solution in solutions:
                print(f"  {solution}")
    
    print("="*70)
    return len(issues) == 0

# Run diagnosis
converged = diagnose_convergence(result)

## 12. Common Failure Modes Reference

| Symptom | Likely Cause | Primary Diagnostic | Solution |
|---------|--------------|-------------------|---------|
| R-hat > 1.01 | Not converged | Rank plot | Increase `num_warmup` |
| ESS < 400 | High autocorrelation | Autocorr plot | Increase `num_samples` |
| Many divergences | Bad geometry | Pair plot | Increase `target_accept_prob` |
| Bimodal posterior | Multiple modes | Trace plot | Longer chains |
| Strong correlations | Identifiability | Pair plot | More data or tighter priors |
| Energy mismatch | Funnel geometry | Energy plot | Non-centered parameterization |
| Slow autocorr decay | Poor mixing | Autocorr + Pair | Reparameterize |
| Non-uniform rank | Non-convergence | Rank plot | Increase warmup (most sensitive) |

## 13. Key Takeaways

### Main Concepts

1. **Diagnostic Workflow**
   - Use all 6 plots as integrated system
   - Rank plot is most sensitive convergence diagnostic
   - Pair plot reveals correlations and divergences
   - Multi-chain MCMC (4 chains) is best practice

2. **Convergence Criteria**
   - R-hat < 1.01 (all parameters)
   - ESS > 400 (all parameters)
   - Divergences < 1%
   - **Always verify before trusting results!**

3. **Common Issues and Solutions**
   - R-hat > 1.01 → Increase warmup
   - ESS < 400 → Increase samples
   - Divergences > 1% → Increase target_accept_prob or tighten priors
   - Strong correlations → More data, different test mode, or accept

4. **Multi-Chain Benefits**
   - Robust R-hat estimation
   - Energy plot requires ≥2 chains
   - Parallel execution (no time penalty)
   - Reliably detects non-convergence

### Best Practices

1. **Always use num_chains=4** for production work
2. **Run all 6 diagnostics** before interpreting results
3. **Rank plot first** for convergence assessment
4. **Pair plot second** for correlations and divergences
5. **Document convergence** in reports (R-hat, ESS, divergences)

## Next Steps

### Apply Diagnostics
- **[04-model-comparison.ipynb](04-model-comparison.ipynb)**: WAIC and LOO for model selection
- **[05-uncertainty-propagation.ipynb](05-uncertainty-propagation.ipynb)**: Propagate uncertainty to predictions

### Related Content
- All `basic/` notebooks demonstrate Bayesian sections
- **[01-bayesian-basics.ipynb](01-bayesian-basics.ipynb)**: NLSQ → NUTS workflow
- **[02-prior-selection.ipynb](02-prior-selection.ipynb)**: Prior choices and sensitivity

---

## Session Information

In [None]:
import sys
import rheojax

print(f"Python: {sys.version}")
print(f"Rheo: {rheojax.__version__}")
print(f"JAX: {jax.__version__}")
print(f"NumPy: {np.__version__}")
print(f"ArviZ: {az.__version__}")
print(f"JAX devices: {jax.devices()}")