# DMT Startup Shear: Stress Overshoot and Structure Breakdown

> **Handbook:** See [DMT Startup Protocol](../../docs/source/models/dmt/dmt.rst#start-up-of-steady-shear) for complete governing equations and overshoot mechanism.

## Physical Context

In **startup shear**, a constant shear rate $\dot{\gamma}$ is suddenly applied to a material initially at rest ($\lambda = 1$, fully structured). For thixotropic materials with elastic storage (DMT with Maxwell backbone), this produces a characteristic **stress overshoot**:

1. **Initial elastic loading**: $\sigma \approx G_0 \dot{\gamma} t$ (linear rise while $\lambda \approx 1$)
2. **Structure breakdown begins**: $\lambda$ decreases as rejuvenation term $a\lambda|\dot{\gamma}|^c/t_{\text{eq}}$ activates
3. **Viscosity drop**: $\eta(\lambda)$ decreases exponentially (exponential closure) or polynomially (HB closure)
4. **Peak stress**: Overshoot occurs when viscous relaxation catches up to elastic loading
5. **Steady state**: $\sigma \to \sigma_{\text{ss}}$, $\lambda \to \lambda_{\text{ss}}$

### Key Physics

- **Stress overshoot requires elasticity**: Without `include_elasticity=True`, the model produces monotonic approach to steady state
- **Overshoot ratio increases with $\dot{\gamma}$**: Higher shear rates → more dramatic transients
- **Thixotropic signature**: Time-dependent evolution distinguishes DMT from purely viscoelastic models (e.g., standard Maxwell)

### Industrial Relevance

- **Emulsion processing**: Startup transients affect pipeline pressure drops
- **3D printing**: Extrusion overshoot controls filament uniformity
- **Coating flows**: Initial shear determines layer quality

## Learning Objectives

- Understand stress overshoot physics in thixotropic materials
- Model structure breakdown during startup shear
- Apply NLSQ + Bayesian workflow to transient protocols
- Analyze the role of Maxwell elasticity in overshoot behavior

## Prerequisites

- Notebook 01: DMT flow curves and calibration

## Runtime Estimate

- Data generation + NLSQ: ~1-2 minutes
- Bayesian inference (1 chain, 200+500 samples): ~2-3 minutes
- Full inference (4 chains, 1000+2000 samples): ~10-15 minutes

## Setup

In [None]:
# Colab setup
import sys

if "google.colab" in sys.modules:
    !pip install -q rheojax
    import os
    os.environ["JAX_ENABLE_X64"] = "1"

In [None]:
# Imports
import os
import sys

from rheojax.core.jax_config import safe_import_jax, verify_float64

jax, jnp = safe_import_jax()
verify_float64()

from pathlib import Path

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display

from rheojax.models import DMTLocal

# Shared plotting utilities
sys.path.insert(0, os.path.dirname(os.path.abspath("")))
from utils.plotting_utils import (
    display_arviz_diagnostics,
    plot_nlsq_fit,
    plot_posterior_predictive,
)

FAST_MODE = os.environ.get("FAST_MODE", "1") == "1"

%matplotlib inline

print(f"JAX version: {jax.__version__}")
print(f"JAX devices: {jax.devices()}")
print(f"Float64 enabled: {jax.config.read('jax_enable_x64')}")
print(f"FAST_MODE: {FAST_MODE}")

## Theory: Startup Shear in DMT Model

### Governing Equations

During startup shear at constant rate $\dot{\gamma}$, the DMT model with Maxwell elasticity evolves as:

**Stress evolution (Maxwell backbone):**
$$
\frac{d\sigma}{dt} = G \dot{\gamma} - \frac{\sigma}{\theta_1(\lambda)}
$$

where $\theta_1(\lambda) = \eta(\lambda)/G$ is the structure-dependent relaxation time.

**Structure evolution:**
$$
\frac{d\lambda}{dt} = \frac{1 - \lambda}{t_{eq}} - \frac{a \lambda |\dot{\gamma}|^c}{t_{eq}}
$$

**Viscosity (exponential closure):**
$$
\eta(\lambda) = \eta_\infty \left(\frac{\eta_0}{\eta_\infty}\right)^\lambda
$$

### Stress Overshoot Mechanism

1. **Initial loading**: At $t=0$, $\lambda=1$ (fully structured), $\sigma=0$
2. **Elastic buildup**: Stress increases rapidly as $\sigma \approx G \dot{\gamma} t$
3. **Structure breakdown**: $\lambda$ decreases due to shear-induced rejuvenation
4. **Viscosity drop**: $\eta(\lambda)$ decreases, relaxation time $\theta_1$ shortens
5. **Overshoot**: Peak stress $\sigma_{peak}$ occurs when relaxation catches up to loading
6. **Steady state**: $\sigma \to \sigma_{ss}$, $\lambda \to \lambda_{ss}$

**Key physics**: Stress overshoot requires elastic storage (Maxwell backbone). Without elasticity (`include_elasticity=False`), stress increases monotonically to steady state.

## Data Generation

In [None]:
# Calibrated parameters from emulsion φ=0.80 flow curve (Notebook 01)
calib_params = {
    "eta_0": 1.5e4,
    "eta_inf": 0.3,
    "a": 0.8,
    "c": 0.7,
    "G0": 500.0,
    "m_G": 1.0,
    "t_eq": 50.0,
}

# Create DMT model with Maxwell elasticity
model_true = DMTLocal(closure="exponential", include_elasticity=True)

# Set calibrated parameters
for name, value in calib_params.items():
    model_true.parameters[name].value = value

print("True model parameters:")
for name, param in model_true.parameters.items():
    print(f"  {name}: {param.value:.4g}")

In [None]:
# Generate startup shear data at 4 shear rates
gamma_dots = jnp.array([0.1, 1.0, 10.0, 100.0])
t_end = 200.0
dt=1.0

# Storage for synthetic data
startup_data = {}

np.random.seed(42)

for gamma_dot in gamma_dots:
    # Simulate startup
    t, stress, fluidity = model_true.simulate_startup(
        gamma_dot=float(gamma_dot),
        t_end=t_end,
        dt=dt,
    )
    
    # Add 3% Gaussian noise
    noise_level = 0.03
    stress_noisy = stress + noise_level * jnp.std(stress) * np.random.randn(len(stress))
    
    startup_data[float(gamma_dot)] = {
        "t": t,
        "stress": stress,
        "stress_noisy": stress_noisy,
        "fluidity": fluidity,
    }

print(f"Generated startup data at {len(gamma_dots)} shear rates")

In [None]:
# Plot all startup curves
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()

for idx, gamma_dot in enumerate(gamma_dots):
    data = startup_data[float(gamma_dot)]
    ax = axes[idx]
    
    ax.plot(data["t"], data["stress"], "k-", linewidth=2, label="True", alpha=0.6)
    ax.plot(data["t"], data["stress_noisy"], "o", markersize=3, alpha=0.5, label="Noisy data")
    
    # Mark overshoot peak
    peak_idx = jnp.argmax(data["stress"])
    ax.plot(data["t"][peak_idx], data["stress"][peak_idx], "r*", markersize=15, label="Peak")
    
    ax.set_xlabel("Time (s)", fontsize=11)
    ax.set_ylabel("Stress (Pa)", fontsize=11)
    ax.set_title(f"$\\dot{{\\gamma}}$ = {gamma_dot:.1f} s$^{{-1}}$", fontsize=12)
    ax.legend(fontsize=9)
    ax.grid(True, alpha=0.3)

fig.suptitle("DMT Startup Shear: Stress Overshoot", fontsize=14, fontweight="bold")
plt.tight_layout()
display(fig)
plt.close(fig)

## NLSQ Fitting

Fit to a single shear rate ($\dot{\gamma} = 10.0$ s$^{-1}$) using NLSQ optimization.

In [None]:
# Select single shear rate for fitting
gamma_dot_fit = 10.0
data_fit = startup_data[gamma_dot_fit]

t_fit = data_fit["t"]
stress_fit = data_fit["stress_noisy"]

print(f"Fitting to startup data at γ̇ = {gamma_dot_fit} s⁻¹")
print(f"Data points: {len(t_fit)}")

In [None]:
# Create fresh model for fitting
model_fit = DMTLocal(closure="exponential", include_elasticity=True)

# Helper function for computing fit quality
def compute_fit_quality(y_true, y_pred):
    """Compute R² and RMSE."""
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    residuals = y_true - y_pred
    if y_true.ndim > 1:
        residuals = residuals.ravel()
        y_true = y_true.ravel()
    ss_res = np.sum(residuals**2)
    ss_tot = np.sum((y_true - np.mean(y_true))**2)
    r2 = 1.0 - ss_res / ss_tot if ss_tot > 0 else 0.0
    rmse = np.sqrt(np.mean(residuals**2))
    return {"R2": r2, "RMSE": rmse}

# NLSQ fit
model_fit.fit(
    t_fit,
    stress_fit,
    test_mode="startup",
    gamma_dot=gamma_dot_fit,
)

# Compute R²
stress_pred = model_fit.predict(t_fit, test_mode="startup", gamma_dot=gamma_dot_fit)
metrics = compute_fit_quality(stress_fit, stress_pred)

print("\nNLSQ fitting complete")
print(f"R² = {metrics['R2']:.6f}")

In [None]:
# Compare true vs fitted parameters
print("\nParameter comparison:")
print(f"{'Parameter':<10} {'True':>12} {'Fitted':>12} {'Rel. Error':>12}")
print("-" * 50)

for name in calib_params.keys():
    true_val = calib_params[name]
    fitted_val = model_fit.parameters[name].value
    rel_error = abs(fitted_val - true_val) / true_val * 100
    print(f"{name:<10} {true_val:>12.4g} {fitted_val:>12.4g} {rel_error:>11.2f}%")

In [None]:
# Plot fitted vs data
stress_pred = model_fit.predict(t_fit, test_mode="startup", gamma_dot=gamma_dot_fit)

fig, ax = plt.subplots(figsize=(10, 6))

ax.plot(t_fit, stress_fit, "o", markersize=4, alpha=0.5, label="Noisy data")
ax.plot(t_fit, data_fit["stress"], "k--", linewidth=2, alpha=0.6, label="True")
ax.plot(t_fit, stress_pred, "r-", linewidth=2, label="NLSQ fit")

ax.set_xlabel("Time (s)", fontsize=12)
ax.set_ylabel("Stress (Pa)", fontsize=12)
ax.set_title(f"NLSQ Fit: Startup at $\\dot{{\\gamma}}$ = {gamma_dot_fit} s$^{{-1}}$ (R² = {metrics['R2']:.6f})", fontsize=13)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

plt.tight_layout()
display(fig)
plt.close(fig)

## Bayesian Inference

Warm-start from NLSQ and run NUTS to quantify parameter uncertainty.

In [None]:
bayesian_completed = False

if not FAST_MODE:
    # Bayesian inference with warm-start
    # Set protocol-specific attributes for startup
    model_fit._gamma_dot_applied = gamma_dot_fit
    model_fit._startup_lam_init = 1.0

    result_bayes = model_fit.fit_bayesian(
        t_fit,
        stress_fit,
        test_mode="startup",
        num_warmup=200,
        num_samples=500,
        num_chains=1,
        seed=42,
    )

    print("Bayesian inference complete")
    bayesian_completed = True
else:
    print("FAST_MODE: Skipping Bayesian inference")

In [None]:
if bayesian_completed:
    # Convergence diagnostics using ArviZ
    import arviz as az

    # Convert to InferenceData first
    idata = az.from_dict(
        posterior={k: v[None, :] if v.ndim == 1 else v[None, :, :] for k, v in result_bayes.posterior_samples.items()},
    )

    # Compute summary statistics
    summary = az.summary(idata)
    print("\nConvergence Diagnostics:")
    print(summary[["mean", "sd", "hdi_3%", "hdi_97%", "ess_bulk", "ess_tail", "r_hat"]])
else:
    print("FAST_MODE: Skipping convergence diagnostics")

In [None]:
if bayesian_completed:
    display_arviz_diagnostics(result_bayes, list(calib_params.keys()), fast_mode=FAST_MODE)
else:
    print("FAST_MODE: Skipping ArviZ diagnostics")

### Posterior Predictive Check

In [None]:
if bayesian_completed:
    # Draw posterior samples and simulate
    n_draws = 100
    param_names = list(calib_params.keys())

    # Flatten posterior samples (single chain)
    posterior_flat = {name: result_bayes.posterior_samples[name].flatten() for name in param_names}

    # Random draws
    indices = np.random.choice(len(posterior_flat[param_names[0]]), size=n_draws, replace=False)

    predictions = []
    for idx in indices:
        # Create model with posterior parameters
        model_post = DMTLocal(closure="exponential", include_elasticity=True)
        for name in param_names:
            model_post.parameters[name].value = float(posterior_flat[name][idx])

        # Simulate using the model's simulate_startup method
        t_sim, stress_sim, _ = model_post.simulate_startup(
            gamma_dot=gamma_dot_fit,
            t_end=t_end,
            dt=dt,
        )
        predictions.append(stress_sim)

    predictions = jnp.array(predictions)

    # Compute median and 95% CI
    median_pred = jnp.median(predictions, axis=0)
    lower_pred = jnp.percentile(predictions, 2.5, axis=0)
    upper_pred = jnp.percentile(predictions, 97.5, axis=0)

    print(f"Posterior predictive: {n_draws} samples")
else:
    print("FAST_MODE: Skipping posterior predictive sampling")

In [None]:
if bayesian_completed:
    # Plot posterior predictive
    fig, ax = plt.subplots(figsize=(10, 6))

    ax.plot(t_fit, stress_fit, "o", markersize=4, alpha=0.5, label="Noisy data")
    ax.plot(t_fit, data_fit["stress"], "k--", linewidth=2, alpha=0.6, label="True")
    ax.plot(t_sim, median_pred, "r-", linewidth=2, label="Posterior median")
    ax.fill_between(
        t_sim,
        lower_pred,
        upper_pred,
        color="red",
        alpha=0.2,
        label="95% CI",
    )

    ax.set_xlabel("Time (s)", fontsize=12)
    ax.set_ylabel("Stress (Pa)", fontsize=12)
    ax.set_title("Posterior Predictive Check: Startup Shear", fontsize=13)
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    display(fig)
    plt.close(fig)
else:
    print("FAST_MODE: Skipping posterior predictive plot")

## Physics Analysis

### Stress Overshoot Ratio vs Shear Rate

In [None]:
# Compute overshoot ratio for each shear rate
overshoot_ratios = []

for gamma_dot in gamma_dots:
    data = startup_data[float(gamma_dot)]
    stress = data["stress"]
    
    # Peak stress
    sigma_peak = jnp.max(stress)
    
    # Steady-state stress (average of last 10%)
    n_ss = int(len(stress) * 0.1)
    sigma_ss = jnp.mean(stress[-n_ss:])
    
    # Overshoot ratio
    ratio = sigma_peak / sigma_ss
    overshoot_ratios.append(ratio)
    
    print(f"γ̇ = {gamma_dot:>6.1f} s⁻¹:  σ_peak = {sigma_peak:>8.2f} Pa,  σ_ss = {sigma_ss:>8.2f} Pa,  Ratio = {ratio:.3f}")

overshoot_ratios = jnp.array(overshoot_ratios)

In [None]:
# Plot overshoot ratio vs shear rate
fig, ax = plt.subplots(figsize=(8, 6))

ax.plot(gamma_dots, overshoot_ratios, "o-", markersize=8, linewidth=2)

ax.set_xscale("log")
ax.set_xlabel("Shear rate $\\dot{\\gamma}$ (s$^{-1}$)", fontsize=12)
ax.set_ylabel("Overshoot ratio $\\sigma_{peak}/\\sigma_{ss}$", fontsize=12)
ax.set_title("Stress Overshoot Ratio vs Shear Rate", fontsize=13)
ax.grid(True, alpha=0.3, which="both")

plt.tight_layout()
display(fig)
plt.close(fig)

### Structure Breakdown Dynamics

In [None]:
# Plot structure parameter λ(t) for all shear rates
fig, ax = plt.subplots(figsize=(10, 6))

colors = plt.cm.viridis(np.linspace(0, 1, len(gamma_dots)))

for idx, gamma_dot in enumerate(gamma_dots):
    data = startup_data[float(gamma_dot)]
    lambda_t = 1.0 - data["fluidity"]  # λ = 1 - f
    
    ax.plot(
        data["t"],
        lambda_t,
        "-",
        linewidth=2,
        color=colors[idx],
        label=f"$\\dot{{\\gamma}}$ = {gamma_dot:.1f} s$^{{-1}}$",
    )

ax.set_xlabel("Time (s)", fontsize=12)
ax.set_ylabel("Structure parameter $\\lambda$", fontsize=12)
ax.set_title("Structure Breakdown During Startup Shear", fontsize=13)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_ylim([0, 1.05])

plt.tight_layout()
display(fig)
plt.close(fig)

### Role of Maxwell Elasticity

Compare DMT-Maxwell (with overshoot) vs DMT-Viscous (monotonic approach).

In [None]:
# Create DMT model WITHOUT elasticity
model_viscous = DMTLocal(closure="exponential", include_elasticity=False)

# Set same parameters (G0 and m_G are unused)
for name, value in calib_params.items():
    if name in ["G0", "m_G"]:
        continue
    model_viscous.parameters[name].value = value

# Simulate startup at γ̇ = 10.0
t_visc, stress_visc, _ = model_viscous.simulate_startup(
    gamma_dot=10.0,
    t_end=t_end,
    dt=dt,
)

print("Simulated DMT-Viscous (no elasticity) startup")

In [None]:
# Compare Maxwell vs Viscous
data_maxwell = startup_data[10.0]

fig, ax = plt.subplots(figsize=(10, 6))

ax.plot(
    data_maxwell["t"],
    data_maxwell["stress"],
    "b-",
    linewidth=2,
    label="DMT-Maxwell (elastic, with overshoot)",
)
ax.plot(
    t_visc,
    stress_visc,
    "r--",
    linewidth=2,
    label="DMT-Viscous (no elasticity, monotonic)",
)

ax.set_xlabel("Time (s)", fontsize=12)
ax.set_ylabel("Stress (Pa)", fontsize=12)
ax.set_title("Effect of Maxwell Elasticity on Startup Response ($\\dot{\\gamma}$ = 10.0 s$^{-1}$)", fontsize=13)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

plt.tight_layout()
display(fig)
plt.close(fig)

## Save Results

In [None]:
# Create output directory
output_dir = Path("examples/outputs/dmt/startup")
output_dir.mkdir(parents=True, exist_ok=True)

# Save overshoot analysis
overshoot_file = output_dir / "overshoot_analysis.npz"
np.savez(
    overshoot_file,
    gamma_dots=np.array(gamma_dots),
    overshoot_ratios=np.array(overshoot_ratios),
)

print(f"Saved overshoot analysis to {overshoot_file}")

# Save fitted parameters
params_file = output_dir / "fitted_parameters.txt"
with open(params_file, "w") as f:
    f.write("DMT Startup Shear - Fitted Parameters\n")
    f.write("=" * 50 + "\n\n")
    f.write(f"Shear rate: {gamma_dot_fit} s^-1\n")
    f.write(f"NLSQ R2: {metrics['R2']:.6f}\n\n")
    f.write(f"{'Parameter':<10} {'Value':>12} {'Unit':>10}\n")
    f.write("-" * 35 + "\n")
    for name, param in model_fit.parameters.items():
        f.write(f"{name:<10} {param.value:>12.4g} {'-':>10}\n")

print(f"Saved fitted parameters to {params_file}")

# Save posterior samples (only if Bayesian ran)
if bayesian_completed:
    posterior_file = output_dir / "posterior_samples.npz"
    np.savez(posterior_file, **{k: np.array(v) for k, v in result_bayes.posterior_samples.items()})
    print(f"Saved posterior samples to {posterior_file}")
else:
    print("FAST_MODE: Skipping posterior samples save")

## Key Takeaways

### 1. Stress Overshoot Mechanism
- **Maxwell elasticity required**: Stress overshoot only occurs with elastic storage (`include_elasticity=True`)
- **Viscous models**: Without elasticity, stress increases monotonically to steady state
- **Competition**: Overshoot arises from elastic loading competing with viscous relaxation

### 2. Shear Rate Dependence
- **Overshoot ratio increases with $\dot{\gamma}$**: Higher shear rates → larger peaks
- **Physical interpretation**: Faster loading outpaces structure breakdown initially
- **Typical values**: Overshoot ratio ~ 1.1-1.3 for soft materials

### 3. Structure Breakdown Timescale
- **Characteristic time**: $t_{\text{break}} \sim 1/(a \cdot \dot{\gamma}^c)$
- **Faster at high $\dot{\gamma}$**: Structure parameter $\lambda$ decays more rapidly
- **Steady state**: $\lambda_{\text{ss}}$ decreases with shear rate (more rejuvenation)

### 4. Bayesian Workflow
- **NLSQ warm-start**: Critical for convergence in transient protocols
- **Parameter correlations**: $\eta_0$ and $G_0$ are typically correlated (set overshoot magnitude)
- **Posterior predictive**: 95% CI captures data within noise level

### 5. Experimental Signatures
- **Thixotropic materials**: Stress overshoot is hallmark signature
- **Examples**: Emulsions, suspensions, gels, biological fluids
- **Quantification**: Overshoot ratio and time-to-peak are key metrics

## Further Reading

### DMT Model Documentation

- [DMT Overview](../../docs/source/models/dmt/index.rst) — Model hierarchy and selection guide
- [Startup Protocol Equations](../../docs/source/models/dmt/dmt.rst#start-up-of-steady-shear) — Complete mathematical derivation

### Key References

1. **de Souza Mendes, P. R. (2009).** "Modeling the thixotropic behavior of structured fluids." *J. Non-Newtonian Fluid Mech.*, 164, 66-75.

2. **Mujumdar, A., Beris, A. N., & Metzner, A. B. (2002).** "Transient phenomena in thixotropic systems." *J. Non-Newtonian Fluid Mech.*, 102, 157-178. — Stress overshoot mechanisms

3. **Thompson, R. L., & de Souza Mendes, P. R. (2014).** "Thixotropic behavior of elasto-viscoplastic materials." *Physics of Fluids*, 26, 023101.

4. **Larson, R. G., & Wei, Y. (2019).** "A review of thixotropy and its rheological modeling." *J. Rheology*, 63, 477-501.