# SPP LAOS: Complete NLSQ → NUTS Workflow

> **Handbook:** See [SPP LAOS Workflow](../../docs/source/transforms/spp_decomposer.rst#complete-workflow) for pipeline implementation and [Rogers SPP Defaults](../../docs/source/transforms/spp_decomposer.rst#theory) for parameter selection guidance.

This notebook demonstrates the complete Sequence of Physical Processes (SPP) framework for
Large Amplitude Oscillatory Shear (LAOS) analysis using RheoJAX.

## Learning Objectives

After completing this notebook, you will be able to:
- Apply SPP decomposition to LAOS amplitude sweep data
- Extract yield stress scaling relationships using power-law fits
- Use SPPAmplitudeSweepPipeline for automated analysis
- Perform NLSQ → NUTS workflow for SPP parameter uncertainty
- Interpret convergence diagnostics for SPP models
- Understand divergence sources in hierarchical noise models

## SPP Theory Overview

The SPP framework (Rogers et al.) decomposes nonlinear stress responses into:
- **Elastic contribution**: Recoverable strain energy (in-phase with strain)
- **Viscous contribution**: Dissipated energy (in-phase with strain rate)

Key outputs:
- `sigma_sy`: Static yield stress (maximum elastic stress)
- `sigma_dy`: Dynamic yield stress (maximum viscous stress)  
- `S_factor`: Stiffening ratio (-1 to 1, 0 = linear)
- `T_factor`: Thickening ratio (-1 to 1, 0 = linear)

**Defaults**: n_harmonics=39, step_size=8, num_mode=2, wrapped strain-rate inference.

**Estimated Time:** 25-30 minutes

In [None]:
# Configure matplotlib for inline plotting in VS Code/Jupyter
# MUST come before importing matplotlib
%matplotlib inline

import arviz as az
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display

from rheojax.core.data import RheoData
from rheojax.pipeline.workflows import SPPAmplitudeSweepPipeline
from rheojax.transforms import SPPDecomposer

import sys, os
sys.path.insert(0, os.path.dirname(os.path.abspath("")))
from utils.plotting_utils import (
    plot_nlsq_fit, display_arviz_diagnostics, plot_posterior_predictive
)

FAST_MODE = os.environ.get("FAST_MODE", "1") == "1"


## 1. Generate Synthetic LAOS Data

Create amplitude sweep data with power-law yielding behavior:
$$\sigma = A \cdot \gamma_0^n \cdot \sin(\omega t)$$

where $A=60$ Pa and $n=0.7$ (sublinear yielding).

In [None]:
# Experimental parameters
omega = 1.5  # rad/s
gamma_levels = jnp.array([0.1, 0.2, 0.4, 0.8, 1.6])  # strain amplitudes
n_points = 400  # points per cycle
t = jnp.linspace(0, 2 * jnp.pi / omega, n_points)

# Power-law yield parameters (ground truth)
A_true = 60.0  # Pa
n_true = 0.7   # exponent

def make_dataset(gamma_0):
    """Generate synthetic LAOS dataset with power-law yielding."""
    strain = gamma_0 * jnp.sin(omega * t)
    # Add slight nonlinearity for more realistic behavior
    stress = A_true * gamma_0**n_true * jnp.sin(omega * t)
    # Add small noise
    noise = np.random.normal(0, 0.5, len(t))
    return RheoData(
        x=np.array(t),
        y=np.array(stress) + noise,
        domain="time",
        metadata={"omega": float(omega), "gamma_0": float(gamma_0), "strain": np.array(strain)}
    )

np.random.seed(42)
datasets = [make_dataset(float(g)) for g in gamma_levels]
print(f"Created {len(datasets)} datasets at γ₀ = {list(gamma_levels)}")

## 2. Visualize Lissajous Curves

Lissajous (stress-strain) curves reveal nonlinear viscoelastic behavior:
- **Ellipse**: Linear viscoelastic
- **Distorted ellipse**: Nonlinear response

In [None]:
fig, axes = plt.subplots(1, len(datasets), figsize=(3*len(datasets), 3), sharey=True)
if len(datasets) == 1:
    axes = [axes]

for ax, ds, gamma in zip(axes, datasets, gamma_levels):
    strain = ds.metadata["strain"]
    stress = ds.y
    ax.plot(strain, stress, 'b-', lw=1.5)
    ax.set_xlabel(r'Strain $\gamma$')
    ax.set_title(f'$\\gamma_0$ = {gamma:.2f}')
    ax.axhline(0, color='gray', lw=0.5, ls='--')
    ax.axvline(0, color='gray', lw=0.5, ls='--')

axes[0].set_ylabel('Stress (Pa)')
fig.suptitle('Lissajous Curves (Stress vs Strain)', fontsize=12)
plt.tight_layout()
display(fig)
plt.close(fig)

## 3. SPP Decomposition (Single Amplitude)

Demonstrate the SPPDecomposer on a single dataset to understand the outputs.

In [None]:
# Select middle amplitude for demonstration
idx = len(datasets) // 2
demo_data = datasets[idx]
demo_gamma = float(gamma_levels[idx])

# Create decomposer with Rogers defaults
decomposer = SPPDecomposer(omega=float(omega), gamma_0=demo_gamma)
result = decomposer.transform(demo_data)

# Display key results
print(f"SPP Decomposition at γ₀ = {demo_gamma}:")
print(f"  Static yield stress (σ_sy):  {decomposer.results_['sigma_sy']:.3f} Pa")
print(f"  Dynamic yield stress (σ_dy): {decomposer.results_['sigma_dy']:.3f} Pa")
print(f"  Stiffening factor (S):       {decomposer.results_['S_factor']:.4f}")
print(f"  Thickening factor (T):       {decomposer.results_['T_factor']:.4f}")

## 4. Amplitude Sweep Pipeline

Process all amplitudes and extract yield stress vs strain amplitude relationship.

In [None]:
# Run SPP amplitude sweep pipeline
pipeline = SPPAmplitudeSweepPipeline(omega=float(omega))
pipeline.run(datasets, gamma_0_values=list(map(float, gamma_levels)))

# Get yield stresses for all amplitudes
yield_data = pipeline.get_yield_stresses()
sigma_sy = yield_data["sigma_sy"]
sigma_dy = yield_data["sigma_dy"]

print("Amplitude Sweep Results:")
print(f"{'γ₀':>8} {'σ_sy (Pa)':>12} {'σ_dy (Pa)':>12}")
print("-" * 34)
for g, sy, dy in zip(gamma_levels, sigma_sy, sigma_dy):
    print(f"{float(g):>8.3f} {float(sy):>12.3f} {float(dy):>12.3f}")

In [None]:
# Plot yield stress vs strain amplitude
fig, ax = plt.subplots(figsize=(6, 4))

ax.loglog(gamma_levels, sigma_sy, 'o-', label=r'$\sigma_{sy}$ (static)', markersize=8)
ax.loglog(gamma_levels, sigma_dy, 's--', label=r'$\sigma_{dy}$ (dynamic)', markersize=8)

# Reference power-law
gamma_ref = np.linspace(float(gamma_levels.min()), float(gamma_levels.max()), 100)
ax.loglog(gamma_ref, A_true * gamma_ref**n_true, 'k:', lw=2, 
          label=f'True: {A_true}·γ$^{{{n_true}}}$')

ax.set_xlabel(r'Strain Amplitude $\gamma_0$')
ax.set_ylabel('Yield Stress (Pa)')
ax.set_title('SPP Yield Stress vs Strain Amplitude')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
display(fig)
plt.close(fig)

## 5. NLSQ Fitting (Point Estimation)

Fit power-law model to yield stress data using fast NLSQ optimization.

In [None]:
# Fit power-law model: σ_sy = scale * γ^exp
pipeline.fit_model(bayesian=False, yield_type="static")
model = pipeline.get_model()

# Get fitted parameters
params = model.parameters
scale_nlsq = params["sigma_sy_scale"].value
exp_nlsq = params["sigma_sy_exp"].value

print("NLSQ Fit Results (Point Estimates):")
print(f"  σ_sy_scale: {scale_nlsq:.4f} Pa (true: {A_true})")
print(f"  σ_sy_exp:   {exp_nlsq:.4f} (true: {n_true})")

## 6. Bayesian Inference with NUTS

Quantify parameter uncertainty using Bayesian inference with warm-start from NLSQ.

In [None]:
# Bayesian inference with proper settings for convergence
# Key settings:
# - num_chains=4: Multiple chains enable proper R-hat computation
# - num_warmup=2000: More warmup for better adaptation
# - num_samples=2000: Enough samples per chain for reliable posterior estimates
# - target_accept_prob=0.99: Very high to minimize divergences
# - max_tree_depth=12: Allow deeper trees for complex posteriors
bayes = model.fit_bayesian(
    gamma_levels, 
    sigma_sy,
    test_mode="oscillation",
    num_chains=4,
    num_warmup=2000,           # More warmup for better step size adaptation
    num_samples=2000,
    target_accept_prob=0.99,   # Very high = very small steps = minimal divergences
    max_tree_depth=12,         # Allow deeper tree exploration
)

# Extract posterior statistics
print("\nBayesian Posterior Summary:")
print(f"{'Parameter':>15} {'Mean':>10} {'Std':>10} {'5%':>10} {'95%':>10}")
print("-" * 57)
for param in ["sigma_sy_scale", "sigma_sy_exp"]:
    if param in bayes.summary:
        s = bayes.summary[param]
        print(f"{param:>15} {s['mean']:>10.4f} {s['std']:>10.4f} "
              f"{s.get('q05', 0):>10.4f} "
              f"{s.get('q95', 0):>10.4f}")

## 7. MCMC Diagnostics

Check convergence using standard diagnostics:
- **R-hat < 1.01**: Chains converged (requires multiple chains for reliable computation)
- **ESS > 400**: Sufficient effective samples
- **Divergences = 0**: No sampling issues

**Note on trace plots**: With only 5 data points and a well-constrained model, the posterior 
is very tight. Traces may appear "smooth" because the sampler explores a narrow region efficiently.
This is expected behavior, not a problem. The key diagnostic is that the trace should be 
stationary (no trends) and the histogram should be unimodal.

In [None]:
# Check diagnostics - R-hat and ESS are in the diagnostics dict, not summary
diag = bayes.diagnostics
print("MCMC Diagnostics:")
print(f"  Divergences: {diag.get('divergences', 'N/A')}")

# R-hat and ESS are stored in diagnostics['r_hat'] and diagnostics['ess'] dicts
r_hat_dict = diag.get('r_hat', {})
ess_dict = diag.get('ess', {})

print("\nConvergence by Parameter:")
for param in ["sigma_sy_scale", "sigma_sy_exp"]:
    rhat = r_hat_dict.get(param, 'N/A')
    ess = ess_dict.get(param, 'N/A')
    rhat_str = f"{rhat:.4f}" if isinstance(rhat, (int, float)) else str(rhat)
    ess_str = f"{ess:.0f}" if isinstance(ess, (int, float)) else str(ess)
    print(f"  {param}: R-hat={rhat_str}, ESS={ess_str}")

# Interpretation
print("\nInterpretation:")
divergences = diag.get('divergences', 0)
if divergences == 0:
    print("  ✓ No divergences - sampling is healthy")
elif divergences < 100:
    print(f"  ⚠ {divergences} divergences - likely from noise parameter funnel (see diagnosis below)")
    print("    → Parameters of interest (scale, exp) are still reliable if R-hat < 1.01")
else:
    print(f"  ✗ {divergences} divergences - consider reparameterization or more informative priors")

for param in ["sigma_sy_scale", "sigma_sy_exp"]:
    rhat = r_hat_dict.get(param, 1.0)
    ess = ess_dict.get(param, 0)
    if isinstance(rhat, (int, float)) and rhat < 1.01:
        print(f"  ✓ {param}: R-hat={rhat:.4f} < 1.01 (converged)")
    elif isinstance(rhat, (int, float)):
        print(f"  ⚠ {param}: R-hat={rhat:.4f} >= 1.01 (may need more samples)")
    if isinstance(ess, (int, float)) and ess > 400:
        print(f"  ✓ {param}: ESS={ess:.0f} > 400 (sufficient)")
    elif isinstance(ess, (int, float)):
        print(f"  ⚠ {param}: ESS={ess:.0f} < 400 (may need more samples)")

## 8. Posterior Visualization with ArviZ

ArviZ provides comprehensive MCMC diagnostics:

| Plot | What to Look For |
|------|------------------|
| **Trace** | Chains should overlap ("fuzzy caterpillar"), no trends |
| **Pair** | Correlations between parameters; divergences cluster in problem regions |
| **Forest** | HDI intervals should overlap across chains |
| **Autocorrelation** | Should drop quickly to zero (good mixing) |
| **Rank** | Histograms should be uniform (chains exploring same space) |

**About Divergences**: Divergences indicate the sampler encountered difficult geometry.
Common causes: (1) tight funnels from hierarchical priors, (2) strong parameter correlations,
(3) multimodal posteriors. Solutions: increase `target_accept_prob`, reparameterize, or use
more informative priors.

In [None]:
# Create ArviZ InferenceData for comprehensive diagnostics
# First, reshape samples for ArviZ (needs shape: chains x draws)
samples = bayes.posterior_samples
num_chains = bayes.num_chains
num_samples = bayes.num_samples

# Build posterior dict with correct shape for ArviZ
posterior_dict = {}
for param in ["sigma_sy_scale", "sigma_sy_exp", "noise", "sigma"]:
    if param in samples:
        vals = np.array(samples[param])
        # Reshape to (num_chains, num_draws)
        if vals.ndim == 1:
            posterior_dict[param] = vals.reshape(num_chains, -1)
        else:
            posterior_dict[param] = vals

# Create InferenceData
idata = az.from_dict(posterior=posterior_dict)

print(f"Created ArviZ InferenceData with {num_chains} chains × {num_samples} samples")
print(f"Parameters: {list(posterior_dict.keys())}")

In [None]:
# ArviZ diagnostic plots (trace, pair, forest, energy, autocorr, rank)
display_arviz_diagnostics(bayes, ["sigma_sy_scale", "sigma_sy_exp"], fast_mode=FAST_MODE)

In [None]:
# ArviZ Summary - comprehensive statistics with R-hat and ESS
summary_df = az.summary(idata, var_names=["sigma_sy_scale", "sigma_sy_exp"],
                        hdi_prob=0.94, round_to=4)
print("ArviZ Summary Statistics:")
print(summary_df.to_string())

In [None]:
# Diagnose divergences: Check if they correlate with the noise parameter
# Divergences often occur in "funnel" geometries where noise → 0
if "noise" in posterior_dict:
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    
    noise_samples = posterior_dict["noise"].flatten()
    scale_samples = posterior_dict["sigma_sy_scale"].flatten()
    exp_samples = posterior_dict["sigma_sy_exp"].flatten()
    
    # Plot noise distribution - check for values near zero
    axes[0].hist(noise_samples, bins=50, density=True, alpha=0.7, edgecolor='white')
    axes[0].axvline(noise_samples.mean(), color='r', ls='--', label=f'Mean: {noise_samples.mean():.2f}')
    axes[0].set_xlabel('Noise (σ)')
    axes[0].set_ylabel('Density')
    axes[0].set_title('Noise Posterior\n(HalfCauchy prior can cause funnels)')
    axes[0].legend()
    
    # Noise vs scale - check for funnel shape
    axes[1].scatter(noise_samples, scale_samples, alpha=0.1, s=1)
    axes[1].set_xlabel('Noise (σ)')
    axes[1].set_ylabel('sigma_sy_scale')
    axes[1].set_title('Noise vs Scale\n(Funnel = divergence source)')
    
    # Noise vs exponent
    axes[2].scatter(noise_samples, exp_samples, alpha=0.1, s=1)
    axes[2].set_xlabel('Noise (σ)')
    axes[2].set_ylabel('sigma_sy_exp')
    axes[2].set_title('Noise vs Exponent')
    
    plt.suptitle('Divergence Diagnosis: Noise Parameter Geometry', y=1.02)
    fig.subplots_adjust(top=0.85, wspace=0.3)
    display(fig)
    plt.close(fig)
    
    print(f"Noise posterior: mean={noise_samples.mean():.3f}, std={noise_samples.std():.3f}")
    print(f"Noise range: [{noise_samples.min():.4f}, {noise_samples.max():.3f}]")
    print(f"\nNote: Divergences likely occur when noise → 0 (funnel geometry)")
    print("The HalfCauchy(scale=10) prior is too vague for only 5 data points.")
else:
    print("Noise parameter not found in samples")

## Residual Analysis

The posterior predictive check above shows the residual structure (observed - predicted). For the SPP power-law model:

- **Random scatter**: Model adequately captures yield stress scaling
- **Systematic trends**: May indicate need for more complex yielding model (e.g., Herschel-Bulkley)
- **Outliers**: Check SPP decomposition quality for those amplitudes

With only 5 data points, visual inspection of residuals is sufficient. For larger datasets, consider quantile-quantile plots or runs tests.

In [None]:
# Posterior predictive check
samples = bayes.posterior_samples
scale_samples = np.array(samples.get("sigma_sy_scale", [])).flatten()
exp_samples = np.array(samples.get("sigma_sy_exp", [])).flatten()

if len(scale_samples) > 0 and len(exp_samples) > 0:
    fig, ax = plt.subplots(figsize=(6, 4))
    
    # Data
    gamma_plot = np.array(gamma_levels)
    ax.scatter(gamma_plot, np.array(sigma_sy), s=80, c='black', zorder=5, label='Data')
    
    # Posterior predictive samples
    gamma_fine = np.linspace(gamma_plot.min() * 0.8, gamma_plot.max() * 1.2, 100)
    n_draws = min(100, len(scale_samples))
    for i in range(n_draws):
        pred = scale_samples[i] * gamma_fine ** exp_samples[i]
        ax.plot(gamma_fine, pred, 'b-', alpha=0.05)
    
    # Mean prediction
    mean_pred = scale_samples.mean() * gamma_fine ** exp_samples.mean()
    ax.plot(gamma_fine, mean_pred, 'r-', lw=2, label='Posterior Mean')
    
    # True values
    true_pred = A_true * gamma_fine ** n_true
    ax.plot(gamma_fine, true_pred, 'g--', lw=2, label='True Model')
    
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xlabel(r'Strain Amplitude $\gamma_0$')
    ax.set_ylabel(r'Static Yield Stress $\sigma_{sy}$ (Pa)')
    ax.set_title('Posterior Predictive Check')
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    display(fig)
    plt.close(fig)

## Further Reading

- **Rogers (2012)**: ["A sequence of physical processes"](https://doi.org/10.1122/1.3662962) — Complete SPP theory and validation
- **RheoJAX SPP Pipeline**: [Documentation](../../docs/source/pipeline/workflows.rst#spp-amplitude-sweep) — Implementation details
- **Hyun et al. (2011)**: ["Nonlinear oscillatory shear review"](https://doi.org/10.1016/j.progpolymsci.2011.02.002) — Context for LAOS methods
- **NumPyro Hierarchical Models**: [Tutorial](https://num.pyro.ai/en/stable/tutorials/bayesian_hierarchical_linear_regression.html) — Understanding noise parameter funnels

## Next Steps

- **[08-spp-laos.ipynb](08-spp-laos.ipynb)**: Detailed SPP theory and single-amplitude decomposition
- **[advanced/10-spp-laos-tutorial.ipynb](../advanced/10-spp-laos-tutorial.ipynb)**: Advanced SPP parameter interpretation
- Apply this workflow to your own LAOS amplitude sweep datasets

**For questions or issues:**
- GitHub: https://github.com/imewei/rheojax/issues
- Documentation: https://rheojax.readthedocs.io

### Key References

- **Rogers, S.A. et al. (2012).** "A sequence of physical processes determined from LAOS." *J. Rheol.* 56:1-25. [Original SPP theory and implementation]
- **Rogers, S.A. & Lettinga, M.P. (2012).** "A sequence of physical processes from LAOS using a simple model." *J. Rheol.* 56:1129-1151. [SPP for complex fluids]
- **Hyun, K. et al. (2011).** "A review of nonlinear oscillatory shear tests." *Prog. Polym. Sci.* 36:1697-1753. [Comprehensive LAOS methods review]

In [None]:
# Final comparison: NLSQ vs Bayesian vs True
print("\n" + "="*60)
print("PARAMETER COMPARISON")
print("="*60)
print(f"{'Parameter':<20} {'True':>10} {'NLSQ':>10} {'Bayesian':>10}")
print("-"*60)
print(f"{'sigma_sy_scale':<20} {A_true:>10.4f} {scale_nlsq:>10.4f} "
      f"{bayes.summary['sigma_sy_scale']['mean']:>10.4f}")
print(f"{'sigma_sy_exp':<20} {n_true:>10.4f} {exp_nlsq:>10.4f} "
      f"{bayes.summary['sigma_sy_exp']['mean']:>10.4f}")
print("="*60)