# SGR Stress Relaxation: Aging in Soft Glassy Materials

**Learning Objectives:**
- Fit SGRConventional to stress relaxation data from aging soft glassy materials
- Understand power-law relaxation G(t) ~ t^(x-2) vs exponential decay
- Track aging through evolution of noise temperature x(t_wait)
- Compare SGRConventional vs SGRGeneric and verify thermodynamic consistency

**Prerequisites:** basic/01, bayesian/01

**Runtime:** Fast ~2 min (num_chains=1), Full ~5 min (num_chains=4)

## 1. Setup

In [None]:
import sys

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    %pip install -q rheojax
    import os

    os.environ["JAX_ENABLE_X64"] = "true"
    print("RheoJAX installed successfully.")

In [None]:
%matplotlib inline
import json
import os
import time
import warnings

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display

from rheojax.core.jax_config import safe_import_jax, verify_float64
from rheojax.models.sgr import SGRConventional, SGRGeneric

jax, jnp = safe_import_jax()
verify_float64()

warnings.filterwarnings("ignore", category=FutureWarning)
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")

## 2. Theory: SGR Stress Relaxation

The Soft Glassy Rheology (SGR) model predicts power-law stress relaxation, distinct from the exponential decay of Maxwell-type models:

$$G(t) = G_0 \cdot G_0(x) \cdot \left(1 + \frac{t}{\tau_0}\right)^{x-2}$$

where:
- **x**: Noise temperature (controls phase behavior)
  - x < 1: Deep glass (G(t) → ∞ as t → ∞)
  - x = 2: Marginal plateau
  - x > 2: Decaying modulus (approaching fluid)
- **G₀**: Elastic modulus scale
- **τ₀**: Elementary relaxation time

**Aging Signature:** As soft glassy materials age (rest at constant temperature), structural rearrangements slow down. In SGR, this is reflected by **decreasing x** over waiting time t_wait, indicating deeper energy traps and slower dynamics.

**Contrast with Maxwell:** Standard viscoelastic models predict G(t) ~ exp(-t/τ), which is exponential (straight line on log-linear plot). SGR's power-law decay appears as a straight line on log-log plot with slope (x-2).

## 3. Load Data

In [None]:
aging_times = [600, 1200, 1800, 2400, 3600]
datasets = {}

for t_age in aging_times:
    data_path = os.path.join("..", "data", "relaxation", "clays", f"rel_lapo_{t_age}.csv")
    raw = np.loadtxt(data_path, delimiter="\t", skiprows=1)
    t = raw[:, 0]
    G_t = raw[:, 1]
    datasets[t_age] = {"t": t, "G_t": G_t}
    print(f"t_age={t_age:4d}s: {len(t)} points, t=[{t.min():.3f}, {t.max():.1f}] s")

In [None]:
fig, ax = plt.subplots(figsize=(8, 5))
colors = plt.cm.plasma(np.linspace(0.2, 0.9, len(aging_times)))

for i, t_age in enumerate(aging_times):
    d = datasets[t_age]
    ax.loglog(d["t"], d["G_t"], "o", color=colors[i], markersize=4, label=f"t_age={t_age}s")

ax.set_xlabel("Time [s]")
ax.set_ylabel("G(t) [Pa]")
ax.set_title("Laponite Clay Relaxation — 5 Aging Times")
ax.legend()
ax.grid(True, alpha=0.3, which="both")
plt.tight_layout()
display(fig)
plt.close(fig)

## 4. NLSQ Fitting

### 4.1 Single Aging Time (3600 s)

In [None]:
model = SGRConventional()

d = datasets[3600]
t0_fit = time.time()
model.fit(d["t"], d["G_t"], test_mode="relaxation")
t_nlsq = time.time() - t0_fit

print(f"NLSQ fit time: {t_nlsq:.2f} s")
print(f"R²: {model._fit_result.r_squared:.6f}")
print(f"RMSE: {model._fit_result.rmse:.4g} Pa")
print("\nFitted parameters:")
for name in ["x", "G0", "tau0"]:
    val = model.parameters.get_value(name)
    print(f"  {name:5s} = {val:.4g}")
print(f"Phase regime: {model.get_phase_regime()}")

In [None]:
t_fine = np.logspace(np.log10(d["t"].min()) - 0.3, np.log10(d["t"].max()) + 0.3, 200)
G_pred = model.predict(t_fine)

fig, ax = plt.subplots(figsize=(8, 5))
ax.loglog(d["t"], d["G_t"], "ko", markersize=5, label="Data (t_age=3600s)")
ax.loglog(t_fine, G_pred, "-", lw=2, color="C0", label="SGR fit")

# Power-law reference
x_fit = model.parameters.get_value("x")
G0_fit = model.parameters.get_value("G0")
slope = x_fit - 2
ref_t = np.logspace(0, 2, 50)
ref_G = G0_fit * ref_t**slope * 0.5  # Scaled for visibility
ax.loglog(
    ref_t, ref_G, ":", lw=1.5, color="gray", alpha=0.5, label=f"Reference slope {slope:.2f}"
)

ax.set_xlabel("Time [s]")
ax.set_ylabel("G(t) [Pa]")
ax.set_title(f"SGR Relaxation Fit — x={x_fit:.3f}")
ax.legend()
ax.grid(True, alpha=0.3, which="both")
plt.tight_layout()
display(fig)
plt.close(fig)

### 4.2 Aging Time Sweep

In [None]:
fit_results = {}

for t_age in aging_times:
    m = SGRConventional()
    d = datasets[t_age]
    m.fit(d["t"], d["G_t"], test_mode="relaxation")
    fit_results[t_age] = {
        "x": float(m.parameters.get_value("x")),
        "G0": float(m.parameters.get_value("G0")),
        "tau0": float(m.parameters.get_value("tau0")),
        "R2": float(m._fit_result.r_squared),
        "regime": m.get_phase_regime(),
    }

print(f"{'t_age':>6s}  {'x':>6s}  {'G0':>10s}  {'τ₀':>10s}  {'R²':>8s}  {'Regime'}")
print("-" * 60)
for t_age in aging_times:
    r = fit_results[t_age]
    print(
        f"{t_age:6d}  {r['x']:6.3f}  {r['G0']:10.2f}  {r['tau0']:10.2e}  {r['R2']:8.5f}  {r['regime']}"
    )

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ages = list(fit_results.keys())
x_vals = [fit_results[a]["x"] for a in ages]

ax1.plot(ages, x_vals, "o-", markersize=8, lw=2)
ax1.axhline(1.0, color="red", linestyle="--", alpha=0.5, label="Glass transition (x=1)")
ax1.set_xlabel("Aging time [s]")
ax1.set_ylabel("Noise temperature x")
ax1.set_title("x(t_wait): Aging Drives x Toward Glass")
ax1.legend()
ax1.grid(True, alpha=0.3)

# Overlay all fits
for i, t_age in enumerate(aging_times):
    d = datasets[t_age]
    ax2.loglog(d["t"], d["G_t"], "o", color=colors[i], markersize=3, alpha=0.5)
    m = SGRConventional()
    m.fit(d["t"], d["G_t"], test_mode="relaxation")
    t_f = np.logspace(np.log10(d["t"].min()), np.log10(d["t"].max()), 100)
    ax2.loglog(t_f, m.predict(t_f), "-", color=colors[i], lw=1.5, label=f"t_age={t_age}s")

ax2.set_xlabel("Time [s]")
ax2.set_ylabel("G(t) [Pa]")
ax2.set_title("All SGR Fits")
ax2.legend(fontsize=8)
ax2.grid(True, alpha=0.3, which="both")

plt.tight_layout()
display(fig)
plt.close(fig)

### 4.3 SGRGeneric Comparison

In [None]:
d = datasets[3600]
model_gen = SGRGeneric()
model_gen.fit(d["t"], d["G_t"], test_mode="relaxation")

print("SGRConventional vs SGRGeneric (t_age=3600s):")
print(
    f"  Conventional: x={model.parameters.get_value('x'):.4f}, R²={model._fit_result.r_squared:.6f}"
)
print(
    f"  Generic:      x={model_gen.parameters.get_value('x'):.4f}, R²={model_gen._fit_result.r_squared:.6f}"
)

state = np.array([100.0, 0.5])
consistency = model_gen.verify_thermodynamic_consistency(state)
print(
    f"\nThermodynamic consistency: {consistency.get('thermodynamically_consistent', 'N/A')}"
)
for key, val in consistency.items():
    if key != "thermodynamically_consistent":
        print(f"  {key}: {val}")

## 5. Bayesian Inference with NUTS

### 5.1 Run NUTS

In [None]:
initial_values = {name: model.parameters.get_value(name) for name in ["x", "G0", "tau0"]}
print("Warm-start values:", initial_values)

NUM_WARMUP = 200
NUM_SAMPLES = 500
NUM_CHAINS = 1
# NUM_WARMUP = 1000; NUM_SAMPLES = 2000; NUM_CHAINS = 4  # production

t0 = time.time()
result = model.fit_bayesian(
    d["t"],
    d["G_t"],
    test_mode="relaxation",
    num_warmup=NUM_WARMUP,
    num_samples=NUM_SAMPLES,
    num_chains=NUM_CHAINS,
    initial_values=initial_values,
    seed=42,
)
t_bayes = time.time() - t0
print(f"\nBayesian inference time: {t_bayes:.1f} s")

### 5.2 Convergence Diagnostics

In [None]:
diag = result.diagnostics
param_names = ["x", "G0", "tau0"]

print("Convergence Diagnostics")
print("=" * 50)
print(f"{'Parameter':>10s}  {'R-hat':>8s}  {'ESS':>8s}")
print("-" * 50)
for p in param_names:
    r_hat = diag.get("r_hat", {}).get(p, float("nan"))
    ess = diag.get("ess", {}).get(p, float("nan"))
    print(f"{p:>10s}  {r_hat:8.4f}  {ess:8.0f}")

n_div = diag.get("divergences", diag.get("num_divergences", 0))
print(f"\nDivergences: {n_div}")

### 5.3 ArviZ Diagnostic Plots

In [None]:
idata = result.to_inference_data()

axes = az.plot_trace(idata, var_names=param_names, figsize=(12, 6))
fig = axes.ravel()[0].figure
fig.suptitle("Trace Plots", fontsize=14, y=1.02)
plt.tight_layout()
display(fig)
plt.close(fig)

In [None]:
axes = az.plot_pair(
    idata, var_names=param_names, kind="scatter", divergences=True, figsize=(9, 9)
)
fig = axes.ravel()[0].figure
fig.suptitle("Parameter Correlations", fontsize=14, y=1.02)
plt.tight_layout()
display(fig)
plt.close(fig)

In [None]:
axes = az.plot_forest(
    idata, var_names=param_names, combined=True, hdi_prob=0.95, figsize=(10, 4)
)
fig = axes.ravel()[0].figure
plt.tight_layout()
display(fig)
plt.close(fig)

### 5.4 Posterior Predictive Check

In [None]:
posterior = result.posterior_samples
n_draws = min(200, len(list(posterior.values())[0]))
t_pred = np.logspace(np.log10(d["t"].min()) - 0.3, np.log10(d["t"].max()) + 0.3, 100)

pred_samples = []
for i in range(n_draws):
    params_i = jnp.array([posterior["x"][i], posterior["G0"][i], posterior["tau0"][i]])
    pred_i = model.model_function(jnp.array(t_pred), params_i, test_mode="relaxation")
    pred_samples.append(np.array(pred_i))

pred_samples = np.array(pred_samples)
pred_median = np.median(pred_samples, axis=0)
pred_lo = np.percentile(pred_samples, 2.5, axis=0)
pred_hi = np.percentile(pred_samples, 97.5, axis=0)

fig, ax = plt.subplots(figsize=(9, 6))
ax.fill_between(t_pred, pred_lo, pred_hi, alpha=0.3, color="C0", label="95% CI")
ax.loglog(t_pred, pred_median, "-", lw=2, color="C0", label="Posterior median")
ax.loglog(d["t"], d["G_t"], "ko", markersize=5, label="Data")
ax.set_xlabel("Time [s]")
ax.set_ylabel("G(t) [Pa]")
ax.set_title("Posterior Predictive Check (t_age=3600s)")
ax.legend()
ax.grid(True, alpha=0.3, which="both")
plt.tight_layout()
display(fig)
plt.close(fig)

## 6. Save Results

In [None]:
output_dir = os.path.join("..", "outputs", "sgr", "relaxation")
os.makedirs(output_dir, exist_ok=True)

nlsq_params = {name: float(model.parameters.get_value(name)) for name in ["x", "G0", "tau0"]}
with open(os.path.join(output_dir, "nlsq_params.json"), "w") as f:
    json.dump(nlsq_params, f, indent=2)

aging_results = {str(k): v for k, v in fit_results.items()}
with open(os.path.join(output_dir, "aging_sweep_results.json"), "w") as f:
    json.dump(aging_results, f, indent=2)

posterior_dict = {k: np.array(v).tolist() for k, v in posterior.items()}
with open(os.path.join(output_dir, "posterior_samples.json"), "w") as f:
    json.dump(posterior_dict, f)

print(f"Results saved to {output_dir}/")

## Key Takeaways

1. **Power-law relaxation**: SGR predicts G(t) ~ (1 + t/τ₀)^(x-2), fundamentally different from exponential decay in Maxwell models. This appears as a straight line on log-log plots with slope (x-2).

2. **Aging signature**: As laponite clay ages (increasing t_wait), the noise temperature x decreases, indicating deeper energy traps and slower structural relaxation. This trend is the hallmark of soft glassy aging.

3. **Tracking x(t_wait)**: The evolution of x provides a quantitative measure of aging dynamics. Approaching x → 1 signals approach to the glass transition.

4. **SGRGeneric equivalence**: SGRGeneric provides equivalent fits with thermodynamic consistency guarantees from the GENERIC framework, useful for validating conventional SGR predictions.

5. **Log-log diagnostics**: Power-law reference slopes directly encode the exponent (x-2), making visual validation straightforward on log-log plots.

**Next Steps:** Explore SAOS (NB 03) for frequency-domain view or creep (NB 04) for complementary time-domain protocol. Both will reveal different aspects of the same underlying SGR dynamics.