# VLB LAOS: NLSQ → NUTS

**PNAS Digital Rheometer Twin — large-amplitude oscillatory shear**

## Learning Objectives

- Fit LAOS stress waveform to extract G₀ and k_d
- Understand ODE-based LAOS prediction (diffrax)
- Visualize Lissajous figures from posterior draws

## Data Source

PNAS 2022 Digital Rheometer Twin dataset (LAOS at ω = 1 rad/s).

## Estimated Runtime

- Fast demo (1 chain): ~3 min (ODE integration per sample)
- Full run (4 chains): ~10 min

**Note:** LAOS requires ODE integration via diffrax, making it slower than
analytical protocols.

In [None]:
import sys

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    %pip install -q rheojax openpyxl
    import os
    os.environ["JAX_ENABLE_X64"] = "true"

In [None]:
%matplotlib inline
import time
import warnings
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display

from rheojax.core.jax_config import safe_import_jax, verify_float64
from rheojax.models import VLBLocal

jax, jnp = safe_import_jax()
verify_float64()
warnings.filterwarnings('ignore', category=FutureWarning)

# Tutorial utilities
import sys

sys.path.insert(0, str(Path('..').resolve()))
from utils.vlb_tutorial_utils import (
    get_bayesian_config,
    get_output_dir,
    load_pnas_laos,
    plot_trace_and_forest,
    print_convergence,
    print_parameter_table,
    save_figure,
    save_results,
)

print(f"JAX version: {jax.__version__}")
import os
import sys

sys.path.insert(0, os.path.dirname(os.path.abspath("")))
from utils.plotting_utils import (
    display_arviz_diagnostics,
    plot_nlsq_fit,
    plot_posterior_predictive,
)

## 2. Load Experimental Data

In [None]:
OMEGA = 1.0  # rad/s
STRAIN_IDX = 5  # Medium amplitude

time_data, strain_data, stress_data = load_pnas_laos(
    omega=OMEGA, strain_amplitude_index=STRAIN_IDX, max_points=200
)
GAMMA_0 = float(np.max(np.abs(strain_data)))

print(f'Data points: {len(time_data)}')
print(f'omega = {OMEGA} rad/s')
print(f'gamma_0 = {GAMMA_0:.4f}')
print(f'Time range: {time_data.min():.3f} - {time_data.max():.1f} s')

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

ax1.plot(time_data, stress_data, 'ko', markersize=2)
ax1.set_xlabel('Time [s]')
ax1.set_ylabel('Stress [Pa]')
ax1.set_title('LAOS Stress Waveform')
ax1.grid(True, alpha=0.3)

ax2.plot(strain_data, stress_data, 'ko', markersize=2)
ax2.set_xlabel('Strain [-]')
ax2.set_ylabel('Stress [Pa]')
ax2.set_title('Lissajous Figure (data)')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
display(fig)
plt.close(fig)

## 3. VLB Forward Model

For LAOS, the VLB distribution tensor ODE is integrated numerically:

$$\dot{\boldsymbol{\mu}} = k_d(\mathbf{I} - \boldsymbol{\mu}) + \mathbf{L}(t) \cdot \boldsymbol{\mu} + \boldsymbol{\mu} \cdot \mathbf{L}^T(t)$$

with $\mathbf{L}(t) = \gamma_0 \omega \cos(\omega t)$.

This uses diffrax (Tsit5 + PID controller) for JAX-compatible ODE integration.
At small amplitudes, the response is quasi-linear (single harmonic).

## 4. Step 1: NLSQ Point Estimation

In [None]:
model = VLBLocal()

t0 = time.time()
model.fit(time_data, stress_data, test_mode='laos',
          gamma_0=GAMMA_0, omega=OMEGA, method='scipy')
t_nlsq = time.time() - t0

print(f'NLSQ fit time: {t_nlsq:.2f} s')
print(f'G0  = {model.G0:.1f} Pa')
print(f'k_d = {model.k_d:.4f} 1/s')
print(f'eta_0 = {model.G0 / model.k_d:.1f} Pa*s')

# Predict on data grid
stress_pred = model.predict(time_data, test_mode='laos',
                            gamma_0=GAMMA_0, omega=OMEGA)
ss_res = np.sum((stress_data - np.array(stress_pred))**2)
ss_tot = np.sum((stress_data - np.mean(stress_data))**2)
r2 = 1 - ss_res / ss_tot
print(f'R-squared = {r2:.4f}')

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

ax1.plot(time_data, stress_data, 'ko', markersize=2, label='Data')
ax1.plot(time_data, stress_pred, 'r-', lw=1.5, label='VLB fit')
ax1.set_xlabel('Time [s]')
ax1.set_ylabel('Stress [Pa]')
ax1.set_title('LAOS Waveform Fit')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Lissajous comparison
ax2.plot(strain_data, stress_data, 'ko', markersize=2, label='Data')
strain_pred = GAMMA_0 * np.sin(OMEGA * time_data)
ax2.plot(strain_pred, stress_pred, 'r-', lw=1.5, label='VLB fit')
ax2.set_xlabel('Strain [-]')
ax2.set_ylabel('Stress [Pa]')
ax2.set_title('Lissajous Comparison')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
display(fig)
plt.close(fig)

## 4. Step 2: Bayesian Inference (NUTS)

LAOS Bayesian inference is slower due to ODE integration at each evaluation.

In [None]:
config = get_bayesian_config()
initial_values = {'G0': float(model.G0), 'k_d': float(model.k_d)}
print(f'Config: {config}')
print(f'Warm-start: {initial_values}')

t0 = time.time()
result = model.fit_bayesian(
    time_data, stress_data, test_mode='laos',
    gamma_0=GAMMA_0, omega=OMEGA,
    initial_values=initial_values,
    seed=42,
    **config,
)
t_bayes = time.time() - t0
print(f'\nBayesian inference time: {t_bayes:.1f} s')

## 5. Convergence Diagnostics

In [None]:
param_names = ["G0", "k_d"]
converged = print_convergence(result, param_names)

In [None]:
display_arviz_diagnostics(result, param_names, fast_mode=os.environ.get("FAST_MODE", "1") == "1")

## 6. Posterior Summary

In [None]:
posterior = result.posterior_samples
nlsq_vals = {"G0": model.G0, "k_d": model.k_d}
print_parameter_table(["G0", "k_d"], nlsq_vals, posterior)

## 7. Posterior Predictive Check

In [None]:
posterior = result.posterior_samples
n_draws = min(50, len(posterior['G0']))  # Fewer draws (ODE is slow)

stress_samples = []
for i in range(n_draws):
    model.parameters.set_value('G0', float(posterior['G0'][i]))
    model.parameters.set_value('k_d', float(posterior['k_d'][i]))
    stress_samples.append(np.array(
        model.predict(time_data, test_mode='laos',
                      gamma_0=GAMMA_0, omega=OMEGA)
    ))

stress_arr = np.array(stress_samples)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Waveform PPC
ax1.fill_between(time_data,
    np.percentile(stress_arr, 2.5, axis=0),
    np.percentile(stress_arr, 97.5, axis=0),
    alpha=0.3, color='C0', label='95% CI')
ax1.plot(time_data, np.median(stress_arr, axis=0),
         'C0-', lw=1.5, label='Posterior median')
ax1.plot(time_data, stress_data, 'ko', markersize=2, label='Data')
ax1.set_xlabel('Time [s]')
ax1.set_ylabel('Stress [Pa]')
ax1.set_title('Waveform PPC')
ax1.legend(fontsize=9)
ax1.grid(True, alpha=0.3)

# Lissajous PPC (overlay several draws)
for i in range(min(10, n_draws)):
    strain_i = GAMMA_0 * np.sin(OMEGA * time_data)
    ax2.plot(strain_i, stress_samples[i], '-', color='C0', alpha=0.15, lw=0.8)
ax2.plot(strain_data, stress_data, 'ko', markersize=2, label='Data')
ax2.set_xlabel('Strain [-]')
ax2.set_ylabel('Stress [Pa]')
ax2.set_title('Lissajous PPC')
ax2.legend(fontsize=9)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
display(fig)
plt.close(fig)

## 8. Save Results

In [None]:
output_dir = get_output_dir("laos")
save_results(output_dir, model, result, ["G0", "k_d"], {"gamma_0": GAMMA_0, "omega": OMEGA})
print("Done.")

## Key Takeaways

1. **LAOS requires ODE integration** — slower than analytical protocols
2. **VLB predicts near-elliptical Lissajous** at moderate amplitudes
3. **Bayesian inference works** but is computationally heavier
4. **Nonlinear LAOS features** (distorted Lissajous) require
   Bell or FENE extensions to VLB

## Series Summary

| Protocol | Parameters Identified | Analytical? |
|-|-|-|
| Flow curve | η₀ = G₀/k_d only | Yes |
| Creep | G₀ and k_d | Yes |
| Relaxation | G₀ and k_d | Yes |
| Startup | G₀ and k_d | Yes |
| **SAOS** | **G₀ and k_d (best)** | **Yes** |
| LAOS | G₀ and k_d | No (ODE) |