# EPM Flow Curve Fitting

**Elasto-Plastic Model (EPM) — Steady-state flow curve with Lattice and Tensorial variants**

## Learning Objectives

- Understand the EPM mesoscopic physics: lattice, Eshelby propagator, plastic avalanches
- Fit steady-state flow curves to real emulsion data using NLSQ with LatticeEPM
- Perform Bayesian inference with NUTS and evaluate convergence
- Use TensorialEPM for forward predictions of normal stress N₁

## Prerequisites

- Basic familiarity with rheological flow curves σ(γ̇)
- Understanding of NLSQ fitting (see `01-basic-maxwell.ipynb`)
- Understanding of Bayesian inference basics (see `05-bayesian-basics.ipynb`)

## Estimated Runtime

- Fast demo (1 chain): ~3-4 min
- Full run (4 chains): ~8-12 min

## 1. Setup & Imports

In [None]:
# Colab setup
import sys

IN_COLAB = "google.colab" in sys.modules

if IN_COLAB:
    %pip install -q rheojax
    import os
    os.environ["JAX_ENABLE_X64"] = "true"
    print("RheoJAX installed successfully.")

In [None]:
%matplotlib inline
import warnings
import time

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display

from rheojax.core.jax_config import safe_import_jax, verify_float64
from rheojax.core.data import RheoData
from rheojax.models.epm.lattice import LatticeEPM
from rheojax.models.epm.tensor import TensorialEPM

jax, jnp = safe_import_jax()
verify_float64()

warnings.filterwarnings("ignore", category=FutureWarning)
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")

def compute_fit_quality(y_true, y_pred):
    """Compute R² and RMSE."""
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    residuals = y_true - y_pred
    if y_true.ndim > 1:
        residuals = residuals.ravel()
        y_true = y_true.ravel()
    ss_res = np.sum(residuals**2)
    ss_tot = np.sum((y_true - np.mean(y_true))**2)
    r2 = 1.0 - ss_res / ss_tot if ss_tot > 0 else 0.0
    rmse = np.sqrt(np.mean(residuals**2))
    return {"R2": r2, "RMSE": rmse}

## 2. Theory: EPM Mesoscopic Physics

The **Elasto-Plastic Model (EPM)** is a mesoscopic approach for amorphous solids (glasses, gels, dense emulsions) that explicitly resolves:

1. **Spatial heterogeneity**: A 2D lattice of elastoplastic blocks (L × L)
2. **Plastic avalanches**: Cascading yielding events when local stress exceeds thresholds
3. **Long-range stress redistribution**: Eshelby quadrupolar propagator

### Lattice Dynamics

Each lattice site $i$ has:
- Local stress $\sigma_i$
- Local yield threshold $\sigma_{c,i}$ drawn from a disorder distribution

### Evolution Equation

$$\frac{\partial \sigma_i}{\partial t} = \mu \dot{\gamma} - \frac{\sigma_i}{\tau_{\text{pl}}} \cdot H(|\sigma_i| - \sigma_{c,i}) + \sum_j G_{ij} \cdot \delta\sigma_j^{\text{pl}}$$

where:
- $\mu$: Shear modulus (elastic loading rate)
- $\tau_{\text{pl}}$: Plastic relaxation time
- $H(\cdot)$: Heaviside step function (yielding criterion)
- $G_{ij}$: Eshelby propagator (quadrupolar, $\sim 1/r^2$)

### Disorder Distribution

Yield thresholds follow a Gaussian distribution:

$$\sigma_{c,i} \sim \mathcal{N}(\sigma_{c,\text{mean}}, \sigma_{c,\text{std}}^2)$$

The ratio $\sigma_{c,\text{std}}/\sigma_{c,\text{mean}}$ controls the material's disorder strength:
- Low disorder → sharp yielding, stress localization
- High disorder → gradual yielding, distributed plasticity

### Parameters

| Parameter | Symbol | Physical Meaning | Typical Range |
|-----------|--------|------------------|---------------|
| `mu` | μ | Shear modulus | 0.5–10 Pa |
| `tau_pl` | τ_pl | Plastic relaxation time | 0.1–10 s |
| `sigma_c_mean` | σ_c,mean | Mean yield threshold | 1–50 Pa |
| `sigma_c_std` | σ_c,std | Disorder strength | 0.05–0.5 Pa |

## 3. Load Flow Curve Data

We use a flow curve from a concentrated oil-in-water emulsion (φ = 0.80). This system exhibits:
- A clear yield stress plateau at low shear rates
- Power-law shear-thinning at high rates

This is ideal for EPM, which was designed for amorphous yield-stress materials.

In [None]:
import os

data_path = os.path.join("..", "data", "flow", "emulsions", "0.80.csv")
if IN_COLAB:
    # Download from repository if running in Colab
    data_path = "0.80.csv"
    if not os.path.exists(data_path):
        print("Please upload 0.80.csv or adjust the path.")

raw = np.loadtxt(data_path, delimiter=",", skiprows=1)
gamma_dot = raw[:, 0]  # Shear rate [1/s]
stress = raw[:, 1]      # Stress [Pa]

print(f"Data points: {len(gamma_dot)}")
print(f"Shear rate range: {gamma_dot.min():.4f} – {gamma_dot.max():.1f} 1/s")
print(f"Stress range: {stress.min():.1f} – {stress.max():.1f} Pa")

In [None]:
fig, ax = plt.subplots(figsize=(8, 5))
ax.loglog(gamma_dot, stress, "ko", markersize=6, label="Emulsion φ=0.80")
ax.set_xlabel("Shear rate [1/s]")
ax.set_ylabel("Stress [Pa]")
ax.set_title("Raw Flow Curve")
ax.grid(True, alpha=0.3, which="both")
ax.legend()
plt.tight_layout()
display(fig)
plt.close(fig)

The data shows a clear **yield stress plateau** at low shear rates (~24 Pa) transitioning to power-law shear-thinning at high rates. This is classic behavior for concentrated emulsions and is well-captured by EPM physics.

## 4. NLSQ Fitting with LatticeEPM

We use **LatticeEPM** for fitting because it supports the full NLSQ + Bayesian pipeline.

Key settings:
- `L=32`: Lattice size (smaller for faster fitting, 64+ for production)
- `smooth=True`: Differentiable yielding for gradient-based optimization

In [None]:
# Initialize LatticeEPM
model = LatticeEPM(
    L=32,      # Lattice size (32 for tutorials, 64+ for production)
    dt=0.01,   # Time step
    mu=1.0,
    tau_pl=1.0,
    sigma_c_mean=1.0,
    sigma_c_std=0.1,
)

# Set physically motivated bounds
model.parameters["mu"].bounds = (0.1, 20.0)
model.parameters["tau_pl"].bounds = (0.01, 50.0)
model.parameters["sigma_c_mean"].bounds = (0.5, 100.0)
model.parameters["sigma_c_std"].bounds = (0.01, 10.0)

print("LatticeEPM initialized")
print(f"  Lattice size: {model.L}×{model.L}")
print(f"  Parameters: {list(model.parameters.keys())}")

In [None]:
# Fit to flow curve
t0 = time.time()
model.fit(gamma_dot, stress, test_mode="flow_curve", method='scipy')
t_nlsq = time.time() - t0

# Compute fit quality
y_pred = model.predict(gamma_dot, test_mode="flow_curve", smooth=True).y
metrics = compute_fit_quality(stress, y_pred)

print(f"NLSQ fit time: {t_nlsq:.2f} s")
print(f"R²: {metrics['R2']:.6f}")
print(f"RMSE: {metrics['RMSE']:.4f} Pa")
print("\nFitted parameters:")
for name in ["mu", "tau_pl", "sigma_c_mean", "sigma_c_std"]:
    val = model.parameters.get_value(name)
    print(f"  {name:15s} = {val:.4g}")

## 5. Parameter Interpretation

The fitted EPM parameters have direct physical meaning:

- **μ (shear modulus)**: Controls the elastic loading rate. Higher μ means faster stress buildup.
- **τ_pl (plastic time)**: Controls how quickly plastic events relax. Shorter τ_pl means faster plastic flow.
- **σ_c,mean (mean yield threshold)**: The average stress required for local yielding. Directly related to the macroscopic yield stress.
- **σ_c,std (disorder strength)**: Controls the width of the yield threshold distribution. Higher values mean more gradual yielding.

In [None]:
# Predict and plot fit quality
gamma_dot_fine = np.logspace(
    np.log10(gamma_dot.min()) - 0.5,
    np.log10(gamma_dot.max()) + 0.2,
    100,
)
stress_pred = model.predict(gamma_dot_fine, test_mode="flow_curve", smooth=True).y

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Left: flow curve
ax1.loglog(gamma_dot, stress, "ko", markersize=6, label="Data")
ax1.loglog(gamma_dot_fine, stress_pred, "-", lw=2, color="C0", label="LatticeEPM fit")
ax1.set_xlabel("Shear rate [1/s]")
ax1.set_ylabel("Stress [Pa]")
ax1.set_title(f"Flow Curve Fit (R² = {metrics['R2']:.4f})")
ax1.legend()
ax1.grid(True, alpha=0.3, which="both")

# Right: residuals
stress_at_data = model.predict(gamma_dot, test_mode="flow_curve", smooth=True).y
residuals = (stress - np.array(stress_at_data)) / stress * 100

ax2.semilogx(gamma_dot, residuals, "o-", markersize=5, alpha=0.7)
ax2.axhline(0, color="black", linestyle="--", alpha=0.5)
ax2.set_xlabel("Shear rate [1/s]")
ax2.set_ylabel("Relative residual [%]")
ax2.set_title("Residual Analysis")
ax2.grid(True, alpha=0.3)

plt.tight_layout()
display(fig)
plt.close(fig)

## 6. Bayesian Inference with NUTS

We use the NLSQ estimates as a **warm-start** for NUTS sampling. This is critical for EPM because:
1. The parameter space is complex with correlations between μ, τ_pl, and σ_c
2. The forward model is computationally expensive (lattice simulation)
3. Warm-start dramatically reduces the required warmup iterations

In [None]:
# Warm-start values from NLSQ
initial_values = {
    name: model.parameters.get_value(name)
    for name in model.parameters.keys()
}
print("Warm-start values:")
for k, v in initial_values.items():
    print(f"  {k}: {v:.4g}")

In [None]:
# --- Fast demo config (change to 4 chains for production) ---
NUM_WARMUP = 200
NUM_SAMPLES = 500
NUM_CHAINS = 1
# NUM_WARMUP = 1000; NUM_SAMPLES = 2000; NUM_CHAINS = 4  # production

print(f"Running Bayesian inference: {NUM_CHAINS} chain(s), {NUM_WARMUP} warmup, {NUM_SAMPLES} samples")

t0 = time.time()
result = model.fit_bayesian(
    gamma_dot,
    stress,
    test_mode="flow_curve",
    num_warmup=NUM_WARMUP,
    num_samples=NUM_SAMPLES,
    num_chains=NUM_CHAINS,
    initial_values=initial_values,
    seed=42,
)
t_bayes = time.time() - t0
print(f"\nBayesian inference time: {t_bayes:.1f} s")

## 7. Convergence Diagnostics

We check the standard convergence criteria:
- **R-hat < 1.05**: Chains have mixed well
- **ESS > 100**: Effective sample size is sufficient
- **No divergences**: Sampler hasn't encountered pathological geometry

In [None]:
diag = result.diagnostics
param_names = ["mu", "tau_pl", "sigma_c_mean", "sigma_c_std"]

print("Convergence Diagnostics")
print("=" * 50)
print(f"{'Parameter':>15s}  {'R-hat':>8s}  {'ESS':>8s}")
print("-" * 50)
for p in param_names:
    r_hat = diag.get("r_hat", {}).get(p, float("nan"))
    ess = diag.get("ess", {}).get(p, float("nan"))
    print(f"{p:>15s}  {r_hat:8.4f}  {ess:8.0f}")

n_div = diag.get("divergences", diag.get("num_divergences", 0))
print(f"\nDivergences: {n_div}")

# Quality check
r_hat_vals = [diag.get("r_hat", {}).get(p, 2.0) for p in param_names]
ess_vals = [diag.get("ess", {}).get(p, 0) for p in param_names]
if max(r_hat_vals) < 1.05 and min(ess_vals) > 100:
    print("\nConvergence: PASSED")
else:
    print("\nConvergence: CHECK REQUIRED (increase num_warmup/num_samples)")

## 8. ArviZ Diagnostic Plots

In [None]:
idata = result.to_inference_data()

# Trace plot
axes = az.plot_trace(idata, var_names=param_names, figsize=(12, 8))
fig = axes.ravel()[0].figure
fig.suptitle("Trace Plots", fontsize=14, y=1.02)
plt.tight_layout()
display(fig)
plt.close(fig)

In [None]:
# Pair plot (parameter correlations)
axes = az.plot_pair(
    idata,
    var_names=param_names,
    kind="scatter",
    divergences=True,
    figsize=(10, 10),
)
fig = axes.ravel()[0].figure
fig.suptitle("Parameter Correlations", fontsize=14, y=1.02)
plt.tight_layout()
display(fig)
plt.close(fig)

In [None]:
# Forest plot (credible intervals)
axes = az.plot_forest(
    idata,
    var_names=param_names,
    combined=True,
    hdi_prob=0.95,
    figsize=(10, 4),
)
fig = axes.ravel()[0].figure
plt.tight_layout()
display(fig)
plt.close(fig)

In [None]:
# Energy plot (NUTS-specific)
if NUM_CHAINS > 1:
    axes = az.plot_energy(idata, figsize=(8, 4))
    fig = axes.ravel()[0].figure if hasattr(axes, 'ravel') else plt.gcf()
    plt.tight_layout()
    display(fig)
    plt.close(fig)
else:
    print("Energy plot requires multiple chains. Run with NUM_CHAINS=4 for full diagnostics.")

## 9. Posterior Predictive Check

In [None]:
# Get credible intervals
intervals = model.get_credible_intervals(
    result.posterior_samples, credibility=0.95
)

# Sample posterior predictions
posterior = result.posterior_samples
n_draws = min(100, len(list(posterior.values())[0]))  # Limit for speed
gamma_dot_pred = np.logspace(
    np.log10(gamma_dot.min()) - 0.3,
    np.log10(gamma_dot.max()) + 0.2,
    50,
)

# Draw posterior predictive samples
print(f"Computing {n_draws} posterior predictive samples...")
pred_samples = []
for i in range(n_draws):
    # Set parameters from posterior
    for name in param_names:
        model.parameters.set_value(name, float(posterior[name][i]))
    
    # Predict
    pred_i = model.predict(gamma_dot_pred, test_mode="flow_curve", smooth=True).y
    pred_samples.append(np.array(pred_i))

pred_samples = np.array(pred_samples)
pred_median = np.median(pred_samples, axis=0)
pred_lo = np.percentile(pred_samples, 2.5, axis=0)
pred_hi = np.percentile(pred_samples, 97.5, axis=0)

print("Done.")

In [None]:
fig, ax = plt.subplots(figsize=(9, 6))
ax.fill_between(
    gamma_dot_pred, pred_lo, pred_hi, alpha=0.3, color="C0", label="95% CI"
)
ax.loglog(gamma_dot_pred, pred_median, "-", lw=2, color="C0", label="Posterior median")
ax.loglog(gamma_dot, stress, "ko", markersize=6, label="Data")
ax.set_xlabel("Shear rate [1/s]")
ax.set_ylabel("Stress [Pa]")
ax.set_title("Posterior Predictive Check")
ax.legend()
ax.grid(True, alpha=0.3, which="both")
plt.tight_layout()
display(fig)
plt.close(fig)

## 10. TensorialEPM Forward Predictions (Sidebar)

The **TensorialEPM** model tracks the full stress tensor [σ_xx, σ_yy, σ_xy], enabling prediction of **normal stress differences** N₁ = σ_xx - σ_yy.

**Note:** TensorialEPM currently supports forward predictions only (fitting is not yet implemented). We use the parameters calibrated from LatticeEPM.

In [None]:
# Get median parameters from posterior
median_params = {name: float(np.median(posterior[name])) for name in param_names}
print("Using median posterior parameters:")
for k, v in median_params.items():
    print(f"  {k}: {v:.4g}")

In [None]:
# Create TensorialEPM with calibrated parameters
model_tensor = TensorialEPM(
    L=32,
    dt=0.01,
    mu=median_params["mu"],
    nu=0.48,
    tau_pl=median_params["tau_pl"],
    sigma_c_mean=median_params["sigma_c_mean"],
    sigma_c_std=median_params["sigma_c_std"],
)

# Predict flow curve with N₁
print("Running TensorialEPM forward prediction...")
result_tensor = model_tensor.predict(gamma_dot_fine, test_mode="flow_curve", smooth=True, seed=42)

sigma_xy_tensor = result_tensor.y
N1_tensor = result_tensor.metadata.get("N1", np.zeros_like(sigma_xy_tensor))

print(f"σ_xy range: {np.min(sigma_xy_tensor):.2f} – {np.max(sigma_xy_tensor):.2f} Pa")
print(f"N₁ range: {np.min(N1_tensor):.2f} – {np.max(N1_tensor):.2f} Pa")

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Shear stress comparison
ax1.loglog(gamma_dot, stress, "ko", markersize=6, label="Data")
ax1.loglog(gamma_dot_fine, sigma_xy_tensor, "-", lw=2, color="C1", label="TensorialEPM")
ax1.set_xlabel("Shear rate [1/s]")
ax1.set_ylabel("Shear stress σ_xy [Pa]")
ax1.set_title("Shear Stress: TensorialEPM vs Data")
ax1.legend()
ax1.grid(True, alpha=0.3, which="both")

# Normal stress difference
ax2.loglog(gamma_dot_fine, np.abs(N1_tensor), "s-", lw=2, color="C2", markersize=4, label="N₁ = σ_xx - σ_yy")
ax2.set_xlabel("Shear rate [1/s]")
ax2.set_ylabel("|N₁| [Pa]")
ax2.set_title("First Normal Stress Difference")
ax2.legend()
ax2.grid(True, alpha=0.3, which="both")

plt.tight_layout()
display(fig)
plt.close(fig)

## 11. Save Results

Save the calibrated parameters for use in Notebook 03 (synthetic startup data generation).

In [None]:
import json

output_dir = os.path.join("..", "outputs", "epm", "flow_curve")
os.makedirs(output_dir, exist_ok=True)

# Save NLSQ point estimates (reset to fitted values)
for name in param_names:
    model.parameters.set_value(name, float(np.median(posterior[name])))

nlsq_params = {
    name: float(model.parameters.get_value(name))
    for name in model.parameters.keys()
}
with open(os.path.join(output_dir, "nlsq_params.json"), "w") as f:
    json.dump(nlsq_params, f, indent=2)

# Save posterior samples
posterior_dict = {k: np.array(v).tolist() for k, v in posterior.items()}
with open(os.path.join(output_dir, "posterior_samples.json"), "w") as f:
    json.dump(posterior_dict, f)

print(f"Results saved to {output_dir}/")
print(f"  nlsq_params.json: {len(nlsq_params)} parameters")
print(f"  posterior_samples.json: {n_draws} draws")

## 12. Key Takeaways

1. **EPM captures yield stress physics** via disorder-induced yielding thresholds and plastic avalanches
2. **LatticeEPM** supports full NLSQ + Bayesian fitting; **TensorialEPM** adds normal stress predictions
3. **NLSQ warm-start is critical** for NUTS convergence on this complex model
4. **μ and σ_c,mean** are the most directly interpretable parameters (modulus and yield stress)
5. **τ_pl and σ_c,std** control dynamics and disorder — may show correlations in the posterior
6. **Normal stress N₁** from TensorialEPM provides additional insight into microstructural anisotropy

## Next Steps

- **Notebook 02**: Fit SAOS (oscillation) data with EPM
- **Notebook 03**: Use calibrated parameters to generate synthetic startup data and observe stress overshoot
- **Notebook 06**: Explore the EPM visualization gallery (stress fields, avalanche animations)