# STZ Linear Viscoelastic Spectrum (SAOS)

**Shear Transformation Zone model — Small-Amplitude Oscillatory Shear**

## Learning Objectives

- Understand the STZ linear viscoelastic (Maxwell-like) approximation
- Fit G'(omega) and G''(omega) from polystyrene oscillation data
- Interpret the effective relaxation time tau_eff from STZ activation
- Analyze crossover frequency and Cole-Cole representation
- Perform Bayesian inference and assess identifiability of 6 parameters

## Prerequisites

- Notebook 01 (STZ flow curve) for basic STZ concepts
- Understanding of G', G'' frequency-domain data

## Estimated Runtime

- Fast demo (1 chain): ~1-2 min
- Full run (4 chains): ~4-6 min

## 1. Setup

In [None]:
# Colab setup
import sys

IN_COLAB = "google.colab" in sys.modules

if IN_COLAB:
    %pip install -q rheojax
    import os
    os.environ["JAX_ENABLE_X64"] = "true"
    print("RheoJAX installed successfully.")

In [None]:
%matplotlib inline
import os
import time
import warnings

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display

from rheojax.core.jax_config import safe_import_jax, verify_float64
from rheojax.models.stz import STZConventional

jax, jnp = safe_import_jax()
verify_float64()

warnings.filterwarnings("ignore", category=FutureWarning)
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")

## 2. Theory: SAOS from STZ Activation

In the **linear viscoelastic limit** (small strain amplitude), the STZ model reduces to a **Maxwell model** with an effective relaxation time determined by STZ activation.

At steady state, $\chi \to \chi_{\infty}$ and the STZ density is:

$$\Lambda_{\text{ss}} = \exp(-e_z / \chi_{\infty})$$

The effective Maxwell relaxation time is:

$$\tau_{\text{eff}} = \frac{\tau_0}{2 \epsilon_0 \Lambda_{\text{ss}}}$$

The dynamic moduli then follow:

$$G'(\omega) = G_0 \frac{(\omega \tau_{\text{eff}})^2}{1 + (\omega \tau_{\text{eff}})^2}$$

$$G''(\omega) = G_0 \frac{\omega \tau_{\text{eff}}}{1 + (\omega \tau_{\text{eff}})^2}$$

### Parameters for SAOS (6 total)

| Parameter | Role in SAOS |
|-----------|-------------|
| G0 | Sets the high-frequency plateau modulus |
| sigma_y | Enters indirectly (not strongly constrained by SAOS) |
| chi_inf | Controls Lambda_ss and thus tau_eff |
| tau0 | Attempt time — sets absolute timescale |
| epsilon0 | Strain increment — scales tau_eff |
| ez | Formation energy — exponentially controls Lambda_ss |

**Limitation:** This is a single-relaxation-time approximation. Real polymers near T_g have a broad spectrum, so deviations from single-Maxwell behavior are expected.

### Material-Model Compatibility

We use **polystyrene at 145 C** — close to its glass transition temperature ($T_g \approx 100$ C). Near $T_g$, polystyrene is an amorphous solid where relaxation is governed by cooperative segmental motions analogous to STZ rearrangements. The single-Maxwell approximation captures the dominant relaxation mode but will miss the high-frequency wing from faster local motions.

## 3. Load Data

In [None]:
from stz_tutorial_utils import load_polystyrene_oscillation

omega, G_prime, G_double_prime = load_polystyrene_oscillation(temp=145)

print(f"Data points: {len(omega)}")
print(f"Frequency range: {omega.min():.1f} - {omega.max():.1f} rad/s")
print(f"G' range: {G_prime.min():.0f} - {G_prime.max():.0f} Pa")
print(f"G'' range: {G_double_prime.min():.0f} - {G_double_prime.max():.0f} Pa")

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))
ax.loglog(omega, G_prime, "s", markersize=6, color="C0", label="G' (storage)")
ax.loglog(omega, G_double_prime, "o", markersize=6, color="C1", label="G'' (loss)")
ax.set_xlabel("Angular frequency [rad/s]")
ax.set_ylabel("Modulus [Pa]")
ax.set_title("Polystyrene PS145 — SAOS")
ax.legend()
ax.grid(True, alpha=0.3, which="both")
plt.tight_layout()
display(fig)
plt.close(fig)

## 4. NLSQ Fitting

In [None]:
from stz_tutorial_utils import compute_fit_quality

# Prepare y data: interleaved [G', G''] as expected by the model
G_star = np.column_stack([G_prime, G_double_prime])

model = STZConventional(variant="standard")

# Set bounds BEFORE values — use set_bounds() to update both bounds and constraints
model.parameters.set_bounds("G0", (1e4, 1e8))
model.parameters["G0"].value = 1e6
model.parameters.set_bounds("sigma_y", (1e3, 1e8))
model.parameters["sigma_y"].value = 1e5
model.parameters.set_bounds("chi_inf", (0.02, 0.5))
model.parameters["chi_inf"].value = 0.15
model.parameters.set_bounds("tau0", (1e-12, 1e0))
model.parameters["tau0"].value = 1e-6
model.parameters.set_bounds("epsilon0", (0.01, 1.0))
model.parameters["epsilon0"].value = 0.1
model.parameters.set_bounds("ez", (0.1, 5.0))
model.parameters["ez"].value = 1.0

t0 = time.time()
model.fit(omega, G_star, test_mode="oscillation", use_log_residuals=True)
t_nlsq = time.time() - t0

# Compute fit quality for G' and G'' combined
G_pred_at_data = model.predict(omega)
G_all_data = np.concatenate([G_prime, G_double_prime])
G_all_pred = np.concatenate([G_pred_at_data[:, 0], G_pred_at_data[:, 1]])
quality = compute_fit_quality(G_all_data, G_all_pred)

print(f"NLSQ fit time: {t_nlsq:.2f} s")
print(f"R-squared: {quality['r_squared']:.6f}")
print(f"RMSE: {quality['rmse']:.1f} Pa")
print("\nFitted parameters:")
saos_params = ["G0", "sigma_y", "chi_inf", "tau0", "epsilon0", "ez"]
for name in saos_params:
    val = model.parameters.get_value(name)
    print(f"  {name:10s} = {val:.4g}")

### 4.1 Fit Quality and Crossover Analysis

In [None]:
omega_fine = np.logspace(
    np.log10(omega.min()) - 0.3,
    np.log10(omega.max()) + 0.3,
    200,
)
G_pred = model.predict(omega_fine)
G_prime_pred = G_pred[:, 0]
G_double_prime_pred = G_pred[:, 1]

# Find crossover frequency
chi_inf_fit = model.parameters.get_value("chi_inf")
ez_fit = model.parameters.get_value("ez")
tau0_fit = model.parameters.get_value("tau0")
eps0_fit = model.parameters.get_value("epsilon0")
Lambda_ss = np.exp(-ez_fit / chi_inf_fit)
tau_eff = tau0_fit / (2.0 * eps0_fit * Lambda_ss)
omega_c = 1.0 / tau_eff

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Left: G', G'' fit
ax1.loglog(omega, G_prime, "s", markersize=6, color="C0", label="G' data")
ax1.loglog(omega, G_double_prime, "o", markersize=6, color="C1", label="G'' data")
ax1.loglog(omega_fine, G_prime_pred, "-", lw=2, color="C0", label="G' fit")
ax1.loglog(omega_fine, G_double_prime_pred, "--", lw=2, color="C1", label="G'' fit")
ax1.axvline(omega_c, color="gray", linestyle=":", alpha=0.7, label=f"$\\omega_c$ = {omega_c:.2g} rad/s")
ax1.set_xlabel("Angular frequency [rad/s]")
ax1.set_ylabel("Modulus [Pa]")
ax1.set_title("STZ SAOS Fit")
ax1.legend(fontsize=8)
ax1.grid(True, alpha=0.3, which="both")

# Right: residuals
G_pred_at_data = model.predict(omega)
res_Gp = (G_prime - G_pred_at_data[:, 0]) / G_prime * 100
res_Gpp = (G_double_prime - G_pred_at_data[:, 1]) / G_double_prime * 100

ax2.semilogx(omega, res_Gp, "s-", markersize=4, alpha=0.7, label="G' residual")
ax2.semilogx(omega, res_Gpp, "o-", markersize=4, alpha=0.7, label="G'' residual")
ax2.axhline(0, color="black", linestyle="--", alpha=0.5)
ax2.set_xlabel("Angular frequency [rad/s]")
ax2.set_ylabel("Relative residual [%]")
ax2.set_title("Residual Analysis")
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
display(fig)
plt.close(fig)

print(f"\nEffective relaxation time: tau_eff = {tau_eff:.4g} s")
print(f"Crossover frequency: omega_c = {omega_c:.4g} rad/s")
print(f"Lambda_ss = exp(-ez/chi_inf) = {Lambda_ss:.4g}")

### 4.2 Cole-Cole Plot

In [None]:
fig, ax = plt.subplots(figsize=(7, 7))
ax.plot(G_prime, G_double_prime, "ko", markersize=6, label="Data")
ax.plot(G_prime_pred, G_double_prime_pred, "-", lw=2, color="C0", label="STZ fit")
ax.set_xlabel("G' [Pa]")
ax.set_ylabel("G'' [Pa]")
ax.set_title("Cole-Cole Plot")
ax.set_aspect("equal")
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
display(fig)
plt.close(fig)

A single Maxwell element produces a perfect semicircle in Cole-Cole space. Deviations indicate the need for a broader relaxation spectrum (e.g., Generalized Maxwell / KWW), which STZ's single-mode approximation cannot capture.

## 5. Bayesian Inference

### 5.1 Run NUTS

In [None]:
initial_values = {
    name: model.parameters.get_value(name)
    for name in model.parameters.keys()
}

# --- Fast demo config ---
NUM_WARMUP = 200
NUM_SAMPLES = 500
NUM_CHAINS = 1
# NUM_WARMUP = 1000; NUM_SAMPLES = 2000; NUM_CHAINS = 4  # production

t0 = time.time()
result = model.fit_bayesian(
    omega,
    G_star,
    test_mode="oscillation",
    num_warmup=NUM_WARMUP,
    num_samples=NUM_SAMPLES,
    num_chains=NUM_CHAINS,
    initial_values=initial_values,
    seed=42,
)
t_bayes = time.time() - t0
print(f"Bayesian inference time: {t_bayes:.1f} s")

### 5.2 Convergence Diagnostics

In [None]:
from stz_tutorial_utils import print_convergence_summary

print_convergence_summary(result, saos_params)

### 5.3 ArviZ Plots

In [None]:
idata = result.to_inference_data()

axes = az.plot_trace(idata, var_names=saos_params, figsize=(12, 12))
fig = axes.ravel()[0].figure
fig.suptitle("Trace Plots", fontsize=14, y=1.02)
plt.tight_layout()
display(fig)
plt.close(fig)

In [None]:
axes = az.plot_pair(
    idata,
    var_names=saos_params,
    kind="scatter",
    divergences=True,
    figsize=(12, 12),
)
fig = axes.ravel()[0].figure
fig.suptitle("Parameter Correlations", fontsize=14, y=1.02)
plt.tight_layout()
display(fig)
plt.close(fig)

In [None]:
axes = az.plot_forest(
    idata,
    var_names=saos_params,
    combined=True,
    hdi_prob=0.95,
    figsize=(10, 5),
)
fig = axes.ravel()[0].figure
plt.tight_layout()
display(fig)
plt.close(fig)

### 5.4 Parameter Summary

In [None]:
from stz_tutorial_utils import print_parameter_comparison

posterior = result.posterior_samples
print_parameter_comparison(model, posterior, saos_params)

### 5.5 Posterior Predictive Check

In [None]:
n_draws = min(200, len(list(posterior.values())[0]))
omega_pred = np.logspace(
    np.log10(omega.min()) - 0.3,
    np.log10(omega.max()) + 0.3,
    100,
)
omega_pred_jax = jnp.asarray(omega_pred, dtype=jnp.float64)

Gp_samples = []
Gpp_samples = []
for i in range(n_draws):
    pred_i = STZConventional._predict_saos_jit(
        omega_pred_jax,
        posterior["G0"][i],
        posterior["sigma_y"][i],
        posterior["chi_inf"][i],
        posterior["tau0"][i],
        posterior["epsilon0"][i],
        posterior["ez"][i],
    )
    pred_arr = np.array(pred_i)
    Gp_samples.append(pred_arr[:, 0])
    Gpp_samples.append(pred_arr[:, 1])

Gp_samples = np.array(Gp_samples)
Gpp_samples = np.array(Gpp_samples)

fig, ax = plt.subplots(figsize=(9, 6))
ax.fill_between(omega_pred,
    np.percentile(Gp_samples, 2.5, axis=0),
    np.percentile(Gp_samples, 97.5, axis=0),
    alpha=0.2, color="C0")
ax.fill_between(omega_pred,
    np.percentile(Gpp_samples, 2.5, axis=0),
    np.percentile(Gpp_samples, 97.5, axis=0),
    alpha=0.2, color="C1")
ax.loglog(omega_pred, np.median(Gp_samples, axis=0), "-", lw=2, color="C0", label="G' posterior")
ax.loglog(omega_pred, np.median(Gpp_samples, axis=0), "--", lw=2, color="C1", label="G'' posterior")
ax.loglog(omega, G_prime, "s", markersize=6, color="C0", markeredgecolor="k", label="G' data")
ax.loglog(omega, G_double_prime, "o", markersize=6, color="C1", markeredgecolor="k", label="G'' data")
ax.set_xlabel("Angular frequency [rad/s]")
ax.set_ylabel("Modulus [Pa]")
ax.set_title("Posterior Predictive Check")
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3, which="both")
plt.tight_layout()
display(fig)
plt.close(fig)

## 6. Limitations

The STZ SAOS prediction is a **single Maxwell mode** approximation:

- It captures the **dominant relaxation** around the crossover frequency
- It **underestimates G''** at high frequencies (fast beta-relaxation modes not included)
- It **overestimates G'** at low frequencies if the terminal zone hasn't been reached
- For broad spectra, use the Generalized Maxwell model instead

Despite these limitations, the SAOS fit provides valuable estimates of G0 and tau_eff that are physically meaningful within the STZ framework.

## 7. Save Results

In [None]:
from stz_tutorial_utils import save_stz_results

output_dir = os.path.join("..", "outputs", "stz", "saos")
save_stz_results(model, result, output_dir, "saos")

## Key Takeaways

1. **STZ SAOS = Maxwell with activation-controlled relaxation time** — tau_eff = tau0 / (2*epsilon0*Lambda_ss)
2. **Lambda_ss = exp(-ez/chi_inf)** links structural disorder to relaxation dynamics
3. **6 parameters** are needed for SAOS, vs 4 for the flow curve — epsilon0 and G0 become identifiable
4. **Single-mode limitation** — deviations from the semicircular Cole-Cole arc indicate spectral broadening
5. **Crossover frequency omega_c = 1/tau_eff** gives the characteristic relaxation rate of the amorphous solid

## Next Steps

- **Notebook 02**: Startup shear with stress overshoot (requires ODE integration, all 8 parameters)
- **Notebook 06**: LAOS for nonlinear oscillatory response