# EPM Stress Relaxation

> **Handbook:** See [Lattice EPM — Relaxation Protocol](../../docs/source/models/epm/lattice_epm.rst#epm-relaxation) for mathematical details with boxed governing equations.

**EPM Relaxation Physics:** After a step strain $\gamma_0$ at $t=0$, EPM predicts stress decay $\sigma(t)$ through **cascading plastic events**. Unlike Maxwell models with exponential $G(t) = G_0 e^{-t/\tau}$, EPM produces non-exponential relaxation (power-law or stretched exponential) from the **disorder-induced multi-timescale spectrum**.

### Disorder-Induced Multi-Relaxation

The distribution of yield thresholds $\mathcal{N}(\sigma_{c,\text{mean}}, \sigma_{c,\text{std}}^2)$ creates a spectrum of relaxation timescales:

| Yield Threshold | Yielding Behavior | Relaxation Contribution |
|-----------------|-------------------|------------------------|
| **Low** $\sigma_{c,i} \ll \sigma_0$ | Yield immediately after step strain | **Fast relaxation modes** |
| **Medium** $\sigma_{c,i} \sim \sigma_0$ | Yield during intermediate times | **Broad relaxation spectrum** |
| **High** $\sigma_{c,i} \gg \sigma_0$ | Remain elastic (never yield) | **Slow modes** (plateau at long times) |

### Relaxation Modulus

$$G(t) = \frac{\sigma(t)}{\gamma_0}$$

**Typical behavior:**
- **Short times** ($t \to 0$): $G(t) \approx \mu$ (elastic modulus, unrelaxed)
- **Intermediate times**: Power-law or stretched exponential decay
- **Long times**: Plateau (glassy materials) or terminal relaxation (viscoelastic fluids)

### Disorder Controls Spectrum Width

The ratio $\sigma_{c,\text{std}}/\sigma_{c,\text{mean}}$ controls the breadth of the relaxation spectrum:

| Disorder Strength | Relaxation Decay | Physical Origin |
|-------------------|------------------|-----------------|
| Low (< 0.2) | Near-exponential $G(t) \sim e^{-t/\tau}$ | Narrow threshold distribution → quasi-single mode |
| Moderate (0.2–0.5) | Power-law $G(t) \sim t^{-\alpha}$ | Broad distribution → continuous spectrum |
| High (> 0.5) | Stretched exponential $G(t) \sim e^{-(t/\tau)^\beta}$ | Very broad spectrum, $\beta < 1$ |

### Cascading Plastic Events

In EPM, stress relaxes via **avalanches** — an active site yields, triggering neighbors via Eshelby propagator $\mathcal{G}_{ij}$, keeping the system active long after the initial strain:

$$\text{Site } i \text{ yields} \to \Delta\sigma_j = \mathcal{G}_{ij} \Delta\gamma_i^{\text{pl}} \to \text{Site } j \text{ may yield} \to \cdots$$

This leads to slow, non-exponential relaxation (power-law $\sim t^{-\alpha}$ or logarithmic $\sim \ln t$).

### Relation to SAOS

The relaxation modulus $G(t)$ and dynamic moduli $G'(\omega)$, $G''(\omega)$ are **Fourier transform pairs**:

$$G'(\omega) = \omega \int_0^\infty G(t) \sin(\omega t) \, dt$$

Consistent EPM parameters should fit both relaxation and SAOS protocols — this provides a **cross-validation test**.

### Materials

Relaxation tests reveal:
- Polymer melts — multi-mode Maxwell spectrum from disorder
- Colloidal gels — slow relaxation via cage reorganization
- Metallic glasses — power-law relaxation ($\alpha \approx 0.5-1.0$)
- Soft glassy materials — stretched exponential ($\beta \approx 0.3-0.8$)

## Learning Objectives

- Understand disorder-induced multi-relaxation in EPM
- Fit real polymer (polystyrene) relaxation data using NLSQ
- Extract relaxation time distribution from EPM parameters
- Compare with SAOS results (consistency check via Fourier transform)

## Prerequisites

- Complete `01_epm_flow_curve.ipynb` for EPM basics
- Understanding of stress relaxation $G(t)$ measurements

## Estimated Runtime

- Fast demo (1 chain): ~3-4 min
- Full run (4 chains): ~8-12 min

## 1. Setup & Imports

In [None]:
# Colab setup
import sys

IN_COLAB = "google.colab" in sys.modules

if IN_COLAB:
    %pip install -q rheojax
    import os
    os.environ["JAX_ENABLE_X64"] = "true"
    print("RheoJAX installed successfully.")

In [None]:
%matplotlib inline
import time
import sys
import os
import json

sys.path.insert(0, os.path.dirname(os.path.abspath("")))
from utils.plotting_utils import (
    plot_nlsq_fit, display_arviz_diagnostics, plot_posterior_predictive
)

import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display

from rheojax.core.jax_config import safe_import_jax, verify_float64
from rheojax.core.data import RheoData
from rheojax.models.epm.lattice import LatticeEPM

jax, jnp = safe_import_jax()
verify_float64()

FAST_MODE = os.environ.get("FAST_MODE", "1") == "1"

print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
print(f"FAST_MODE: {FAST_MODE}")

def compute_fit_quality(y_true, y_pred):
    """Compute R2 and RMSE."""
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    residuals = y_true - y_pred
    if y_true.ndim > 1:
        residuals = residuals.ravel()
        y_true = y_true.ravel()
    ss_res = np.sum(residuals**2)
    ss_tot = np.sum((y_true - np.mean(y_true))**2)
    r2 = 1.0 - ss_res / ss_tot if ss_tot > 0 else 0.0
    rmse = np.sqrt(np.mean(residuals**2))
    return {"R2": r2, "RMSE": rmse}

## 2. Theory: Relaxation in EPM

In stress relaxation, we apply a step strain γ₀ at t=0 and monitor the decaying stress σ(t). The relaxation modulus is:

$$G(t) = \frac{\sigma(t)}{\gamma_0}$$

### Disorder-Induced Multi-Relaxation

In EPM, the **distribution of yield thresholds** creates a spectrum of relaxation timescales:

- **Sites with low thresholds**: Yield quickly → fast relaxation modes
- **Sites with high thresholds**: Remain elastic longer → slow relaxation modes

This leads to a stretched-exponential or power-law-like decay:

$$G(t) \sim G_0 \cdot f(t/\tau_{\text{pl}}, \sigma_c/\sigma_0)$$

### Key Parameters

| Parameter | Effect on Relaxation |
|-----------|---------------------|
| μ | Sets initial modulus G(0) ≈ μ |
| τ_pl | Controls characteristic relaxation time |
| σ_c,std/σ_c,mean | Controls breadth of relaxation spectrum |

### Relation to SAOS

The relaxation modulus G(t) and dynamic moduli G'(ω), G''(ω) are Fourier transform pairs:

$$G'(\omega) = \omega \int_0^\infty G(t) \sin(\omega t) \, dt$$

Consistent EPM parameters should fit both protocols.

## 3. Load Relaxation Data

We use stress relaxation data from polystyrene at 145°C (same sample as SAOS in Notebook 02).

In [None]:
data_path = os.path.join("..", "data", "relaxation", "polymers", "stressrelaxation_ps145_data.csv")
if IN_COLAB:
    data_path = "stressrelaxation_ps145_data.csv"
    if not os.path.exists(data_path):
        print("Please upload stressrelaxation_ps145_data.csv or adjust the path.")

# Load data (tab-separated: Time, Relaxation Modulus)
raw = np.loadtxt(data_path, delimiter="\t", skiprows=1)
time_data = raw[:, 0]      # Time [s]
G_data = raw[:, 1]         # Relaxation modulus [Pa]

print(f"Data points: {len(time_data)}")
print(f"Time range: {time_data.min():.3f} – {time_data.max():.1f} s")
print(f"G(t) range: {G_data.min():.2e} – {G_data.max():.2e} Pa")

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Left: Linear scale
ax1.plot(time_data, G_data / 1e3, "ko-", markersize=4)
ax1.set_xlabel("Time [s]")
ax1.set_ylabel("Relaxation modulus G(t) [kPa]")
ax1.set_title("Stress Relaxation: Polystyrene at 145°C")
ax1.grid(True, alpha=0.3)

# Right: Log-log scale
ax2.loglog(time_data, G_data, "ko-", markersize=4)
ax2.set_xlabel("Time [s]")
ax2.set_ylabel("Relaxation modulus G(t) [Pa]")
ax2.set_title("Log-Log Plot")
ax2.grid(True, alpha=0.3, which="both")

plt.tight_layout()
display(fig)
plt.close(fig)

The relaxation data shows:
- High initial modulus (~143 kPa)
- Gradual decay over time
- Power-law-like behavior in log-log (broad relaxation spectrum)

## 4. NLSQ Fitting

In [None]:
# Initialize LatticeEPM with defaults
model = LatticeEPM(
    L=16 if FAST_MODE else 32,
    dt=0.01,
)

# Widen bounds for polymer melts (must update both .bounds and .constraints)
polymer_params = {
    "mu":           ((1e4, 1e7),  1e5),
    "tau_pl":       ((0.01, 100.0), 1.0),
    "sigma_c_mean": ((1e3, 1e7),  1e5),
    "sigma_c_std":  ((1e2, 1e6),  1e4),
}
for name, (new_bounds, init_val) in polymer_params.items():
    param = model.parameters[name]
    param.bounds = new_bounds
    for c in param.constraints:
        if c.type == "bounds":
            c.min_value, c.max_value = new_bounds
    model.parameters.set_value(name, init_val)

print("LatticeEPM initialized for relaxation fitting")
for name in polymer_params:
    print(f"  {name:15s} = {model.parameters.get_value(name):.4g}  bounds={model.parameters[name].bounds}")

In [None]:
# Fit to relaxation data
print("Fitting to relaxation data...")
t0 = time.time()
model.fit(time_data, G_data, test_mode="relaxation", method='scipy')
t_nlsq = time.time() - t0

# Compute fit quality
y_pred = model.predict(time_data, test_mode="relaxation", smooth=True).y
metrics = compute_fit_quality(G_data, y_pred)

print(f"NLSQ fit time: {t_nlsq:.2f} s")
print(f"R²: {metrics['R2']:.6f}")
print(f"RMSE: {metrics['RMSE']:.2e} Pa")

print("\nFitted parameters:")
param_names = ["mu", "tau_pl", "sigma_c_mean", "sigma_c_std"]
for name in param_names:
    val = model.parameters.get_value(name)
    print(f"  {name:15s} = {val:.4g}")

In [None]:
# Plot NLSQ fit with uncertainty band
param_names = ["mu", "tau_pl", "sigma_c_mean", "sigma_c_std"]

# Pass x_pred=time_data to avoid fine-grid shape mismatch in EPM
fig, ax = plot_nlsq_fit(
    time_data, G_data, model, test_mode="relaxation",
    param_names=param_names, x_pred=time_data, log_scale=True,
    xlabel="Time [s]", ylabel="Relaxation modulus G(t) [Pa]",
    title=f"Relaxation Fit (R2 = {metrics['R2']:.5f})",
)
display(fig)
plt.close(fig)

# Define fine grid for later use
time_fine = np.logspace(
    np.log10(time_data.min()),
    np.log10(time_data.max()) + 0.3,
    200,
)
rheo_fine = RheoData(
    x=time_fine, y=np.zeros_like(time_fine),
    initial_test_mode="relaxation",
)

## 5. Relaxation Time Distribution

The EPM parameters encode a distribution of relaxation times through the yield threshold distribution.

In [None]:
# Extract parameters
mu = model.parameters.get_value("mu")
tau_pl = model.parameters.get_value("tau_pl")
sigma_c_mean = model.parameters.get_value("sigma_c_mean")
sigma_c_std = model.parameters.get_value("sigma_c_std")

# Coefficient of variation (disorder strength)
cv = sigma_c_std / sigma_c_mean

print("Relaxation Spectrum Analysis")
print("=" * 45)
print(f"Characteristic relaxation time: τ_pl = {tau_pl:.4f} s")
print(f"Disorder strength (CV): σ_c,std/σ_c,mean = {cv:.3f}")
print(f"\nInterpretation:")
print(f"  CV = {cv:.3f} → {'Narrow' if cv < 0.2 else 'Moderate' if cv < 0.5 else 'Broad'} relaxation spectrum")
print(f"  Characteristic frequency: ω_c = 1/τ_pl ≈ {1/tau_pl:.2f} rad/s")

In [None]:
# Visualize implied relaxation time distribution
# In EPM, sites with threshold σ_c relax at rate ~ 1/τ_pl when σ > σ_c
# The distribution of thresholds maps to a distribution of effective relaxation times

# Sample threshold distribution
n_samples = 10000
sigma_c_samples = np.random.normal(sigma_c_mean, sigma_c_std, n_samples)
sigma_c_samples = np.abs(sigma_c_samples)  # Thresholds are positive

# Effective relaxation times (rough approximation)
# τ_eff ~ τ_pl * (σ_c / σ_typical)
sigma_typical = sigma_c_mean  # Reference stress
tau_eff = tau_pl * (sigma_c_samples / sigma_typical)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Left: Threshold distribution
ax1.hist(sigma_c_samples, bins=50, density=True, alpha=0.7, color="C0", edgecolor="black")
ax1.axvline(sigma_c_mean, color="C1", linestyle="--", lw=2, label=f"Mean: {sigma_c_mean:.2e} Pa")
ax1.set_xlabel("Yield threshold σ_c [Pa]")
ax1.set_ylabel("Probability density")
ax1.set_title("Yield Threshold Distribution")
ax1.legend()
ax1.grid(True, alpha=0.3)

# Right: Effective relaxation time distribution
ax2.hist(tau_eff, bins=50, density=True, alpha=0.7, color="C2", edgecolor="black")
ax2.axvline(tau_pl, color="C1", linestyle="--", lw=2, label=f"τ_pl: {tau_pl:.3f} s")
ax2.set_xlabel("Effective relaxation time τ_eff [s]")
ax2.set_ylabel("Probability density")
ax2.set_title("Implied Relaxation Time Distribution")
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
display(fig)
plt.close(fig)

## 6. Bayesian Inference

In [None]:
# Warm-start from NLSQ
initial_values = {
    name: model.parameters.get_value(name)
    for name in model.parameters.keys()
}

# --- Fast demo config ---
NUM_WARMUP = 50 if FAST_MODE else 200
NUM_SAMPLES = 100 if FAST_MODE else 500
NUM_CHAINS = 1

print(f"Running Bayesian inference: {NUM_CHAINS} chain(s)")

t0 = time.time()
result = model.fit_bayesian(
    time_data,
    G_data,
    test_mode="relaxation",
    num_warmup=NUM_WARMUP,
    num_samples=NUM_SAMPLES,
    num_chains=NUM_CHAINS,
    initial_values=initial_values,
    seed=42,
)
t_bayes = time.time() - t0
print(f"Bayesian inference time: {t_bayes:.1f} s")

In [None]:
# Convergence diagnostics
diag = result.diagnostics

print("Convergence Diagnostics")
print("=" * 50)
print(f"{'Parameter':>15s}  {'R-hat':>8s}  {'ESS':>8s}")
print("-" * 50)
for p in param_names:
    r_hat = diag.get("r_hat", {}).get(p, float("nan"))
    ess = diag.get("ess", {}).get(p, float("nan"))
    print(f"{p:>15s}  {r_hat:8.4f}  {ess:8.0f}")

n_div = diag.get("divergences", diag.get("num_divergences", 0))
print(f"\nDivergences: {n_div}")

In [None]:
# ArviZ diagnostic plots
display_arviz_diagnostics(result, param_names, fast_mode=FAST_MODE)

## 7. Posterior Predictive Check

In [None]:
posterior = result.posterior_samples
n_draws = min(5 if FAST_MODE else 100, len(list(posterior.values())[0]))

print(f"Computing {n_draws} posterior predictive samples...")
pred_samples = []
for i in range(n_draws):
    for name in param_names:
        model.parameters.set_value(name, float(posterior[name][i]))
    pred_i = model.predict(rheo_fine, smooth=True).y
    pred_samples.append(np.array(pred_i))

pred_samples = np.array(pred_samples)
pred_median = np.median(pred_samples, axis=0)
pred_lo = np.percentile(pred_samples, 2.5, axis=0)
pred_hi = np.percentile(pred_samples, 97.5, axis=0)
print("Done.")

fig, ax = plt.subplots(figsize=(9, 6))
ax.fill_between(time_fine, pred_lo, pred_hi, alpha=0.3, color="C0", label="95% CI")
ax.loglog(time_fine, pred_median, "-", lw=2, color="C0", label="Posterior median")
ax.loglog(time_data, G_data, "ko", markersize=5, label="Data")
ax.set_xlabel("Time [s]")
ax.set_ylabel("Relaxation modulus G(t) [Pa]")
ax.set_title("Posterior Predictive Check - Relaxation")
ax.legend()
ax.grid(True, alpha=0.3, which="both")
plt.tight_layout()
display(fig)
plt.close(fig)

## 8. Comparison with SAOS

If you ran Notebook 02, we can compare the relaxation time τ_pl with the SAOS crossover frequency.

In [None]:
# Try to load SAOS results
saos_params_file = os.path.join("..", "outputs", "epm", "oscillation", "nlsq_params_oscillation.json")

if os.path.exists(saos_params_file):
    with open(saos_params_file) as f:
        saos_params = json.load(f)
    
    tau_pl_relax = float(np.median(posterior["tau_pl"]))
    tau_pl_saos = saos_params.get("tau_pl", None)
    
    if tau_pl_saos is not None:
        print("Cross-Protocol Comparison")
        print("=" * 45)
        print(f"τ_pl from relaxation: {tau_pl_relax:.4f} s")
        print(f"τ_pl from SAOS:       {tau_pl_saos:.4f} s")
        print(f"Ratio:                {tau_pl_relax/tau_pl_saos:.2f}")
        print(f"\nCrossover frequencies:")
        print(f"  ω_c (relaxation): {1/tau_pl_relax:.2f} rad/s")
        print(f"  ω_c (SAOS):       {1/tau_pl_saos:.2f} rad/s")
else:
    print("SAOS results not found. Run Notebook 02 for cross-protocol comparison.")
    tau_pl_relax = float(np.median(posterior["tau_pl"]))
    print(f"\nRelaxation τ_pl: {tau_pl_relax:.4f} s")
    print(f"Implied crossover: ω_c ≈ {1/tau_pl_relax:.2f} rad/s")

## 9. Parameter Summary

In [None]:
print("\nParameter Summary")
print("=" * 65)
print(f"{'Param':>15s}  {'Median':>12s}  {'95% CI':>28s}")
print("-" * 65)

for name in param_names:
    samples = posterior[name]
    median = float(np.median(samples))
    lo = float(np.percentile(samples, 2.5))
    hi = float(np.percentile(samples, 97.5))
    print(f"{name:>15s}  {median:12.4g}  [{lo:.4g}, {hi:.4g}]")

## 10. Key Takeaways

1. **EPM captures multi-relaxation** through disorder in yield thresholds
2. **$\tau_{\text{pl}}$ sets the characteristic relaxation time** — directly comparable to SAOS crossover $\omega_c \approx 1/\tau_{\text{pl}}$
3. **Disorder strength** ($\sigma_{c,\text{std}}/\sigma_{c,\text{mean}}$) controls the breadth of the relaxation spectrum
4. **Consistency check**: $\tau_{\text{pl}}$ from relaxation should match $\tau_{\text{pl}}$ from SAOS on the same material (Fourier duality)
5. **Power-law relaxation** emerges naturally from Gaussian threshold distributions (no ad-hoc stretched exponentials)

## Next Steps

- **Notebook 06**: Visualization of lattice stress fields during relaxation
- Compare EPM parameters across all protocols (flow curve, SAOS, startup, creep, relaxation) for self-consistency

## Further Reading

**Handbook:**
- [Lattice EPM — Relaxation Protocol](../../docs/source/models/epm/lattice_epm.rst#epm-relaxation) — Boxed governing equations for step strain

**Key References:**
- Nicolas, A., Ferrero, E. E., Martens, K., & Barrat, J.-L. (2018). "Deformation and flow of amorphous solids." *Reviews of Modern Physics*, 90, 045006.
- Martens, K., Bocquet, L., & Barrat, J.-L. (2011). "Connecting diffusion and dynamical heterogeneities in actively deformed amorphous systems." *Physical Review Letters*, 106, 156001.

In [None]:
# Save results
output_dir = os.path.join("..", "outputs", "epm", "relaxation")
os.makedirs(output_dir, exist_ok=True)

median_params = {name: float(np.median(posterior[name])) for name in param_names}
with open(os.path.join(output_dir, "nlsq_params_relaxation.json"), "w") as f:
    json.dump(median_params, f, indent=2)

print(f"Results saved to {output_dir}/")