# EPM Startup Shear

**Elasto-Plastic Model (EPM) — Startup flow with stress overshoot**

## Learning Objectives

- Understand stress overshoot in EPM as avalanche onset
- Generate synthetic startup data from calibrated flow curve parameters
- Fit startup data and recover original parameters (parameter recovery test)
- Predict N₁(t) evolution using TensorialEPM

## Prerequisites

- Complete `01_epm_flow_curve.ipynb` first (provides calibrated parameters)
- Understanding of startup shear protocol σ(t) at constant γ̇

## Estimated Runtime

- Fast demo (1 chain): ~3-4 min
- Full run (4 chains): ~8-12 min

## 1. Setup & Imports

In [None]:
# Colab setup
import sys

IN_COLAB = "google.colab" in sys.modules

if IN_COLAB:
    %pip install -q rheojax
    import os
    os.environ["JAX_ENABLE_X64"] = "true"
    print("RheoJAX installed successfully.")

In [None]:
%matplotlib inline
import time
import sys
import os
import json

sys.path.insert(0, os.path.dirname(os.path.abspath("")))
from utils.plotting_utils import (
    plot_nlsq_fit, display_arviz_diagnostics, plot_posterior_predictive
)

import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display

from rheojax.core.jax_config import safe_import_jax, verify_float64
from rheojax.core.data import RheoData
from rheojax.models.epm.lattice import LatticeEPM
from rheojax.models.epm.tensor import TensorialEPM

jax, jnp = safe_import_jax()
verify_float64()

FAST_MODE = os.environ.get("FAST_MODE", "1") == "1"

print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
print(f"FAST_MODE: {FAST_MODE}")

def compute_fit_quality(y_true, y_pred):
    """Compute R2 and RMSE."""
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    residuals = y_true - y_pred
    if y_true.ndim > 1:
        residuals = residuals.ravel()
        y_true = y_true.ravel()
    ss_res = np.sum(residuals**2)
    ss_tot = np.sum((y_true - np.mean(y_true))**2)
    r2 = 1.0 - ss_res / ss_tot if ss_tot > 0 else 0.0
    rmse = np.sqrt(np.mean(residuals**2))
    return {"R2": r2, "RMSE": rmse}

## 2. Theory: Startup in EPM

In startup shear, we apply a constant shear rate γ̇ starting from rest and monitor σ(t).

### Stress Overshoot Physics

The characteristic **stress overshoot** in EPM arises from:

1. **Initial elastic loading**: σ(t) ≈ μγ̇t (linear regime)
2. **Avalanche onset**: First yielding events trigger stress redistribution
3. **Cascading plasticity**: Eshelby propagator causes neighboring sites to yield
4. **Steady state**: Balance between loading and plastic dissipation

### Key Metrics

| Metric | Symbol | Physical Meaning |
|--------|--------|------------------|
| Peak stress | σ_max | Maximum stress before yielding cascade |
| Peak time | t_peak | Time to avalanche onset |
| Peak strain | γ_peak = γ̇·t_peak | Strain to yielding (~yield strain) |
| Steady stress | σ_∞ | Flow stress at long times |
| Overshoot ratio | σ_max/σ_∞ | Measures avalanche strength |

### Rate Dependence

At higher shear rates:
- **t_peak decreases**: Less time for elastic loading
- **γ_peak ≈ constant**: Yield strain is material-dependent
- **σ_max/σ_∞ may vary**: Depends on disorder and rate

## 3. Load Calibrated Parameters

We load parameters from Notebook 01 (flow curve fitting). If not available, we use default values.

In [None]:
# Try to load from Notebook 01 output
params_file = os.path.join("..", "outputs", "epm", "flow_curve", "nlsq_params.json")

if os.path.exists(params_file):
    with open(params_file) as f:
        true_params = json.load(f)
    print("Loaded calibrated parameters from Notebook 01:")
else:
    # Default parameters for demonstration
    true_params = {
        "mu": 2.5,
        "tau_pl": 1.0,
        "sigma_c_mean": 25.0,
        "sigma_c_std": 2.5,
    }
    print("Using default parameters (run Notebook 01 to use calibrated values):")

for k, v in true_params.items():
    if k in ["mu", "tau_pl", "sigma_c_mean", "sigma_c_std"]:
        print(f"  {k}: {v:.4g}")

## 4. Generate Synthetic Startup Data

We generate synthetic startup data using the calibrated parameters, then add 3% Gaussian noise.

In [None]:
# Simulation parameters
GAMMA_DOT = 1.0      # Applied shear rate [1/s]
T_END = 10.0         # End time [s]
N_POINTS = 80        # Number of data points
NOISE_LEVEL = 0.03   # 3% noise
SEED = 42

print(f"Generating startup data at γ̇ = {GAMMA_DOT} 1/s")
print(f"  Time range: 0 – {T_END} s")
print(f"  Data points: {N_POINTS}")
print(f"  Noise level: {NOISE_LEVEL*100:.0f}%")

In [None]:
# Create model with true parameters
model_true = LatticeEPM(
    L=16 if FAST_MODE else 32,
    dt=0.01,
    mu=true_params.get("mu", 2.5),
    tau_pl=true_params.get("tau_pl", 1.0),
    sigma_c_mean=true_params.get("sigma_c_mean", 25.0),
    sigma_c_std=true_params.get("sigma_c_std", 2.5),
)

# Generate time array
time_data = np.linspace(0.1, T_END, N_POINTS)  # Start at 0.1 to avoid t=0

# Create RheoData for startup
rheo_data = RheoData(
    x=time_data,
    y=np.zeros_like(time_data),
    initial_test_mode="startup",
    metadata={"gamma_dot": GAMMA_DOT},
)

# Simulate clean signal
print("Running EPM startup simulation...")
t0 = time.time()
result_clean = model_true.predict(rheo_data, smooth=True, seed=SEED)
t_sim = time.time() - t0
print(f"Simulation time: {t_sim:.2f} s")

stress_clean = np.array(result_clean.y)

# Add noise (use absolute mean to handle negative-stress scenarios)
rng = np.random.default_rng(SEED)
noise_scale = NOISE_LEVEL * max(np.abs(np.mean(stress_clean)), 1e-10)
noise = rng.normal(0, noise_scale, size=stress_clean.shape)
stress_noisy = stress_clean + noise

print(f"Stress range: {stress_noisy.min():.2f} - {stress_noisy.max():.2f} Pa")

In [None]:
fig, ax = plt.subplots(figsize=(9, 6))

ax.plot(time_data, stress_clean, "-", lw=2, color="C0", label="Clean signal", alpha=0.7)
ax.plot(time_data, stress_noisy, "ko", markersize=4, label="Noisy data (3%)")

# Mark overshoot
idx_peak = np.argmax(stress_clean)
ax.axvline(time_data[idx_peak], color="gray", linestyle="--", alpha=0.5)
ax.annotate(
    f"t_peak = {time_data[idx_peak]:.2f} s\nσ_max = {stress_clean[idx_peak]:.2f} Pa",
    xy=(time_data[idx_peak], stress_clean[idx_peak]),
    xytext=(time_data[idx_peak] + 1, stress_clean[idx_peak] * 0.95),
    fontsize=10,
    arrowprops=dict(arrowstyle="->", color="gray"),
)

ax.set_xlabel("Time [s]")
ax.set_ylabel("Stress [Pa]")
ax.set_title(f"Synthetic Startup Data at γ̇ = {GAMMA_DOT} 1/s")
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
display(fig)
plt.close(fig)

## 5. Overshoot Metrics

In [None]:
# Compute overshoot metrics
idx_peak = np.argmax(stress_clean)
sigma_max = stress_clean[idx_peak]
t_peak = time_data[idx_peak]
gamma_peak = GAMMA_DOT * t_peak

# Steady state (average of last 20% of data)
n_steady = max(5, len(stress_clean) // 5)
sigma_inf = np.mean(stress_clean[-n_steady:])

overshoot_ratio = sigma_max / sigma_inf

print("Overshoot Metrics")
print("=" * 40)
print(f"Peak stress σ_max:     {sigma_max:.3f} Pa")
print(f"Peak time t_peak:      {t_peak:.3f} s")
print(f"Peak strain γ_peak:    {gamma_peak:.3f}")
print(f"Steady stress σ_∞:     {sigma_inf:.3f} Pa")
print(f"Overshoot ratio:       {overshoot_ratio:.3f}")
print(f"\nInterpretation:")
print(f"  γ_peak ≈ {gamma_peak:.2f} is the effective yield strain")
print(f"  Overshoot of {(overshoot_ratio-1)*100:.1f}% indicates avalanche dynamics")

## 6. NLSQ Fitting (Parameter Recovery)

In [None]:
# Initialize model for fitting (start from different initial values)
model_fit = LatticeEPM(
    L=16 if FAST_MODE else 32,
    dt=0.01,
    mu=1.0,           # Different from true
    tau_pl=0.5,       # Different from true
    sigma_c_mean=10.0,
    sigma_c_std=1.0,
)

# Set bounds
model_fit.parameters["mu"].bounds = (0.1, 20.0)
model_fit.parameters["tau_pl"].bounds = (0.01, 50.0)
model_fit.parameters["sigma_c_mean"].bounds = (0.5, 100.0)
model_fit.parameters["sigma_c_std"].bounds = (0.01, 20.0)

# Fit
print("Fitting to synthetic startup data...")
t0 = time.time()
model_fit.fit(time_data, stress_noisy, test_mode="startup", gamma_dot=GAMMA_DOT, method='scipy')
t_nlsq = time.time() - t0

# Compute fit quality
y_pred = model_fit.predict(time_data, test_mode="startup", gamma_dot=GAMMA_DOT, smooth=True, seed=SEED).y
metrics = compute_fit_quality(stress_noisy, y_pred)

print(f"NLSQ fit time: {t_nlsq:.2f} s")
print(f"R2: {metrics['R2']:.6f}")
print(f"RMSE: {metrics['RMSE']:.4f} Pa")

In [None]:
# Compare fitted vs true parameters
param_names = ["mu", "tau_pl", "sigma_c_mean", "sigma_c_std"]

print("\nParameter Recovery")
print("=" * 55)
print(f"{'Parameter':>15s}  {'True':>10s}  {'Fitted':>10s}  {'Error %':>10s}")
print("-" * 55)

for name in param_names:
    true_val = true_params.get(name, 1.0)
    fit_val = model_fit.parameters.get_value(name)
    error_pct = abs(fit_val - true_val) / true_val * 100
    print(f"{name:>15s}  {true_val:10.4g}  {fit_val:10.4g}  {error_pct:10.1f}%")

In [None]:
# Plot NLSQ fit with uncertainty band
param_names = ["mu", "tau_pl", "sigma_c_mean", "sigma_c_std"]

time_fine = np.linspace(0.1, T_END, 200)
rheo_fine = RheoData(
    x=time_fine, y=np.zeros_like(time_fine),
    initial_test_mode="startup",
    metadata={"gamma_dot": GAMMA_DOT},
)

fig, ax = plt.subplots(figsize=(9, 6))
# Pass x_pred=time_data to avoid fine-grid shape mismatch in EPM
plot_nlsq_fit(
    time_data, stress_noisy, model_fit, test_mode="startup",
    param_names=param_names, x_pred=time_data, log_scale=False,
    xlabel="Time [s]", ylabel="Stress [Pa]",
    title=f"Startup Fit (R2 = {metrics['R2']:.5f})",
    ax=ax, gamma_dot=GAMMA_DOT,
)

# Overlay true model
stress_true = model_true.predict(rheo_fine, smooth=True, seed=SEED).y
ax.plot(time_fine, stress_true, "--", lw=1.5, color="C1", alpha=0.7, label="True model")
ax.legend()

plt.tight_layout()
display(fig)
plt.close(fig)

## 7. Bayesian Inference

In [None]:
# Warm-start from NLSQ
initial_values = {
    name: model_fit.parameters.get_value(name)
    for name in model_fit.parameters.keys()
}

# --- Fast demo config ---
NUM_WARMUP = 50 if FAST_MODE else 200
NUM_SAMPLES = 100 if FAST_MODE else 500
NUM_CHAINS = 1

print(f"Running Bayesian inference: {NUM_CHAINS} chain(s)")

t0 = time.time()
result = model_fit.fit_bayesian(
    time_data,
    stress_noisy,
    test_mode="startup",
    gamma_dot=GAMMA_DOT,
    num_warmup=NUM_WARMUP,
    num_samples=NUM_SAMPLES,
    num_chains=NUM_CHAINS,
    initial_values=initial_values,
    seed=42,
)
t_bayes = time.time() - t0
print(f"Bayesian inference time: {t_bayes:.1f} s")

In [None]:
# Convergence diagnostics
diag = result.diagnostics

print("Convergence Diagnostics")
print("=" * 50)
print(f"{'Parameter':>15s}  {'R-hat':>8s}  {'ESS':>8s}")
print("-" * 50)
for p in param_names:
    r_hat = diag.get("r_hat", {}).get(p, float("nan"))
    ess = diag.get("ess", {}).get(p, float("nan"))
    print(f"{p:>15s}  {r_hat:8.4f}  {ess:8.0f}")

n_div = diag.get("divergences", diag.get("num_divergences", 0))
print(f"\nDivergences: {n_div}")

In [None]:
# ArviZ diagnostic plots
display_arviz_diagnostics(result, param_names, fast_mode=FAST_MODE)

In [None]:
# Parameter recovery check with posteriors
posterior = result.posterior_samples

print("\nParameter Recovery with Bayesian Inference")
print("=" * 70)
print(f"{'Parameter':>15s}  {'True':>10s}  {'Median':>10s}  {'95% CI':>24s}  {'Recovered?':>12s}")
print("-" * 70)

for name in param_names:
    true_val = true_params.get(name, 1.0)
    samples = posterior[name]
    median = float(np.median(samples))
    lo = float(np.percentile(samples, 2.5))
    hi = float(np.percentile(samples, 97.5))
    
    # Check if true value is within 95% CI
    recovered = "✓" if lo <= true_val <= hi else "✗"
    
    print(f"{name:>15s}  {true_val:10.4g}  {median:10.4g}  [{lo:.4g}, {hi:.4g}]  {recovered:>12s}")

## 8. TensorialEPM: N₁(t) Evolution

Using TensorialEPM, we can predict how the first normal stress difference N₁ evolves during startup.

In [None]:
# Get median parameters
median_params = {name: float(np.median(posterior[name])) for name in param_names}

# Create TensorialEPM
model_tensor = TensorialEPM(
    L=16 if FAST_MODE else 32,
    dt=0.01,
    mu=median_params["mu"],
    nu=0.48,
    tau_pl=median_params["tau_pl"],
    sigma_c_mean=median_params["sigma_c_mean"],
    sigma_c_std=median_params["sigma_c_std"],
)

# Predict startup with N₁
print("Running TensorialEPM startup simulation...")
result_tensor = model_tensor.predict(rheo_fine, smooth=True, seed=SEED)

sigma_xy_t = np.array(result_tensor.y)
N1_t = np.array(result_tensor.metadata.get("N1", np.zeros_like(sigma_xy_t)))

print(f"σ_xy range: {sigma_xy_t.min():.2f} – {sigma_xy_t.max():.2f} Pa")
print(f"N₁ range: {N1_t.min():.2f} – {N1_t.max():.2f} Pa")

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Left: Shear stress
ax1.plot(time_fine, sigma_xy_t, "-", lw=2, color="C0", label="σ_xy (TensorialEPM)")
ax1.plot(time_data, stress_noisy, "ko", markersize=4, alpha=0.5, label="Data")
ax1.set_xlabel("Time [s]")
ax1.set_ylabel("Shear stress σ_xy [Pa]")
ax1.set_title("Shear Stress Evolution")
ax1.legend()
ax1.grid(True, alpha=0.3)

# Right: Normal stress
ax2.plot(time_fine, N1_t, "-", lw=2, color="C2", label="N₁ = σ_xx - σ_yy")
ax2.set_xlabel("Time [s]")
ax2.set_ylabel("Normal stress N₁ [Pa]")
ax2.set_title("First Normal Stress Difference")
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
display(fig)
plt.close(fig)

In [None]:
# N₁/σ_xy ratio over time
# Avoid division by zero
mask = sigma_xy_t > 1e-6
N1_ratio = np.zeros_like(sigma_xy_t)
N1_ratio[mask] = N1_t[mask] / sigma_xy_t[mask]

fig, ax = plt.subplots(figsize=(9, 5))
ax.plot(time_fine[mask], N1_ratio[mask], "-", lw=2, color="C3")
ax.set_xlabel("Time [s]")
ax.set_ylabel("N₁ / σ_xy")
ax.set_title("Normal Stress Ratio During Startup")
ax.grid(True, alpha=0.3)

# Steady-state ratio
ss_ratio = np.mean(N1_ratio[-20:])
ax.axhline(ss_ratio, color="gray", linestyle="--", alpha=0.7, 
           label=f"Steady-state: {ss_ratio:.3f}")
ax.legend()

plt.tight_layout()
display(fig)
plt.close(fig)

## 9. Key Takeaways

1. **Stress overshoot in EPM** arises from avalanche onset — the transition from elastic loading to cascading plasticity
2. **γ_peak** (peak strain) is approximately the yield strain of the material
3. **Parameter recovery** from startup data is reliable when overshoot is well-resolved
4. **N₁ evolution** shows normal stress buildup during shear startup
5. The **N₁/σ_xy ratio** approaches a steady value that characterizes the material's nonlinearity

## Next Steps

- **Notebook 04**: Creep response and yield stress estimation
- **Notebook 05**: Stress relaxation and relaxation spectrum
- **Notebook 06**: Visualization of avalanche dynamics

In [None]:
# Save results
output_dir = os.path.join("..", "outputs", "epm", "startup")
os.makedirs(output_dir, exist_ok=True)

with open(os.path.join(output_dir, "nlsq_params_startup.json"), "w") as f:
    json.dump(median_params, f, indent=2)

print(f"Results saved to {output_dir}/")