# HVNM Tutorial 06: LAOS Nonlinear Oscillatory — NLSQ to NUTS

**Fit large-amplitude oscillatory shear with the Hybrid Vitrimer Nanocomposite Model**

Under LAOS ($\gamma(t) = \gamma_0 \sin(\omega t)$), the HVNM captures:
1. **Linear regime** ($\gamma_0 \ll 1$): recovers SAOS G', G''
2. **Nonlinear regime** ($\gamma_0 \sim 1$): TST-accelerated bond exchange distorts the Lissajous curve
3. **Far-nonlinear** ($\gamma_0 \gg 1$): strain-induced softening, dissipation increase

LAOS constrains the same 6 parameters as startup: **G_P, G_E, G_D, nu_0, k_d_D, V_act**,
but with different sensitivity (oscillatory vs monotone loading).

## Dataset
PNAS 2022 Digital Rheometer Twin — LAOS at ω = 1 rad/s

## Estimated Runtime
- NLSQ: ~30 s (ODE per cycle) | NUTS: ~5 min (FAST_MODE) / ~30 min (production)

## 1. Setup

In [None]:
import sys
import time

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    %pip install -q rheojax openpyxl
    import os
    os.environ["JAX_ENABLE_X64"] = "true"

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

from rheojax.core.jax_config import safe_import_jax, verify_float64
from rheojax.models import HVNMLocal

jax, jnp = safe_import_jax()
verify_float64()

sys.path.insert(0, "../..")
from examples.utils.hvnm_tutorial_utils import (
    configure_hvnm_for_fit,
    get_bayesian_config,
    get_fast_mode,
    get_nlsq_values,
    get_output_dir,
    load_pnas_laos,
    plot_fit_comparison,
    plot_ppc,
    plot_trace_and_forest,
    print_convergence,
    print_parameter_table,
    save_figure,
    save_results,
    setup_style,
)

setup_style()
print(f"JAX {jax.__version__}, FAST_MODE: {get_fast_mode()}")
import os
import sys

sys.path.insert(0, os.path.dirname(os.path.abspath("")))
from utils.plotting_utils import (
    display_arviz_diagnostics,
    plot_nlsq_fit,
    plot_posterior_predictive,
)

## 2. Load Data and Apply QC

In [None]:
max_pts = 80 if get_fast_mode() else 300
data = load_pnas_laos(omega=1.0, strain_amplitude_index=5, max_points=max_pts)

print(data.summary())
print(f"\nomega = {data.protocol_kwargs['omega']} rad/s")
print(f"gamma_0 ~ {data.protocol_kwargs['gamma_0']:.4g}")

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 4))

# Time series
ax1.plot(data.x, data.y, '-', lw=1, color='coral', label='Stress')
ax1b = ax1.twinx()
ax1b.plot(data.x, data.y2, '-', lw=0.8, color='steelblue', alpha=0.5, label='Strain')
ax1.set_xlabel('Time [s]')
ax1.set_ylabel(r'$\sigma$ [Pa]', color='coral')
ax1b.set_ylabel(r'$\gamma$', color='steelblue')
ax1.set_title('LAOS Time Series')
ax1.grid(True, alpha=0.3)

# Lissajous (elastic: stress vs strain)
ax2.plot(data.y2, data.y, '-', lw=0.8, color='purple')
ax2.set_xlabel(r'$\gamma$')
ax2.set_ylabel(r'$\sigma$ [Pa]')
ax2.set_title('Elastic Lissajous Curve')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# QC: chop to last complete cycle(s) for clean fitting
omega_val = data.protocol_kwargs['omega']
T_period = 2 * np.pi / omega_val
t = data.x_masked
t_span = t[-1] - t[0]
n_cycles = int(t_span / T_period)
n_fit_cycles = 1 if get_fast_mode() else min(2, n_cycles)

if n_fit_cycles >= 1 and n_cycles >= n_fit_cycles:
    t_start = t[-1] - n_fit_cycles * T_period
    data.mask = data.x >= t_start
    print(f"Using last {n_fit_cycles} cycle(s): t >= {t_start:.2f} s ({data.n_points} points)")
else:
    print(f"Only {n_cycles} complete cycles available. Using all data.")

print(f"Total cycles in data: {n_cycles}")

## 3. Configure HVNM and Fit (NLSQ)

LAOS constrains 6 HVNM parameters (same as startup):
- **G_P, G_E, G_D**: Network moduli
- **nu_0**: BER attempt frequency
- **k_d_D**: D-network dissociation rate
- **V_act**: Stress-activation volume (controls nonlinear distortion)

In [None]:
model = HVNMLocal(kinetics="stress", include_dissociative=True)

# Estimate stress amplitude for initial guesses
sigma_amp = np.max(np.abs(data.y_masked))

fit_params = configure_hvnm_for_fit(
    model,
    protocol="laos",
    overrides={
        "G_P": sigma_amp * 0.2,
        "G_E": sigma_amp * 0.3,
        "G_D": sigma_amp * 0.3,
        "nu_0": 1e9,
        "k_d_D": omega_val,    # Start with D-network time ~ 1/omega
        "V_act": 1e-4,
        "T": 300.0,
        "phi": 0.0,
    },
)
print(f"Fittable: {fit_params}")

t0 = time.time()
model.fit(
    data.x_masked,
    data.y_masked,
    test_mode="laos",
    gamma_0=data.protocol_kwargs['gamma_0'],
    omega=data.protocol_kwargs['omega'],
    max_iter=3000,
    method='scipy',  # diffrax ODE solver incompatible with NLSQ forward-mode AD
)
print(f"NLSQ: {time.time() - t0:.1f} s")

nlsq_vals = get_nlsq_values(model, fit_params)
for p, v in nlsq_vals.items():
    print(f"  {p} = {v:.4g}")

In [None]:
fig = plot_fit_comparison(data, model, title="HVNM LAOS: NLSQ Fit")
save_figure(fig, "hvnm_06_laos_nlsq_fit.png")
plt.show()

In [None]:
# Lissajous comparison: data vs model
y_pred = model.predict(
    data.x_masked, test_mode="laos",
    gamma_0=data.protocol_kwargs['gamma_0'],
    omega=data.protocol_kwargs['omega'],
)
strain = data.y2[data.mask] if data.y2 is not None else None

fig, ax = plt.subplots(figsize=(6, 5))
if strain is not None:
    ax.plot(strain, data.y_masked, '-', lw=1, color='steelblue', alpha=0.6, label='Data')
    ax.plot(strain, y_pred, '--', lw=1.5, color='orangered', label='HVNM fit')
ax.set_xlabel(r'$\gamma$')
ax.set_ylabel(r'$\sigma$ [Pa]')
ax.set_title('Elastic Lissajous: Data vs HVNM')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
save_figure(fig, "hvnm_06_laos_lissajous.png")
plt.show()

## 4. Bayesian Inference (NUTS)

LAOS requires ODE integration at every likelihood evaluation, making NUTS
the most expensive protocol. Use FAST_MODE for exploration.

In [None]:
if get_fast_mode():
    print("FAST_MODE: Skipping NUTS for LAOS (ODE + VJP too memory-intensive for CI).")
    print("NUTS is demonstrated for all other protocols in notebooks 09-13.")
    print("Set FAST_MODE = False in hvnm_tutorial_utils.py for full Bayesian.")
    result = None
else:
    bayes_cfg = get_bayesian_config()
    print(f"Config: {bayes_cfg}")
    t0 = time.time()
    result = model.fit_bayesian(
        data.x_masked,
        data.y_masked,
        test_mode="laos",
        gamma_0=data.protocol_kwargs['gamma_0'],
        omega=data.protocol_kwargs['omega'],
        **bayes_cfg,
    )
    print(f"NUTS: {time.time() - t0:.1f} s")

## 5. Diagnostics and PPC

In [None]:
if result is not None:
    print_convergence(result, fit_params)
    print()
    print_parameter_table(fit_params, nlsq_vals, result.posterior_samples)
else:
    print("NUTS was skipped. Showing NLSQ results only.")
    for p in fit_params:
        print(f"  {p}: {nlsq_vals[p]:.4g}")

In [None]:
if result is not None:
    display_arviz_diagnostics(result, fit_params, fast_mode=get_fast_mode())
else:
    print("Skipped (no NUTS result).")

In [None]:
if result is not None:
    fig = plot_ppc(
        data, model, result.posterior_samples, fit_params,
        title="HVNM LAOS: Posterior Predictive Check",
    )
    save_figure(fig, "hvnm_06_laos_ppc.png")
    plt.show()
else:
    print("Skipped (no NUTS result).")

## 6. Save Results

In [None]:
if result is not None:
    save_results(
        get_output_dir("laos"), model, result,
        param_names=fit_params,
        extra_meta={
            "dataset": "PNAS_DRT",
            "protocol": "laos",
            "omega": data.protocol_kwargs['omega'],
            "gamma_0": data.protocol_kwargs['gamma_0'],
        },
    )
else:
    print("Skipped saving (no NUTS result).")

## What to Change for Your Data

1. **Strain amplitude**: Change `strain_amplitude_index` (0-11) in `load_pnas_laos()`. Higher indices = larger deformation = more nonlinearity
2. **Frequency**: Change `omega` (1, 3, or 5 rad/s) for the PNAS dataset
3. **Cycle selection**: Adjust the cycle chopping logic to keep only steady-state oscillation
4. **Subsampling**: Increase `max_points` for more resolution per cycle, decrease for faster NUTS

## Troubleshooting

- **ODE diverges**: LAOS at high gamma_0 can be stiff. Try reducing gamma_0 or using a more robust ODE solver
- **Lissajous doesn't close**: Transient startup effects. Ensure you're fitting only steady-state cycles (chop early data)
- **V_act unconstrained**: At small gamma_0 (linear regime), V_act has no effect. Use larger strain amplitudes
- **Slow NUTS**: Each posterior draw requires a full ODE solve. Reduce `max_points` or use FAST_MODE
- **Phase mismatch**: Model predicts correct amplitude but wrong phase? Check omega and gamma_0 values match between data and model