# TNT Cates (Living Polymers): LAOS

## Objectives
- Fit TNTCates model to large-amplitude oscillatory shear (LAOS) data
- Analyze nonlinear response and shear banding threshold
- Extract harmonics via FFT for Lissajous analysis
- Show stress plateau at high strain amplitudes
- Perform Bayesian inference with NUTS

## Setup

In [None]:
import os
import sys
import time

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    %pip install -q rheojax

import numpy as np
import matplotlib.pyplot as plt
import arviz as az
from scipy.fft import fft, fftfreq

from rheojax.core.jax_config import safe_import_jax
jax, jnp = safe_import_jax()
from rheojax.core.jax_config import verify_float64
verify_float64()

from rheojax.models.tnt import TNTCates

sys.path.insert(0, os.path.join("..", "utils"))
from tnt_tutorial_utils import (
    load_pnas_laos,
    compute_fit_quality,
    print_convergence_summary,
    print_parameter_comparison,
    save_tnt_results,
    get_tnt_cates_param_names,
    compute_cates_tau_d,
)

param_names = get_tnt_cates_param_names()
print(f"TNTCates parameters: {param_names}")

## Theory: Cates Model for LAOS

The Cates model predicts nonlinear oscillatory response in living polymers:

**LAOS protocol:**
- Applied strain: $\gamma(t) = \gamma_0 \sin(\omega t)$
- Measure stress: $\sigma(t)$ (nonlinear waveform)

**Shear banding threshold:**
When $\sigma$ exceeds the maximum stress in the flow curve:
$$\sigma_{\text{max}} = \frac{2G_0\tau_d}{3\sqrt{3}}$$

the material may undergo transient shear banding.

**Nonlinear signatures:**
- Stress plateau at high strain amplitudes
- Odd harmonics in FFT (3rd, 5th, 7th, ...)
- Distorted Lissajous curves (stress vs strain)

**Physical interpretation:**
- Higher harmonics indicate departure from linear viscoelasticity
- Cage breaking and reformation dynamics during oscillation

## Load LAOS Data

In [None]:
time_data, stress, strain, omega_laos, gamma_0 = load_pnas_laos(omega=1.0, strain_amplitude_index=5)

print(f"Data points: {len(time_data)}")
print(f"Time range: {time_data.min():.2f} to {time_data.max():.2f} s")
print(f"Applied frequency: {omega_laos:.3f} rad/s")
print(f"Applied strain amplitude: {gamma_0:.3f}")
print(f"Stress range: {stress.min():.2f} to {stress.max():.2f} Pa")

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(time_data, stress, '-', linewidth=1, label='Stress')
ax1.set_xlabel('Time (s)', fontsize=12)
ax1.set_ylabel('Stress (Pa)', fontsize=12)
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.set_title('LAOS Stress Time Series', fontsize=14)

ax2.plot(strain, stress, '-', linewidth=1)
ax2.set_xlabel('Strain γ', fontsize=12)
ax2.set_ylabel('Stress σ (Pa)', fontsize=12)
ax2.grid(True, alpha=0.3)
ax2.set_title('Lissajous Curve', fontsize=14)

plt.tight_layout()
display(fig)
plt.close(fig)

## NLSQ Fitting

In [None]:
model = TNTCates()

start_time = time.time()
model.fit(time_data, stress, test_mode='laos', omega=omega_laos, gamma_0=gamma_0, method='scipy')
fit_time = time.time() - start_time

print(f"\nNLSQ Optimization completed in {fit_time:.2f} seconds")

# Extract fitted parameters
nlsq_params = {name: model.parameters.get_value(name) for name in param_names}
print("\nNLSQ Parameters:")
for name, value in nlsq_params.items():
    print(f"  {name}: {value:.4e}")

# Compute fit quality
stress_pred_fit = model.predict(time_data, test_mode='laos', omega=omega_laos, gamma_0=gamma_0)
quality = compute_fit_quality(stress, stress_pred_fit)
print(f"\nFit Quality: R² = {quality['R2']:.6f}, RMSE = {quality['RMSE']:.4e}")

## Visualize NLSQ Fit

In [None]:
time_pred = jnp.linspace(time_data.min(), time_data.max(), 500)
stress_pred = model.predict(time_pred, test_mode='laos', omega=omega_laos, gamma_0=gamma_0)

strain_pred = gamma_0 * jnp.sin(omega_laos * time_pred)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(time_data, stress, 'o', label='Data', markersize=3, alpha=0.5)
ax1.plot(time_pred, stress_pred, '-', linewidth=2, label='TNTCates fit')
ax1.set_xlabel('Time (s)', fontsize=12)
ax1.set_ylabel('Stress (Pa)', fontsize=12)
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.set_title('LAOS Time Series Fit', fontsize=14)

ax2.plot(strain, stress, 'o', label='Data', markersize=3, alpha=0.5)
ax2.plot(strain_pred, stress_pred, '-', linewidth=2, label='TNTCates fit')
ax2.set_xlabel('Strain γ', fontsize=12)
ax2.set_ylabel('Stress σ (Pa)', fontsize=12)
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_title('Lissajous Curve Fit', fontsize=14)

plt.tight_layout()
display(fig)
plt.close(fig)

## Physical Analysis: Shear Banding Threshold

In [None]:
tau_d = compute_cates_tau_d(nlsq_params['tau_rep'], nlsq_params['tau_break'])
zeta = nlsq_params['tau_break'] / nlsq_params['tau_rep']

sigma_max_theory = (2 * nlsq_params['G_0'] * tau_d) / (3 * np.sqrt(3))
sigma_max_observed = np.max(np.abs(stress_pred))

print(f"\nPhysical Analysis:")
print(f"  Reptation time (tau_rep): {nlsq_params['tau_rep']:.4e} s")
print(f"  Breaking time (tau_break): {nlsq_params['tau_break']:.4e} s")
print(f"  Effective relaxation time (tau_d): {tau_d:.4e} s")
print(f"  Fast-breaking parameter (zeta): {zeta:.4f}")
print(f"\nShear Banding Analysis:")
print(f"  Theoretical max stress (σ_max): {sigma_max_theory:.2f} Pa")
print(f"  Observed max stress: {sigma_max_observed:.2f} Pa")
print(f"  Stress ratio (σ_obs/σ_max): {sigma_max_observed/sigma_max_theory:.3f}")

if sigma_max_observed > sigma_max_theory:
    print(f"  ⚠ Stress exceeds σ_max: Potential transient shear banding")
else:
    print(f"  ✓ Stress below σ_max: No shear banding expected")

## FFT Harmonics Analysis

In [None]:
dt = time_pred[1] - time_pred[0]
stress_fft = fft(np.array(stress_pred))
freqs = fftfreq(len(time_pred), float(dt))

positive_mask = freqs > 0
freqs_pos = freqs[positive_mask]
amplitudes = 2.0 * np.abs(stress_fft[positive_mask]) / len(time_pred)

fundamental_freq = omega_laos / (2 * np.pi)
harmonic_indices = []
harmonic_amplitudes = []
for n in [1, 3, 5, 7]:
    target_freq = n * fundamental_freq
    idx = np.argmin(np.abs(freqs_pos - target_freq))
    harmonic_indices.append(n)
    harmonic_amplitudes.append(amplitudes[idx])

print(f"\nFFT Harmonic Analysis:")
print(f"  Fundamental frequency: {fundamental_freq:.4f} Hz")
for n, amp in zip(harmonic_indices, harmonic_amplitudes):
    print(f"  Harmonic {n}: {amp:.2f} Pa ({amp/harmonic_amplitudes[0]*100:.1f}% of fundamental)")

fig, ax = plt.subplots(figsize=(10, 5))
ax.stem([n * fundamental_freq for n in harmonic_indices], harmonic_amplitudes, basefmt=' ')
ax.set_xlabel('Frequency (Hz)', fontsize=12)
ax.set_ylabel('Amplitude (Pa)', fontsize=12)
ax.set_xlim(0, 8 * fundamental_freq)
ax.grid(True, alpha=0.3)
ax.set_title('LAOS Harmonics', fontsize=14)
display(fig)
plt.close(fig)

## Bayesian Inference with NUTS

In [None]:
NUM_WARMUP = 200
NUM_SAMPLES = 500
NUM_CHAINS = 1

print(f"Starting Bayesian inference with NUTS...")
print(f"Warmup: {NUM_WARMUP}, Samples: {NUM_SAMPLES}, Chains: {NUM_CHAINS}")

start_time = time.time()
bayesian_result = model.fit_bayesian(
    time_data, stress,
    test_mode='laos',
    omega=omega_laos,
    gamma_0=gamma_0,
    num_warmup=NUM_WARMUP,
    num_samples=NUM_SAMPLES,
    num_chains=NUM_CHAINS,
    seed=42
)
bayes_time = time.time() - start_time

print(f"\nBayesian inference completed in {bayes_time:.2f} seconds")
print(f"Time per sample: {bayes_time/(NUM_WARMUP + NUM_SAMPLES):.3f} seconds")

## Convergence Diagnostics

In [None]:
posterior = bayesian_result.posterior_samples

bayesian_params = {name: float(jnp.mean(posterior[name])) for name in param_names}
param_std = {name: float(jnp.std(posterior[name])) for name in param_names}

print("\nPosterior Statistics:")
for name in param_names:
    print(f"  {name}: {bayesian_params[name]:.4e} ± {param_std[name]:.4e}")

# Compare NLSQ vs Bayesian using the utility function
print_parameter_comparison(model, posterior, param_names)

## ArviZ Trace Plot

In [None]:
idata = az.from_dict(posterior={k: v.reshape(NUM_CHAINS, NUM_SAMPLES) for k, v in posterior.items()})

fig = az.plot_trace(idata, var_names=param_names, compact=False, figsize=(12, 8))
plt.suptitle('MCMC Trace Plots', fontsize=14, y=1.0)
plt.tight_layout()
display(fig)
plt.close(fig)

## ArviZ Pair Plot

In [None]:
fig = az.plot_pair(
    idata,
    var_names=param_names,
    kind='kde',
    marginals=True,
    figsize=(10, 10)
)
plt.suptitle('Posterior Correlations', fontsize=14, y=0.995)
display(fig)
plt.close(fig)

## ArviZ Forest Plot

In [None]:
fig = az.plot_forest(
    idata,
    var_names=param_names,
    combined=True,
    hdi_prob=0.95,
    figsize=(8, 4)
)
plt.suptitle('95% Credible Intervals', fontsize=14)
display(fig)
plt.close(fig)

## Posterior Predictive Distribution

In [None]:
n_posterior_samples = 200
sample_indices = np.random.choice(NUM_SAMPLES * NUM_CHAINS, n_posterior_samples, replace=False)

predictions = []
for idx in sample_indices:
    # Set model parameters from posterior sample
    for j, name in enumerate(param_names):
        model.parameters.set_value(name, float(posterior[name].flatten()[idx]))
    # Predict with updated parameters
    pred_i = model.predict(time_pred, test_mode='laos', omega=omega_laos, gamma_0=gamma_0)
    predictions.append(np.array(pred_i))

predictions = np.array(predictions)
pred_mean = np.mean(predictions, axis=0)
pred_lower = np.percentile(predictions, 2.5, axis=0)
pred_upper = np.percentile(predictions, 97.5, axis=0)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(time_data, stress, 'o', label='Data', markersize=3, alpha=0.5, zorder=3)
ax1.plot(time_pred, pred_mean, '-', linewidth=2, label='Posterior mean', zorder=2)
ax1.fill_between(time_pred, pred_lower, pred_upper, alpha=0.3, label='95% CI', zorder=1)
ax1.set_xlabel('Time (s)', fontsize=12)
ax1.set_ylabel('Stress (Pa)', fontsize=12)
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.set_title('Posterior Predictive (Time Series)', fontsize=14)

ax2.plot(strain, stress, 'o', label='Data', markersize=3, alpha=0.5, zorder=3)
ax2.plot(strain_pred, pred_mean, '-', linewidth=2, label='Posterior mean', zorder=2)
ax2.set_xlabel('Strain γ', fontsize=12)
ax2.set_ylabel('Stress σ (Pa)', fontsize=12)
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_title('Posterior Predictive (Lissajous)', fontsize=14)

plt.tight_layout()
display(fig)
plt.close(fig)

## Physical Interpretation from Posterior

In [None]:
tau_d_posterior = np.sqrt(posterior['tau_rep'] * posterior['tau_break'])
zeta_posterior = posterior['tau_break'] / posterior['tau_rep']
sigma_max_posterior = (2 * posterior['G_0'] * tau_d_posterior) / (3 * np.sqrt(3))

print(f"\nPhysical quantities from posterior:")
print(f"  tau_d: {np.mean(tau_d_posterior):.4e} ± {np.std(tau_d_posterior):.4e} s")
print(f"  zeta (tau_break/tau_rep): {np.mean(zeta_posterior):.4f} ± {np.std(zeta_posterior):.4f}")
print(f"  σ_max (shear banding threshold): {np.mean(sigma_max_posterior):.2f} ± {np.std(sigma_max_posterior):.2f} Pa")

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.hist(tau_d_posterior, bins=30, alpha=0.7, edgecolor='black')
ax1.axvline(np.mean(tau_d_posterior), color='r', linestyle='--', label='Mean')
ax1.set_xlabel(r'$\tau_d$ (s)', fontsize=12)
ax1.set_ylabel('Frequency', fontsize=12)
ax1.legend()
ax1.set_title('Effective Relaxation Time Distribution', fontsize=12)

ax2.hist(sigma_max_posterior, bins=30, alpha=0.7, edgecolor='black')
ax2.axvline(np.mean(sigma_max_posterior), color='r', linestyle='--', label='Mean')
ax2.axvline(sigma_max_observed, color='g', linestyle=':', label='Observed max')
ax2.set_xlabel(r'$\sigma_{max}$ (Pa)', fontsize=12)
ax2.set_ylabel('Frequency', fontsize=12)
ax2.legend()
ax2.set_title('Shear Banding Threshold Distribution', fontsize=12)

plt.tight_layout()
display(fig)
plt.close(fig)

## Save Results

In [None]:
save_tnt_results(model, bayesian_result, "cates", "laos", param_names)

## Key Takeaways

1. **Nonlinear LAOS response** shows stress plateau and odd harmonics for living polymers
2. **Shear banding threshold** $\sigma_{\text{max}} = 2G_0\tau_d/(3\sqrt{3})$ predicts when transient banding occurs
3. **Odd harmonics** (3rd, 5th, 7th) indicate departure from linear viscoelasticity
4. **Lissajous distortion** reveals nonlinear cage dynamics under large-amplitude oscillation
5. **Bayesian inference** quantifies uncertainty in parameters from complex LAOS waveforms

**Cross-protocol validation:** Compare extracted $\tau_d$ with flow curve, startup, relaxation, creep, and SAOS to ensure global consistency of Cates model parameters.