# Zener Model: Oscillatory Shear Fitting

This notebook demonstrates the complete workflow for fitting the Zener (Standard Linear Solid) model to oscillatory shear data, showcasing modern Rheo capabilities including GPU-accelerated optimization and Bayesian uncertainty quantification.

## Learning Objectives

After completing this notebook, you will be able to:
- Fit the Zener model to oscillatory shear data (G', G" vs frequency)
- Understand the physical meaning of equilibrium and Maxwell moduli
- Leverage NLSQ optimization for 5-270x speedup over SciPy
- Perform Bayesian inference with NLSQ→NUTS warm-start workflow
- Interpret all 6 ArviZ diagnostic plots for MCMC convergence
- Extract physically meaningful parameters with uncertainty quantification

## Prerequisites

Basic understanding of:
- Rheological concepts (storage modulus G', loss modulus G")
- Linear viscoelasticity
- Oscillatory shear testing
- Python and NumPy

**Recommended:** Complete `01-maxwell-fitting.ipynb` first

**Estimated Time:** 35-45 minutes

## Setup and Imports

We start by importing necessary libraries. Note the **safe JAX import pattern** - this is critical for ensuring float64 precision throughout the entire JAX stack.

In [None]:
# Standard scientific computing imports
import numpy as np
import matplotlib.pyplot as plt
import time

# Rheo imports - always explicit
from rheojax.pipeline.base import Pipeline
from rheojax.pipeline.bayesian import BayesianPipeline
from rheojax.core.data import RheoData
from rheojax.models.zener import Zener
from rheojax.core.jax_config import safe_import_jax, verify_float64

# Safe JAX import - REQUIRED for all notebooks using JAX
# This pattern ensures float64 precision enforcement throughout
jax, jnp = safe_import_jax()

# Verify float64 is enabled (educational demonstration)
verify_float64()
print(f"✓ JAX float64 precision enabled (default dtype bits: {jax.config.jax_default_dtype_bits})")

# Set reproducible random seed
np.random.seed(42)

# Configure matplotlib for publication-quality plots
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 11

## Zener Model Theory

The Zener model (also called Standard Linear Solid or SLS) represents a viscoelastic material as a Maxwell element (spring and dashpot in series) in parallel with an equilibrium spring:

**Complex Modulus (Oscillatory Shear):**
$$G^*(\omega) = G_e + \frac{G_m (\omega\tau)^2}{1 + (\omega\tau)^2} + i\frac{G_m \omega\tau}{1 + (\omega\tau)^2}$$

where:
- $G'(\omega)$ = storage modulus = $G_e + \frac{G_m (\omega\tau)^2}{1 + (\omega\tau)^2}$
- $G''(\omega)$ = loss modulus = $\frac{G_m \omega\tau}{1 + (\omega\tau)^2}$
- $G_e$ = equilibrium modulus (Pa) - long-time elastic response
- $G_m$ = Maxwell modulus (Pa) - transient elastic component
- $\eta$ = viscosity (Pa·s) - resistance to flow
- $\tau = \eta / G_m$ = relaxation time (s)

**Physical Interpretation:**
- **$G_e$**: Equilibrium modulus - elastic response at $t→\infty$ (solid-like behavior)
- **$G_m$**: Maxwell modulus - transient elastic component that relaxes
- **$\eta$**: Viscosity - determines relaxation rate
- **$\tau$**: Relaxation time - characteristic time scale for stress relaxation

**Applicability:**
- Crosslinked polymers (gels, elastomers)
- Materials with finite equilibrium modulus
- Limited to small strains (linear viscoelastic regime)
- Single dominant relaxation time

**Comparison to Maxwell Model:**
- Maxwell: $G_e = 0$ (complete stress relaxation)
- Zener: $G_e > 0$ (finite equilibrium modulus)

## Generate Synthetic Oscillation Data

We create synthetic oscillatory shear data with known parameters to validate our fitting workflow. This allows us to verify numerical accuracy by comparing fitted parameters to true values.

In [None]:
# True Zener parameters
Ge_true = 1e4  # Pa (equilibrium modulus)
Gm_true = 5e4  # Pa (Maxwell modulus)
eta_true = 1e3  # Pa·s (viscosity)
tau_true = eta_true / Gm_true  # s (relaxation time)

print(f"True Parameters:")
print(f"  Ge  = {Ge_true:.2e} Pa")
print(f"  Gm  = {Gm_true:.2e} Pa")
print(f"  eta = {eta_true:.2e} Pa·s")
print(f"  tau = {tau_true:.4f} s")

# Generate frequency array (logarithmically spaced)
omega = np.logspace(-2, 3, 40)  # 0.01 to 1000 rad/s

# True complex modulus
omega_tau = omega * tau_true
omega_tau_sq = omega_tau**2
G_prime_true = Ge_true + Gm_true * omega_tau_sq / (1 + omega_tau_sq)
G_double_prime_true = Gm_true * omega_tau / (1 + omega_tau_sq)

# Add realistic Gaussian noise (1-2% relative error)
noise_level = 0.015  # 1.5%
noise_Gp = np.random.normal(0, noise_level * G_prime_true)
noise_Gpp = np.random.normal(0, noise_level * G_double_prime_true)

G_prime_noisy = G_prime_true + noise_Gp
G_double_prime_noisy = G_double_prime_true + noise_Gpp

# Create complex modulus for fitting
G_star_noisy = G_prime_noisy + 1j * G_double_prime_noisy

print(f"\nData characteristics:")
print(f"  Frequency range: {omega.min():.2f} - {omega.max():.2f} rad/s")
print(f"  Number of points: {len(omega)}")
print(f"  Noise level: {noise_level*100:.1f}% relative")
print(f"  SNR (G'): {np.mean(G_prime_true) / np.std(noise_Gp):.1f}")
print(f"  SNR (G''): {np.mean(G_double_prime_true) / np.std(noise_Gpp):.1f}")

In [None]:
# Visualize raw data
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Left: G' and G" vs frequency
ax1.loglog(omega, G_prime_noisy, 'o', markersize=6, alpha=0.7, label="G' (data)", color='#1f77b4')
ax1.loglog(omega, G_double_prime_noisy, 's', markersize=6, alpha=0.7, label='G" (data)', color='#ff7f0e')
ax1.loglog(omega, G_prime_true, '--', linewidth=2, alpha=0.4, label="G' (true)", color='#1f77b4')
ax1.loglog(omega, G_double_prime_true, '--', linewidth=2, alpha=0.4, label='G" (true)', color='#ff7f0e')
ax1.set_xlabel('Angular Frequency ω (rad/s)', fontsize=12, fontweight='bold')
ax1.set_ylabel('Modulus (Pa)', fontsize=12, fontweight='bold')
ax1.set_title('Oscillatory Shear Data', fontsize=13, fontweight='bold')
ax1.grid(True, alpha=0.3, which='both')
ax1.legend(fontsize=10, loc='best', ncol=2)

# Right: tan(δ) = G"/G'
tan_delta_noisy = G_double_prime_noisy / G_prime_noisy
tan_delta_true = G_double_prime_true / G_prime_true
ax2.semilogx(omega, tan_delta_noisy, 'o', markersize=6, alpha=0.7, label='Data', color='#2ca02c')
ax2.semilogx(omega, tan_delta_true, '--', linewidth=2, alpha=0.4, label='True', color='gray')
ax2.set_xlabel('Angular Frequency ω (rad/s)', fontsize=12, fontweight='bold')
ax2.set_ylabel('tan(δ) = G"/G\'', fontsize=12, fontweight='bold')
ax2.set_title('Loss Tangent', fontsize=13, fontweight='bold')
ax2.grid(True, alpha=0.3)
ax2.legend(fontsize=11)

plt.tight_layout()
plt.show()

print("\nPhysical insights from data:")
print(f"  G' at low ω: {G_prime_noisy[0]:.2e} Pa (approaches Ge)")
print(f"  G' at high ω: {G_prime_noisy[-1]:.2e} Pa (approaches Ge + Gm)")
print(f"  tan(δ) peak: {tan_delta_noisy.max():.4f} at ω ≈ {omega[np.argmax(tan_delta_noisy)]:.2f} rad/s")

## Approach 1: Pipeline API (Recommended for Standard Workflows)

The **Pipeline API** provides a fluent interface for common analysis tasks. It's ideal for rapid prototyping and standardized workflows.

In [None]:
# Create RheoData container with metadata
data = RheoData(
    x=omega,
    y=G_star_noisy,
    x_units='rad/s',
    y_units='Pa',
    domain='frequency',
)

# Pipeline API workflow with timing
start_pipeline = time.time()

pipeline = Pipeline(data)
pipeline.fit('zener')

pipeline_time = time.time() - start_pipeline

# Extract fitted parameters
model = pipeline.get_last_model()
Ge_pipeline = model.parameters.get_value('Ge')
Gm_pipeline = model.parameters.get_value('Gm')
eta_pipeline = model.parameters.get_value('eta')
tau_pipeline = eta_pipeline / Gm_pipeline

print("\n" + "="*60)
print("PIPELINE API RESULTS")
print("="*60)
print(f"Fitted Parameters:")
print(f"  Ge  = {Ge_pipeline:.4e} Pa  (true: {Ge_true:.4e})")
print(f"  Gm  = {Gm_pipeline:.4e} Pa  (true: {Gm_true:.4e})")
print(f"  eta = {eta_pipeline:.4e} Pa·s  (true: {eta_true:.4e})")
print(f"  tau = {tau_pipeline:.6f} s  (true: {tau_true:.6f})")
print(f"\nRelative Errors:")
print(f"  Ge:  {abs(Ge_pipeline - Ge_true) / Ge_true * 100:.4f}%")
print(f"  Gm:  {abs(Gm_pipeline - Gm_true) / Gm_true * 100:.4f}%")
print(f"  eta: {abs(eta_pipeline - eta_true) / eta_true * 100:.4f}%")
print(f"\nOptimization time: {pipeline_time:.4f} s")
print("="*60)

## Approach 2: Modular API (Recommended for Customization)

The **Modular API** provides direct access to model classes with scikit-learn compatible interface. Use this when you need fine control over parameters, bounds, or optimization settings.

In [None]:
# Create Zener model instance
model = Zener()

# Set parameter bounds (optional but recommended)
model.parameters.set_bounds('Ge', (1e2, 1e6))  # Reasonable modulus range
model.parameters.set_bounds('Gm', (1e3, 1e7))  # Reasonable modulus range
model.parameters.set_bounds('eta', (1e1, 1e5))  # Reasonable viscosity range

# Fit with timing
start_modular = time.time()

model.fit(omega, G_star_noisy)

modular_time = time.time() - start_modular

# Extract fitted parameters
Ge_modular = model.parameters.get_value('Ge')
Gm_modular = model.parameters.get_value('Gm')
eta_modular = model.parameters.get_value('eta')
tau_modular = eta_modular / Gm_modular

print("\n" + "="*60)
print("MODULAR API RESULTS")
print("="*60)
print(f"Fitted Parameters:")
print(f"  Ge  = {Ge_modular:.4e} Pa  (true: {Ge_true:.4e})")
print(f"  Gm  = {Gm_modular:.4e} Pa  (true: {Gm_true:.4e})")
print(f"  eta = {eta_modular:.4e} Pa·s  (true: {eta_true:.4e})")
print(f"  tau = {tau_modular:.6f} s  (true: {tau_true:.6f})")
print(f"\nRelative Errors:")
print(f"  Ge:  {abs(Ge_modular - Ge_true) / Ge_true * 100:.4f}%")
print(f"  Gm:  {abs(Gm_modular - Gm_true) / Gm_true * 100:.4f}%")
print(f"  eta: {abs(eta_modular - eta_true) / eta_true * 100:.4f}%")
print(f"\nOptimization time: {modular_time:.4f} s")
print("="*60)

# Verify both approaches give same results
assert np.allclose(Ge_pipeline, Ge_modular, rtol=1e-9), "Pipeline and Modular should give identical results"
assert np.allclose(Gm_pipeline, Gm_modular, rtol=1e-9), "Pipeline and Modular should give identical results"
assert np.allclose(eta_pipeline, eta_modular, rtol=1e-9), "Pipeline and Modular should give identical results"
print("\n✓ Pipeline and Modular APIs produce identical results")

## Performance Benchmark: NLSQ vs SciPy

Rheo uses **NLSQ** (GPU-accelerated nonlinear least squares) as the default optimization backend, providing 5-270x speedup over SciPy's `curve_fit`.

The speedup comes from:
1. **JAX JIT compilation** - compiles optimization to optimized XLA code
2. **Automatic differentiation** - exact gradients without numerical approximation
3. **GPU acceleration** - parallel computation on CUDA devices (if available)

Let's measure actual performance on your hardware:

In [None]:
# Benchmark: Multiple fits to get reliable timing
n_runs = 10
times = []

for i in range(n_runs):
    model_bench = Zener()
    start = time.time()
    model_bench.fit(omega, G_star_noisy)
    times.append(time.time() - start)

nlsq_mean = np.mean(times[1:])  # Exclude first run (JIT compilation)
nlsq_std = np.std(times[1:])

print("\n" + "="*60)
print("PERFORMANCE BENCHMARK (NLSQ)")
print("="*60)
print(f"Number of runs: {n_runs}")
print(f"First run (with JIT): {times[0]:.4f} s")
print(f"Subsequent runs: {nlsq_mean:.4f} ± {nlsq_std:.4f} s")
print(f"JIT overhead: {times[0] - nlsq_mean:.4f} s")
print(f"\nNOTE: SciPy curve_fit typically takes 0.05-0.5s for this problem")
print(f"Expected speedup: 5-270x depending on problem size and GPU")
print(f"For this small dataset ({len(omega)} points), speedup may be modest.")
print(f"Speedup increases dramatically with dataset size (>1000 points).")
print("="*60)

## Results Visualization

We create publication-quality visualizations showing:
1. **Fit quality** - data vs model prediction for G' and G"
2. **Residual analysis** - systematic errors or outliers?

In [None]:
# Generate predictions
G_star_pred = model.predict(omega)
G_prime_pred = np.real(G_star_pred)
G_double_prime_pred = np.imag(G_star_pred)

# Calculate residuals
residuals_Gp = G_prime_noisy - G_prime_pred
residuals_Gpp = G_double_prime_noisy - G_double_prime_pred
relative_residuals_Gp = residuals_Gp / G_prime_noisy * 100
relative_residuals_Gpp = residuals_Gpp / G_double_prime_noisy * 100

# Create figure with subplots
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Top left: G' fit quality
axes[0, 0].loglog(omega, G_prime_noisy, 'o', markersize=6, alpha=0.7, label='Data', color='#1f77b4')
axes[0, 0].loglog(omega, G_prime_true, '--', linewidth=2, alpha=0.4, label='True', color='gray')
axes[0, 0].loglog(omega, G_prime_pred, '-', linewidth=2.5, label='Fitted', color='#ff7f0e')
axes[0, 0].set_xlabel('ω (rad/s)', fontsize=12, fontweight='bold')
axes[0, 0].set_ylabel("Storage Modulus G' (Pa)", fontsize=12, fontweight='bold')
axes[0, 0].set_title("G' Fit Quality", fontsize=13, fontweight='bold')
axes[0, 0].grid(True, alpha=0.3, which='both')
axes[0, 0].legend(fontsize=10, framealpha=0.9)

# Top right: G" fit quality
axes[0, 1].loglog(omega, G_double_prime_noisy, 's', markersize=6, alpha=0.7, label='Data', color='#1f77b4')
axes[0, 1].loglog(omega, G_double_prime_true, '--', linewidth=2, alpha=0.4, label='True', color='gray')
axes[0, 1].loglog(omega, G_double_prime_pred, '-', linewidth=2.5, label='Fitted', color='#ff7f0e')
axes[0, 1].set_xlabel('ω (rad/s)', fontsize=12, fontweight='bold')
axes[0, 1].set_ylabel('Loss Modulus G" (Pa)', fontsize=12, fontweight='bold')
axes[0, 1].set_title('G" Fit Quality', fontsize=13, fontweight='bold')
axes[0, 1].grid(True, alpha=0.3, which='both')
axes[0, 1].legend(fontsize=10, framealpha=0.9)

# Bottom left: G' residuals
axes[1, 0].semilogx(omega, relative_residuals_Gp, 'o', markersize=6, alpha=0.7, color='#2ca02c')
axes[1, 0].axhline(0, color='black', linestyle='--', linewidth=1, alpha=0.5)
axes[1, 0].axhline(noise_level * 100, color='red', linestyle=':', linewidth=1.5, alpha=0.5, label=f'Expected noise: ±{noise_level*100:.1f}%')
axes[1, 0].axhline(-noise_level * 100, color='red', linestyle=':', linewidth=1.5, alpha=0.5)
axes[1, 0].set_xlabel('ω (rad/s)', fontsize=12, fontweight='bold')
axes[1, 0].set_ylabel("G' Relative Residuals (%)", fontsize=12, fontweight='bold')
axes[1, 0].set_title("G' Residual Analysis", fontsize=13, fontweight='bold')
axes[1, 0].grid(True, alpha=0.3)
axes[1, 0].legend(fontsize=9, framealpha=0.9)

# Bottom right: G" residuals
axes[1, 1].semilogx(omega, relative_residuals_Gpp, 's', markersize=6, alpha=0.7, color='#2ca02c')
axes[1, 1].axhline(0, color='black', linestyle='--', linewidth=1, alpha=0.5)
axes[1, 1].axhline(noise_level * 100, color='red', linestyle=':', linewidth=1.5, alpha=0.5, label=f'Expected noise: ±{noise_level*100:.1f}%')
axes[1, 1].axhline(-noise_level * 100, color='red', linestyle=':', linewidth=1.5, alpha=0.5)
axes[1, 1].set_xlabel('ω (rad/s)', fontsize=12, fontweight='bold')
axes[1, 1].set_ylabel('G" Relative Residuals (%)', fontsize=12, fontweight='bold')
axes[1, 1].set_title('G" Residual Analysis', fontsize=13, fontweight='bold')
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].legend(fontsize=9, framealpha=0.9)

plt.tight_layout()
plt.show()

# Compute fit quality metrics
ss_res_Gp = np.sum(residuals_Gp**2)
ss_tot_Gp = np.sum((G_prime_noisy - np.mean(G_prime_noisy))**2)
r_squared_Gp = 1 - (ss_res_Gp / ss_tot_Gp)

ss_res_Gpp = np.sum(residuals_Gpp**2)
ss_tot_Gpp = np.sum((G_double_prime_noisy - np.mean(G_double_prime_noisy))**2)
r_squared_Gpp = 1 - (ss_res_Gpp / ss_tot_Gpp)

print("\nFit Quality Metrics:")
print(f"  G' R² = {r_squared_Gp:.6f}")
print(f"  G' RMSE = {np.sqrt(np.mean(residuals_Gp**2)):.2e} Pa")
print(f"  G' Mean |residual| = {np.mean(np.abs(residuals_Gp)):.2e} Pa ({np.mean(np.abs(relative_residuals_Gp)):.2f}%)")
print(f"\n  G'' R² = {r_squared_Gpp:.6f}")
print(f"  G'' RMSE = {np.sqrt(np.mean(residuals_Gpp**2)):.2e} Pa")
print(f"  G'' Mean |residual| = {np.mean(np.abs(residuals_Gpp)):.2e} Pa ({np.mean(np.abs(relative_residuals_Gpp)):.2f}%)")

## Bayesian Inference: Uncertainty Quantification

While NLSQ provides fast point estimates, **Bayesian inference** quantifies parameter uncertainty through posterior distributions. This is essential when:
- Parameters are poorly constrained by data
- Understanding parameter correlations is important
- Propagating uncertainty to predictions is needed
- Comparing competing models statistically

### Two-Stage Workflow: NLSQ → NUTS

1. **NLSQ optimization** (fast) - find approximate maximum likelihood parameters
2. **NUTS sampling** (slower) - warm-start from NLSQ for 2-5x faster convergence

This warm-start strategy dramatically reduces:
- Number of iterations to convergence
- Divergent transitions (MCMC failures)
- Total computational time

In [None]:
print("\n" + "="*60)
print("BAYESIAN INFERENCE WITH WARM-START")
print("="*60)
print("Running MCMC sampling... (this may take 1-2 minutes)\n")

# Bayesian inference using warm-start from NLSQ
bayesian_start = time.time()

result = model.fit_bayesian(
    omega, G_star_noisy,
    num_warmup=1000,   # Burn-in iterations
    num_samples=2000,  # Posterior samples
    num_chains=1,      # Single chain (faster for demo)
    initial_values={   # Warm-start from NLSQ
        'Ge': model.parameters.get_value('Ge'),
        'Gm': model.parameters.get_value('Gm'),
        'eta': model.parameters.get_value('eta')
    }
)

bayesian_time = time.time() - bayesian_start

print(f"\nBayesian inference completed in {bayesian_time:.2f} s")
print("="*60)

### Posterior Summary and Convergence Diagnostics

Key metrics for MCMC quality:
- **R-hat < 1.01**: Chains have converged (all parameters must meet this)
- **ESS > 400**: Effective sample size ensures reliable estimates
- **Divergences < 1%**: NUTS sampler is well-behaved

In [None]:
# Extract posterior samples and diagnostics
posterior = result.posterior_samples
diagnostics = result.diagnostics
summary = result.summary

# Get credible intervals
credible_intervals = model.get_credible_intervals(posterior, credibility=0.95)

print("\n" + "="*60)
print("POSTERIOR SUMMARY")
print("="*60)
print(f"\nParameter Estimates (posterior mean ± std):")
print(f"  Ge  = {summary['Ge']['mean']:.4e} ± {summary['Ge']['std']:.4e} Pa")
print(f"  Gm  = {summary['Gm']['mean']:.4e} ± {summary['Gm']['std']:.4e} Pa")
print(f"  eta = {summary['eta']['mean']:.4e} ± {summary['eta']['std']:.4e} Pa·s")

print(f"\n95% Credible Intervals:")
print(f"  Ge:  [{credible_intervals['Ge'][0]:.4e}, {credible_intervals['Ge'][1]:.4e}] Pa")
print(f"  Gm:  [{credible_intervals['Gm'][0]:.4e}, {credible_intervals['Gm'][1]:.4e}] Pa")
print(f"  eta: [{credible_intervals['eta'][0]:.4e}, {credible_intervals['eta'][1]:.4e}] Pa·s")

print(f"\nConvergence Diagnostics:")
print(f"  R-hat (Ge):  {diagnostics['r_hat']['Ge']:.4f}  {'✓' if diagnostics['r_hat']['Ge'] < 1.01 else '✗ POOR'}")
print(f"  R-hat (Gm):  {diagnostics['r_hat']['Gm']:.4f}  {'✓' if diagnostics['r_hat']['Gm'] < 1.01 else '✗ POOR'}")
print(f"  R-hat (eta): {diagnostics['r_hat']['eta']:.4f}  {'✓' if diagnostics['r_hat']['eta'] < 1.01 else '✗ POOR'}")
print(f"  ESS (Ge):    {diagnostics['ess']['Ge']:.0f}  {'✓' if diagnostics['ess']['Ge'] > 400 else '✗ LOW'}")
print(f"  ESS (Gm):    {diagnostics['ess']['Gm']:.0f}  {'✓' if diagnostics['ess']['Gm'] > 400 else '✗ LOW'}")
print(f"  ESS (eta):   {diagnostics['ess']['eta']:.0f}  {'✓' if diagnostics['ess']['eta'] > 400 else '✗ LOW'}")

if 'num_divergences' in diagnostics:
    div_rate = diagnostics['num_divergences'] / result.num_samples * 100
    print(f"  Divergences: {diagnostics['num_divergences']} ({div_rate:.2f}%)  {'✓' if div_rate < 1 else '✗ HIGH'}")

print("\n" + "="*60)

# Check convergence
converged = all([
    diagnostics['r_hat']['Ge'] < 1.01,
    diagnostics['r_hat']['Gm'] < 1.01,
    diagnostics['r_hat']['eta'] < 1.01,
    diagnostics['ess']['Ge'] > 400,
    diagnostics['ess']['Gm'] > 400,
    diagnostics['ess']['eta'] > 400
])

if converged:
    print("\n✓ EXCELLENT CONVERGENCE - All diagnostic criteria met!")
else:
    print("\n⚠ WARNING: Convergence criteria not met. Increase num_warmup or num_samples.")

## ArviZ Diagnostic Plots: Comprehensive MCMC Quality Assessment

ArviZ provides 6 essential diagnostic plots to assess MCMC quality. Understanding these plots is critical for reliable Bayesian inference.

### Plot 1: Trace Plot - Visualize MCMC Chains

In [None]:
import arviz as az

# Convert to ArviZ InferenceData for plotting
idata = result.to_inference_data()

# Trace plot: visualize sampling
az.plot_trace(idata, figsize=(12, 8))
plt.tight_layout()
plt.show()

print("""
INTERPRETATION - Trace Plot:
- LEFT: Posterior marginal distributions (should be smooth, unimodal)
- RIGHT: Parameter values vs iteration (should look like "fuzzy caterpillar")
- GOOD: Stationary mean, no trends, no stuck regions
- BAD: Drift, discontinuities, bimodal distributions
""")

### Plot 2: Pair Plot - Parameter Correlations and Divergences

In [None]:
# Pair plot: parameter correlations
az.plot_pair(
    idata,
    var_names=['Ge', 'Gm', 'eta'],
    kind='scatter',
    divergences=True,  # Highlight problematic samples
    figsize=(12, 10)
)
plt.tight_layout()
plt.show()

print("""
INTERPRETATION - Pair Plot:
- DIAGONAL: Marginal posterior distributions
- OFF-DIAGONAL: Joint distributions (parameter correlations)
- RED POINTS: Divergent transitions (MCMC failures)

What to look for:
✓ GOOD: Elliptical joint distribution, few/no divergences
✗ BAD: Funnel geometry, strong correlations, many divergences

For Zener model:
- Ge and Gm may show weak correlation (both contribute to G')
- Gm and eta often correlated (both determine relaxation time τ)
- Strong correlations indicate parameter non-identifiability
""")

### Plot 3: Forest Plot - Credible Interval Comparison

In [None]:
# Forest plot: credible intervals
az.plot_forest(
    idata,
    var_names=['Ge', 'Gm', 'eta'],
    hdi_prob=0.95,  # 95% highest density interval
    combined=True,
    figsize=(10, 4)
)
plt.tight_layout()
plt.show()

print("""
INTERPRETATION - Forest Plot:
- THICK LINE: 95% credible interval (95% probability parameter in this range)
- THIN LINE: Full posterior range
- DOT: Posterior mean

What to look for:
✓ GOOD: Narrow credible intervals (well-constrained parameters)
✗ BAD: Very wide intervals (poorly constrained, need more data or tighter priors)

Compare:
- Relative uncertainty: σ/μ for each parameter
- Parameter magnitudes: Are scales appropriate?
""")

### Plot 4: Autocorrelation Plot - Mixing Quality

In [None]:
# Autocorrelation plot: mixing diagnostic
az.plot_autocorr(
    idata,
    var_names=['Ge', 'Gm', 'eta'],
    max_lag=100,
    figsize=(12, 4)
)
plt.tight_layout()
plt.show()

print("""
INTERPRETATION - Autocorrelation Plot:
- Y-axis: Correlation between samples at different lags
- X-axis: Lag (number of iterations)

What to look for:
✓ GOOD: Autocorrelation drops to ~0 within few dozen lags
✗ BAD: Slow decay (high autocorrelation) → poor mixing

If autocorrelation is high:
- Increase num_samples to get more effective samples
- Check for parameter correlations (use pair plot)
- Consider reparameterization if persistent

Relation to ESS:
- High autocorrelation → low ESS (fewer independent samples)
- ESS = num_samples / (1 + 2*Σ autocorrelations)
""")

### Plot 5: Rank Plot - Convergence Diagnostic

In [None]:
# Rank plot: modern convergence diagnostic
az.plot_rank(
    idata,
    var_names=['Ge', 'Gm', 'eta'],
    figsize=(12, 4)
)
plt.tight_layout()
plt.show()

print("""
INTERPRETATION - Rank Plot:
- Histogram of ranked parameter values across chains
- Modern alternative to trace plots for convergence

What to look for:
✓ GOOD: Uniform histogram (flat, all bins similar height)
✗ BAD: Non-uniform (peaks, valleys, trends)

Non-uniform patterns indicate:
- Chains exploring different regions (not converged)
- Chain sticking or slow mixing
- Need more warmup iterations

This is the MOST SENSITIVE convergence diagnostic:
- More reliable than R-hat for detecting subtle issues
- Should always check even if R-hat < 1.01
""")

### Plot 6: ESS Plot - Effective Sample Size

**Note:** ESS plot requires multiple chains for meaningful results. With single chain (num_chains=1), this plot shows ESS estimates but cannot compare across chains. For production work, use num_chains=4.

In [None]:
# ESS plot: effective sample size
try:
    az.plot_ess(
        idata,
        var_names=['Ge', 'Gm', 'eta'],
        kind='local',  # 'local', 'quantile', or 'evolution'
        figsize=(12, 4)
    )
    plt.tight_layout()
    plt.show()
except Exception as e:
    print(f"Note: ESS plot requires multiple chains for full functionality.")
    print(f"Current setup: {idata.posterior.chain.size} chain(s)")
    print(f"For production, use num_chains=4.\n")
    
    # Show ESS values instead
    print(f"Effective Sample Size (ESS):")
    print(f"  Ge:  {diagnostics['ess']['Ge']:.0f} / {result.num_samples} samples ({diagnostics['ess']['Ge']/result.num_samples*100:.1f}%)")
    print(f"  Gm:  {diagnostics['ess']['Gm']:.0f} / {result.num_samples} samples ({diagnostics['ess']['Gm']/result.num_samples*100:.1f}%)")
    print(f"  eta: {diagnostics['ess']['eta']:.0f} / {result.num_samples} samples ({diagnostics['ess']['eta']/result.num_samples*100:.1f}%)")

print("""
INTERPRETATION - ESS Plot:
- Quantifies sampling efficiency per parameter
- ESS = number of "independent" samples (accounting for autocorrelation)

What to look for:
✓ GOOD: ESS > 400 (bulk and tail) for all parameters
✗ BAD: Low ESS → need more samples or better mixing

ESS types:
- BULK: Central posterior region (mean, median estimates)
- TAIL: Extreme quantiles (credible interval estimates)
- LOCAL: ESS at different quantiles

If ESS is low:
1. Increase num_samples (more iterations)
2. Check autocorrelation plot (poor mixing?)
3. Use multiple chains (num_chains=4) for better estimates
4. Warm-start from NLSQ (already doing this!)
""")

## Physical Interpretation

Let's interpret the fitted parameters in the context of material behavior:

### Parameter Meanings

**Equilibrium Modulus (Ge):**
- Represents long-time elastic response (t→∞)
- For our fit: ~1×10⁴ Pa (10 kPa)
- Physical meaning: Permanent network structure or entanglement contribution
- Typical range: 10² - 10⁶ Pa depending on material

**Maxwell Modulus (Gm):**
- Represents transient elastic component
- For our fit: ~5×10⁴ Pa (50 kPa)
- Physical meaning: Temporary elastic storage that relaxes
- Typical range: 10³ - 10⁷ Pa

**Viscosity (η):**
- Represents resistance to flow
- For our fit: ~1×10³ Pa·s
- Physical meaning: Controls rate of stress relaxation
- Typical range: 10⁻² - 10⁶ Pa·s

**Relaxation Time (τ = η/Gm):**
- Time scale for stress decay to 1/e (~37%) of initial value
- For our fit: ~0.02 s
- Physical meaning: Fast relaxation → fluid-like, Slow relaxation → solid-like

### Material Classification

Based on Ge/Gm ratio:
- **Ge/Gm < 0.1**: Predominantly viscous (weak gel, concentrated solution)
- **0.1 < Ge/Gm < 10**: Viscoelastic (soft solids, weak gels)
- **Ge/Gm > 10**: Predominantly elastic (strong gels, elastomers)

Our material (Ge/Gm ≈ 0.2) exhibits **balanced viscoelastic behavior** with significant equilibrium elasticity.

### Model Limitations

The Zener model is valid when:
- ✓ Small strains (linear viscoelastic regime, typically < 10%)
- ✓ Single dominant relaxation time
- ✓ Finite equilibrium modulus

Consider alternative models if:
- ✗ Multiple relaxation times needed → Generalized Maxwell (Prony series)
- ✗ No equilibrium modulus → Maxwell model
- ✗ Power-law relaxation → Fractional Zener models
- ✗ Large strain behavior → Nonlinear models

In [None]:
# Summary table of results
print("\n" + "="*70)
print("FINAL PARAMETER SUMMARY")
print("="*70)
print(f"\n{'Method':<20} {'Ge (Pa)':<15} {'Gm (Pa)':<15} {'eta (Pa·s)':<15} {'tau (s)':<10}")
print("-"*70)
print(f"{'True Values':<20} {Ge_true:<15.4e} {Gm_true:<15.4e} {eta_true:<15.4e} {tau_true:<10.6f}")
print(f"{'NLSQ (Point)':<20} {Ge_modular:<15.4e} {Gm_modular:<15.4e} {eta_modular:<15.4e} {tau_modular:<10.6f}")
print(f"{'Bayesian (Mean)':<20} {summary['Ge']['mean']:<15.4e} {summary['Gm']['mean']:<15.4e} {summary['eta']['mean']:<15.4e} {summary['eta']['mean']/summary['Gm']['mean']:<10.6f}")
print("-"*70)

# Uncertainty from Bayesian inference
print(f"\n{'Bayesian Uncertainty (1σ):':<20} {summary['Ge']['std']:<15.4e} {summary['Gm']['std']:<15.4e} {summary['eta']['std']:<15.4e}")
print(f"{'Relative Uncertainty:':<20} {summary['Ge']['std']/summary['Ge']['mean']*100:<15.2f}% {summary['Gm']['std']/summary['Gm']['mean']*100:<15.2f}% {summary['eta']['std']/summary['eta']['mean']*100:<15.2f}%")
print("\n" + "="*70)

# Material classification
Ge_Gm_ratio = Ge_modular / Gm_modular
print(f"\nPhysical Interpretation:")
print(f"  Ge/Gm ratio: {Ge_Gm_ratio:.3f}")
if Ge_Gm_ratio < 0.1:
    material_type = "Predominantly viscous (weak gel/solution)"
elif Ge_Gm_ratio < 10:
    material_type = "Balanced viscoelastic (soft solid/gel)"
else:
    material_type = "Predominantly elastic (strong gel/elastomer)"
print(f"  Material Type: {material_type}")
print(f"  Relaxation Time: {tau_modular:.6f} s")
print(f"  Equilibrium Modulus: {Ge_modular:.2e} Pa ({Ge_modular/1e3:.1f} kPa)")
print(f"  Total Modulus (G0): {(Ge_modular + Gm_modular):.2e} Pa ({(Ge_modular + Gm_modular)/1e3:.1f} kPa)")

## Key Takeaways

### Main Concepts

1. **Zener Model Characteristics:**
   - Three-parameter model: Ge (equilibrium), Gm (Maxwell), eta (viscosity)
   - Finite equilibrium modulus distinguishes from Maxwell model
   - Single relaxation time describes transient response
   - Applicable to crosslinked polymers, gels, soft solids

2. **Oscillatory Shear Data:**
   - Complex modulus: G* = G' + iG"
   - G' (storage) measures elastic energy storage
   - G" (loss) measures viscous dissipation
   - tan(δ) = G"/G' quantifies viscoelastic character

3. **NLSQ Optimization:**
   - Default backend provides 5-270x speedup vs SciPy
   - JAX JIT compilation + automatic differentiation
   - GPU acceleration available (additional 10-100x for large datasets)
   - Float64 precision enforced via safe_import_jax()

4. **Bayesian Uncertainty Quantification:**
   - Two-stage workflow: NLSQ (fast) → NUTS (warm-start)
   - Warm-start reduces convergence time 2-5x
   - Provides credible intervals and parameter correlations
   - Essential for identifying non-identifiability issues

5. **ArviZ Diagnostic Suite:**
   - **6 essential plots** assess MCMC quality comprehensively
   - Must check: R-hat < 1.01, ESS > 400, divergences < 1%
   - Rank plot is most sensitive convergence diagnostic
   - Pair plot reveals parameter correlations (Gm-eta often correlated)

### When to Use Zener Model

**Appropriate for:**
- ✓ Crosslinked polymers with finite equilibrium modulus
- ✓ Gels and soft solids (physical or chemical networks)
- ✓ Materials with single dominant relaxation time
- ✓ Small strain linear viscoelastic regime

**Consider alternatives for:**
- ✗ Complete stress relaxation (Ge=0) → Maxwell model
- ✗ Multiple relaxation times → Generalized Maxwell
- ✗ Power-law frequency dependence → Fractional Zener models
- ✗ Solid-like materials with no flow → Kelvin-Voigt

### Common Pitfalls

1. **Parameter Correlation:**
   - Gm and eta often correlated (both determine τ)
   - Check pair plot for non-identifiability
   - May need multi-technique fitting to improve constraint

2. **Frequency Range:**
   - Need data spanning relaxation time: 0.1τ < 1/ω < 10τ
   - Insufficient range → poor Ge or Gm estimation
   - Use mastercurve generation to extend range

3. **Model Selection:**
   - Check residuals for systematic trends
   - Zener may be insufficient if multiple relaxation times present
   - Use Bayesian model comparison (WAIC, LOO) to compare alternatives

## Next Steps

### Explore Related Models
- **[03-springpot-fitting.ipynb](03-springpot-fitting.ipynb)**: Fractional element for power-law behavior
- **[01-maxwell-fitting.ipynb](01-maxwell-fitting.ipynb)**: Special case with Ge=0
- **Advanced fractional models**: See `advanced/04-fractional-models-deep-dive.ipynb`

### Deepen Bayesian Understanding
- **[bayesian/01-bayesian-basics.ipynb](../bayesian/01-bayesian-basics.ipynb)**: Comprehensive NLSQ→NUTS workflow
- **[bayesian/03-convergence-diagnostics.ipynb](../bayesian/03-convergence-diagnostics.ipynb)**: Deep dive into all 6 ArviZ plots
- **[bayesian/04-model-comparison.ipynb](../bayesian/04-model-comparison.ipynb)**: Statistical model selection

### Advanced Workflows
- **[transforms/02-mastercurve-generation.ipynb](../transforms/02-mastercurve-generation.ipynb)**: Extend frequency range via TTS
- **[advanced/01-multi-technique-fitting.ipynb](../advanced/01-multi-technique-fitting.ipynb)**: Constrained fitting across test modes

---

## Session Information

In [None]:
# Print session information for reproducibility
import sys
import rheojax

print(f"Python version: {sys.version}")
print(f"Rheo version: {rheojax.__version__}")
print(f"JAX version: {jax.__version__}")
print(f"JAX devices: {jax.devices()}")
print(f"NumPy version: {np.__version__}")
print(f"ArviZ version: {az.__version__}")