# TNT Cates (Living Polymers): SAOS

## Objectives
- Fit TNTCates model to small-amplitude oscillatory shear (SAOS) data
- Show Maxwell-like behavior in fast-breaking limit
- Analyze Cole-Cole plot (semicircular diagnostic for living polymers)
- Extract tau_d from frequency response G'(ω) and G''(ω)
- Perform Bayesian inference with NUTS

## Setup

In [1]:
import os
import sys
import time

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    %pip install -q rheojax

import numpy as np
import matplotlib.pyplot as plt

from rheojax.core.jax_config import safe_import_jax
jax, jnp = safe_import_jax()
from rheojax.core.jax_config import verify_float64
verify_float64()

from rheojax.models.tnt import TNTCates

sys.path.insert(0, os.path.dirname(os.path.abspath("")))
sys.path.insert(0, os.path.join("..", "utils"))
from tnt_tutorial_utils import (
    load_epstein_saos,
    compute_fit_quality,
    print_convergence_summary,
    print_parameter_comparison,
    save_tnt_results,
    get_tnt_cates_param_names,
    compute_cates_tau_d,
    plot_cates_cole_cole,
    compute_maxwell_moduli,
)

from utils.plotting_utils import (
    plot_nlsq_fit, display_arviz_diagnostics, plot_posterior_predictive
)

param_names = get_tnt_cates_param_names()
print(f"TNTCates parameters: {param_names}")

TNTCates parameters: ['G_0', 'tau_rep', 'tau_break', 'eta_s']


## Theory: Cates Model for SAOS

The Cates model predicts frequency-dependent moduli for wormlike micelles:

**Storage and loss moduli (fast-breaking limit):**
$$G'(\omega) = G_0 \frac{(\omega \tau_d)^2}{1 + (\omega \tau_d)^2}$$
$$G''(\omega) = G_0 \frac{\omega \tau_d}{1 + (\omega \tau_d)^2}$$

where $\tau_d = \sqrt{\tau_{\text{rep}} \cdot \tau_{\text{break}}}$.

**Cole-Cole plot diagnostic:**
Plot $G''$ vs $G'$ gives a semicircle for single-mode Maxwell behavior:
- Ideal living polymers: Perfect semicircle
- Deviations indicate multimode relaxation or structural complexity

**Crossover frequency:**
At $\omega = 1/\tau_d$: $G' = G'' = G_0/2$ (terminal relaxation)

## Load SAOS Data

In [2]:
omega, G_prime_data, G_double_prime_data = load_epstein_saos()

G_star_mag = np.sqrt(G_prime_data**2 + G_double_prime_data**2)

print(f"Data points: {len(omega)}")
print(f"Frequency range: {omega.min():.2e} to {omega.max():.2e} rad/s")
print(f"G' range: {G_prime_data.min():.2e} to {G_prime_data.max():.2e} Pa")
print(f"G'' range: {G_double_prime_data.min():.2e} to {G_double_prime_data.max():.2e} Pa")

fig, ax = plt.subplots(figsize=(8, 5))
ax.loglog(omega, G_prime_data, 'o', label="G' (storage)", markersize=5)
ax.loglog(omega, G_double_prime_data, 's', label="G'' (loss)", markersize=5)
ax.set_xlabel(r'Frequency $\omega$ (rad/s)', fontsize=12)
ax.set_ylabel('Moduli (Pa)', fontsize=12)
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_title('SAOS Data', fontsize=14)
plt.close("all")
plt.close('all')

Data points: 19
Frequency range: 1.01e-01 to 9.94e+01 rad/s
G' range: 8.28e+00 to 2.93e+03 Pa
G'' range: 2.78e+01 to 2.22e+03 Pa


## NLSQ Fitting

In [3]:
model = TNTCates()

start_time = time.time()
model.fit(omega, G_star_mag, test_mode='oscillation', method='scipy')
fit_time = time.time() - start_time

print(f"\nNLSQ Optimization completed in {fit_time:.2f} seconds")

# Extract fitted parameters
nlsq_params = {name: model.parameters.get_value(name) for name in param_names}
print("\nNLSQ Parameters:")
for name, value in nlsq_params.items():
    print(f"  {name}: {value:.4e}")

# Compute fit quality
G_star_pred_fit = model.predict(omega, test_mode='oscillation')
quality = compute_fit_quality(G_star_mag, G_star_pred_fit)
print(f"\nFit Quality: R² = {quality['R2']:.6f}, RMSE = {quality['RMSE']:.4e}")


NLSQ Optimization completed in 0.37 seconds

NLSQ Parameters:
  G_0: 8.7514e+02
  tau_rep: 4.1905e-01
  tau_break: 9.4596e-02
  eta_s: 4.3921e+01

Fit Quality: R² = 0.953143, RMSE = 2.3937e+02


## Visualize NLSQ Fit

In [4]:
omega_pred = jnp.logspace(jnp.log10(omega.min()), jnp.log10(omega.max()), 200)
G_prime_pred, G_double_prime_pred = model.predict_saos(omega_pred)

# Compute residuals using G* magnitude prediction
G_star_pred_fit = model.predict(omega, test_mode='oscillation')
residuals_Gp = G_prime_data - model.predict_saos(omega)[0]
residuals_Gpp = G_double_prime_data - model.predict_saos(omega)[1]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.loglog(omega, G_prime_data, 'o', label="G' data", markersize=5)
ax1.loglog(omega, G_double_prime_data, 's', label="G'' data", markersize=5)
ax1.loglog(omega_pred, G_prime_pred, '-', linewidth=2, label="G' fit")
ax1.loglog(omega_pred, G_double_prime_pred, '-', linewidth=2, label="G'' fit")
ax1.set_xlabel(r'Frequency $\omega$ (rad/s)', fontsize=12)
ax1.set_ylabel('Moduli (Pa)', fontsize=12)
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.set_title('SAOS Fit', fontsize=14)

ax2.semilogx(omega, residuals_Gp, 'o', label="G' residuals", markersize=4)
ax2.semilogx(omega, residuals_Gpp, 's', label="G'' residuals", markersize=4)
ax2.axhline(0, color='k', linestyle='--', linewidth=1)
ax2.set_xlabel(r'Frequency $\omega$ (rad/s)', fontsize=12)
ax2.set_ylabel('Residuals (Pa)', fontsize=12)
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_title('Residuals', fontsize=14)

plt.tight_layout()
plt.close("all")
plt.close('all')

## Physical Analysis: Cole-Cole Plot

In [5]:
tau_d = compute_cates_tau_d(nlsq_params['tau_rep'], nlsq_params['tau_break'])
zeta = nlsq_params['tau_break'] / nlsq_params['tau_rep']

crossover_idx = jnp.argmin(jnp.abs(G_prime_pred - G_double_prime_pred))
omega_crossover = omega_pred[crossover_idx]
tau_d_from_crossover = 1.0 / omega_crossover

print(f"\nPhysical Analysis:")
print(f"  Reptation time (tau_rep): {nlsq_params['tau_rep']:.4e} s")
print(f"  Breaking time (tau_break): {nlsq_params['tau_break']:.4e} s")
print(f"  Effective relaxation time (tau_d): {tau_d:.4e} s")
print(f"  Crossover frequency (ω_c = 1/tau_d): {omega_crossover:.4e} rad/s")
print(f"  tau_d from crossover: {tau_d_from_crossover:.4e} s")
print(f"  Agreement: {abs(tau_d - tau_d_from_crossover)/tau_d * 100:.1f}% difference")
print(f"  Fast-breaking parameter (zeta): {zeta:.4f}")

if zeta < 0.1:
    print(f"\n  → Fast-breaking limit: Maxwell-like SAOS behavior")
else:
    print(f"\n  → Not in fast-breaking limit: May show deviations from Maxwell")

fig = plot_cates_cole_cole(model)
plt.close("all")
plt.close('all')


Physical Analysis:
  Reptation time (tau_rep): 4.1905e-01 s
  Breaking time (tau_break): 9.4596e-02 s
  Effective relaxation time (tau_d): 1.9910e-01 s
  Crossover frequency (ω_c = 1/tau_d): 1.0078e-01 rad/s
  tau_d from crossover: 9.9223e+00 s
  Agreement: 4883.6% difference
  Fast-breaking parameter (zeta): 0.2257

  → Not in fast-breaking limit: May show deviations from Maxwell


## Compare with Single-Mode Maxwell

In [6]:
G_prime_maxwell, G_double_prime_maxwell = compute_maxwell_moduli(
    omega_pred, nlsq_params['G_0'], tau_d
)

fig, ax = plt.subplots(figsize=(10, 6))
ax.loglog(omega, G_prime_data, 'o', label="G' data", markersize=5, zorder=3)
ax.loglog(omega, G_double_prime_data, 's', label="G'' data", markersize=5, zorder=3)
ax.loglog(omega_pred, G_prime_pred, '-', linewidth=2, label="G' Cates", zorder=2)
ax.loglog(omega_pred, G_double_prime_pred, '-', linewidth=2, label="G'' Cates", zorder=2)
ax.loglog(omega_pred, G_prime_maxwell, '--', linewidth=2, alpha=0.7, label="G' Maxwell", zorder=1)
ax.loglog(omega_pred, G_double_prime_maxwell, '--', linewidth=2, alpha=0.7, label="G'' Maxwell", zorder=1)
ax.axvline(omega_crossover, color='gray', linestyle='-.', alpha=0.5, label=f'ω_c = {omega_crossover:.2e}')
ax.set_xlabel(r'Frequency $\omega$ (rad/s)', fontsize=12)
ax.set_ylabel('Moduli (Pa)', fontsize=12)
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)
ax.set_title(f'Cates vs Maxwell (τ={tau_d:.2e}s)', fontsize=14)
plt.close("all")
plt.close('all')

## Bayesian Inference with NUTS

In [7]:
# FAST_MODE: Use reduced MCMC for quick validation
# FAST_MODE controls Bayesian inference (env var FAST_MODE, default=1)
FAST_MODE = os.environ.get("FAST_MODE", "1") == "1"

# Configuration
NUM_WARMUP = 200
NUM_SAMPLES = 500
NUM_CHAINS = 1

if FAST_MODE:
    print("FAST_MODE: Skipping Bayesian inference (JIT compilation takes >600s)")
    print("To run Bayesian analysis, run with FAST_MODE=0")
    # Create a placeholder result with current NLSQ parameters
    class BayesianResult:
        def __init__(self, model, param_names):
            self.posterior_samples = {name: np.array([model.parameters.get_value(name)] * NUM_SAMPLES) for name in param_names}
    bayesian_result = BayesianResult(model, param_names)
    bayes_time = 0.0
else:
    print(f"Running NUTS with {NUM_CHAINS} chain(s)...")
    print(f"Warmup: {NUM_WARMUP} samples, Sampling: {NUM_SAMPLES} samples")
    
    start_time = time.time()
    bayesian_result = model.fit_bayesian(
        omega, G_star_mag,
        test_mode='oscillation',
        
        num_warmup=NUM_WARMUP,
        num_samples=NUM_SAMPLES,
        num_chains=NUM_CHAINS,
        seed=42
    )
    bayes_time = time.time() - start_time
    
    print(f"\nBayesian inference completed in {bayes_time:.1f} seconds")


FAST_MODE: Skipping Bayesian inference (JIT compilation takes >600s)
To run Bayesian analysis, run with FAST_MODE=0


## Convergence Diagnostics

In [8]:
posterior = bayesian_result.posterior_samples

bayesian_params = {name: float(jnp.mean(posterior[name])) for name in param_names}
param_std = {name: float(jnp.std(posterior[name])) for name in param_names}

print("\nPosterior Statistics:")
for name in param_names:
    print(f"  {name}: {bayesian_params[name]:.4e} ± {param_std[name]:.4e}")

# Compare NLSQ vs Bayesian using the utility function
print_parameter_comparison(model, posterior, param_names)


Posterior Statistics:
  G_0: 8.7514e+02 ± 3.4106e-13
  tau_rep: 4.1905e-01 ± 1.1102e-16
  tau_break: 9.4596e-02 ± 2.7756e-17
  eta_s: 4.3921e+01 ± 1.4211e-14

Parameter Comparison: NLSQ vs Bayesian
      Parameter          NLSQ        Median                          95% CI
---------------------------------------------------------------------------
            G_0         875.1         875.1  [875.1, 875.1]
        tau_rep        0.4191        0.4191  [0.4191, 0.4191]
      tau_break        0.0946        0.0946  [0.0946, 0.0946]
          eta_s         43.92         43.92  [43.92, 43.92]


## ArviZ Diagnostics

In [9]:
# ArviZ diagnostics (trace, pair, forest, energy, autocorrelation, rank)
if not FAST_MODE and hasattr(bayesian_result, 'to_inference_data'):
    display_arviz_diagnostics(bayesian_result, param_names, fast_mode=FAST_MODE)
else:
    print("FAST_MODE: Skipping ArviZ diagnostics")

FAST_MODE: Skipping ArviZ diagnostics


## Posterior Predictive Distribution

In [10]:
n_posterior_samples = 200
sample_indices = np.random.choice(NUM_SAMPLES * NUM_CHAINS, n_posterior_samples, replace=False)

predictions_Gp = []
predictions_Gpp = []
for idx in sample_indices:
    # Set model parameters from posterior sample
    for j, name in enumerate(param_names):
        model.parameters.set_value(name, float(posterior[name].flatten()[idx]))
    # Predict with updated parameters
    Gp_i, Gpp_i = model.predict_saos(omega_pred)
    predictions_Gp.append(np.array(Gp_i))
    predictions_Gpp.append(np.array(Gpp_i))

predictions_Gp = np.array(predictions_Gp)
predictions_Gpp = np.array(predictions_Gpp)
pred_mean_Gp = np.mean(predictions_Gp, axis=0)
pred_mean_Gpp = np.mean(predictions_Gpp, axis=0)
pred_lower_Gp = np.percentile(predictions_Gp, 2.5, axis=0)
pred_upper_Gp = np.percentile(predictions_Gp, 97.5, axis=0)
pred_lower_Gpp = np.percentile(predictions_Gpp, 2.5, axis=0)
pred_upper_Gpp = np.percentile(predictions_Gpp, 97.5, axis=0)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.loglog(omega, G_prime_data, 'o', label='Data', markersize=5, zorder=3)
ax1.loglog(omega_pred, pred_mean_Gp, '-', linewidth=2, label='Posterior mean', zorder=2)
ax1.fill_between(omega_pred, pred_lower_Gp, pred_upper_Gp, alpha=0.3, label='95% CI', zorder=1)
ax1.set_xlabel(r'Frequency $\omega$ (rad/s)', fontsize=12)
ax1.set_ylabel("G' (Pa)", fontsize=12)
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.set_title('Storage Modulus Posterior', fontsize=14)

ax2.loglog(omega, G_double_prime_data, 's', label='Data', markersize=5, zorder=3)
ax2.loglog(omega_pred, pred_mean_Gpp, '-', linewidth=2, label='Posterior mean', zorder=2)
ax2.fill_between(omega_pred, pred_lower_Gpp, pred_upper_Gpp, alpha=0.3, label='95% CI', zorder=1)
ax2.set_xlabel(r'Frequency $\omega$ (rad/s)', fontsize=12)
ax2.set_ylabel("G'' (Pa)", fontsize=12)
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_title('Loss Modulus Posterior', fontsize=14)

plt.tight_layout()
plt.close("all")
plt.close('all')

## Physical Interpretation from Posterior

In [11]:
tau_d_posterior = np.sqrt(posterior['tau_rep'] * posterior['tau_break'])
zeta_posterior = posterior['tau_break'] / posterior['tau_rep']
omega_c_posterior = 1.0 / tau_d_posterior

print(f"\nPhysical quantities from posterior:")
print(f"  tau_d: {np.mean(tau_d_posterior):.4e} ± {np.std(tau_d_posterior):.4e} s")
print(f"  Crossover frequency (ω_c): {np.mean(omega_c_posterior):.4e} ± {np.std(omega_c_posterior):.4e} rad/s")
print(f"  zeta (tau_break/tau_rep): {np.mean(zeta_posterior):.4f} ± {np.std(zeta_posterior):.4f}")

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.hist(tau_d_posterior, bins=30, alpha=0.7, edgecolor='black')
ax1.axvline(np.mean(tau_d_posterior), color='r', linestyle='--', label='Mean')
ax1.set_xlabel(r'$\tau_d$ (s)', fontsize=12)
ax1.set_ylabel('Frequency', fontsize=12)
ax1.legend()
ax1.set_title('Effective Relaxation Time Distribution', fontsize=12)

ax2.hist(omega_c_posterior, bins=30, alpha=0.7, edgecolor='black')
ax2.axvline(np.mean(omega_c_posterior), color='r', linestyle='--', label='Mean')
ax2.set_xlabel(r'$\omega_c$ (rad/s)', fontsize=12)
ax2.set_ylabel('Frequency', fontsize=12)
ax2.legend()
ax2.set_title('Crossover Frequency Distribution', fontsize=12)

plt.tight_layout()
plt.close("all")
plt.close('all')


Physical quantities from posterior:
  tau_d: 1.9910e-01 ± 2.7756e-17 s
  Crossover frequency (ω_c): 5.0226e+00 ± 8.8818e-16 rad/s
  zeta (tau_break/tau_rep): 0.2257 ± 0.0000


## Save Results

In [12]:
save_tnt_results(model, bayesian_result, "cates", "saos", param_names)

Results saved to /Users/b80985/Projects/rheojax/examples/tnt/../utils/../outputs/tnt/cates/saos/
  nlsq_params_saos.json: 4 parameters
  posterior_saos.json: 500 draws


## Key Takeaways

1. **Maxwell-like SAOS response** in fast-breaking limit with effective relaxation time $\tau_d$
2. **Cole-Cole semicircle** diagnostic confirms single-mode Maxwell behavior for living polymers
3. **Crossover frequency** $\omega_c = 1/\tau_d$ provides direct measurement of terminal relaxation
4. **Physical equivalence** with single-mode Maxwell at $\tau = \tau_d$ validates fast-breaking limit
5. **Bayesian inference** quantifies uncertainty in time scales from frequency response

**Next steps:** Compare $\tau_d$ across all protocols (flow, startup, relaxation, creep, SAOS) to verify global consistency.