# Fluidity Local Model: Startup Shear Flow

## Learning Objectives

1. Understand startup shear protocol for thixotropic yield-stress fluids
2. Learn stress overshoot mechanism as a thixotropic signature
3. Generate synthetic startup data from calibrated parameters
4. Fit startup transient using NLSQ optimization
5. Perform Bayesian inference with NUTS warm-start
6. Visualize fluidity evolution f(t) during stress growth

## Prerequisites

- Notebook 01 (flow curve calibration for parameter initialization)
- examples/basic/01_quickstart.ipynb
- examples/bayesian/01_bayesian_inference.ipynb

## Expected Runtime

- **Fast mode** (NUM_CHAINS=1, NUM_SAMPLES=500): ~2-3 minutes
- **Production mode** (NUM_CHAINS=4, NUM_SAMPLES=2000): ~8-12 minutes

## 1. Setup

In [None]:
import sys
IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    %pip install -q rheojax
    import os
    os.environ["JAX_ENABLE_X64"] = "true"
    print("RheoJAX installed successfully.")

In [None]:
%matplotlib inline
import json
import os
import sys
import time
import warnings
from pathlib import Path

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display

# Add utils to path
sys.path.insert(0, str(Path("..").resolve()))

from rheojax.core.jax_config import safe_import_jax, verify_float64
from rheojax.models.fluidity import FluidityLocal
from utils.fluidity_tutorial_utils import (
    compute_fit_quality,
    generate_synthetic_startup,
    get_fluidity_param_names,
    load_fluidity_parameters,
    print_convergence_summary,
    print_parameter_comparison,
    save_fluidity_results,
    set_model_parameters,
)

jax, jnp = safe_import_jax()
verify_float64()

warnings.filterwarnings("ignore", category=FutureWarning)
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")

## 2. Theory: Startup Shear Flow

### Startup Protocol

In startup shear, a **constant shear rate** γ̇ is suddenly applied to a material initially at rest. The stress σ(t) and fluidity f(t) evolve via coupled ODEs:

$$
\frac{d\sigma}{dt} = G\left(\dot{\gamma} - \sigma f(t)\right)
$$

$$
\frac{df}{dt} = \frac{f_{\text{eq}} - f}{\theta} + a|\dot{\gamma}|^{n_{\text{rejuv}}}(f_{\text{inf}} - f)
$$

### Stress Overshoot Mechanism

The **stress overshoot** is a hallmark of thixotropic materials:

1. **Initial Response** (t ≈ 0): Material is "structured" with low fluidity f ≈ f_eq → stress grows linearly σ ≈ Gγ̇t
2. **Peak Overshoot** (t ≈ t_peak): Stress reaches maximum σ_max > σ_ss as structure begins to break
3. **Breakdown** (t > t_peak): Rejuvenation term dominates → f increases → stress decays
4. **Steady State** (t → ∞): f → f_ss, σ → σ_ss = τ_y + K|γ̇|^n (Herschel-Bulkley)

### Physical Interpretation

- **Overshoot magnitude** (σ_max - σ_ss) / σ_ss: Degree of structural build-up
- **Time to peak** t_peak: Inverse structural breakdown rate ~ 1/(a|γ̇|^n)
- **Steady-state fluidity**: f_ss = (f_eq/θ + a|γ̇|^n f_inf) / (1/θ + a|γ̇|^n)

### Model Assumptions

- Homogeneous flow (local model, no spatial gradients)
- Scalar fluidity (no orientation tensors)
- Instantaneous strain application at t=0

## 3. Load Calibrated Parameters

We attempt to load previously calibrated parameters from the flow_curve notebook (NB 01). If not available, we use physically reasonable default values for a yield-stress fluid.

In [None]:
# Attempt to load calibrated parameters from flow_curve notebook
try:
    calib_params = load_fluidity_parameters("local", "flow_curve")
    print("Loaded calibrated parameters from flow_curve:")
    for name, val in calib_params.items():
        print(f"  {name:10s} = {val:.4g}")
except FileNotFoundError:
    # Use default parameters for a typical yield-stress fluid
    calib_params = {
        "G": 1e6,       # Pa (elastic modulus)
        "tau_y": 100.0, # Pa (yield stress)
        "K": 500.0,     # Pa·s^n (flow consistency)
        "n_flow": 0.5,  # dimensionless (flow exponent)
        "f_eq": 1e-6,   # 1/(Pa·s) (equilibrium fluidity)
        "f_inf": 1e-3,  # 1/(Pa·s) (high-shear fluidity)
        "theta": 10.0,  # s (aging timescale)
        "a": 1.0,       # dimensionless (rejuvenation amplitude)
        "n_rejuv": 1.0, # dimensionless (rejuvenation exponent)
    }
    print("Using default parameters (run flow_curve notebook first for calibrated values):")
    for name, val in calib_params.items():
        print(f"  {name:10s} = {val:.4g}")

# Get parameter names for this model variant
param_names = get_fluidity_param_names("local")
print(f"\nModel variant: local (parameters: {len(param_names)})")

## 4. Generate Synthetic Startup Data

Generate synthetic startup data at γ̇ = 1.0 s⁻¹ using the calibrated model with 3% Gaussian noise. This emulates realistic experimental measurements with measurement uncertainty.

In [None]:
# Create calibrated "true" model
model_true = FluidityLocal()
set_model_parameters(model_true, calib_params)
model_true.fitted_ = True

# Generate synthetic data at γ̇ = 1.0 s⁻¹
GAMMA_DOT = 1.0  # s⁻¹
T_END = 10.0     # s
N_POINTS = 100
NOISE_LEVEL = 0.03  # 3%
SEED = 42

t_data, stress_data = generate_synthetic_startup(
    model_true,
    gamma_dot=GAMMA_DOT,
    t_end=T_END,
    n_points=N_POINTS,
    noise_level=NOISE_LEVEL,
    seed=SEED,
)

print(f"Generated {N_POINTS} startup data points at γ̇={GAMMA_DOT} s⁻¹")
print(f"Time range: [{t_data.min():.3f}, {t_data.max():.2f}] s")
print(f"Stress range: [{stress_data.min():.2f}, {stress_data.max():.2f}] Pa")
print(f"Noise level: {NOISE_LEVEL*100}%")

# Compute expected steady-state stress
tau_y = calib_params["tau_y"]
K = calib_params["K"]
n = calib_params["n_flow"]
sigma_ss_expected = tau_y + K * GAMMA_DOT**n
print(f"\nExpected steady-state stress: {sigma_ss_expected:.2f} Pa")
print(f"Peak stress (overshoot): {stress_data.max():.2f} Pa")
print(f"Overshoot ratio: {(stress_data.max() / sigma_ss_expected - 1)*100:.1f}%")

In [None]:
# Visualize synthetic data
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

# Linear scale
ax1.plot(t_data, stress_data, "o", markersize=4, alpha=0.6, label="Synthetic data (3% noise)")
ax1.axhline(sigma_ss_expected, color="gray", linestyle="--", alpha=0.5, label=f"Expected σ_ss={sigma_ss_expected:.1f} Pa")
ax1.set_xlabel("Time [s]")
ax1.set_ylabel("Stress σ(t) [Pa]")
ax1.set_title(f"Startup Shear — γ̇={GAMMA_DOT} s⁻¹ (Linear)")
ax1.legend()
ax1.grid(True, alpha=0.3)

# Log scale (early time detail)
ax2.semilogx(t_data, stress_data, "o", markersize=4, alpha=0.6)
ax2.axhline(sigma_ss_expected, color="gray", linestyle="--", alpha=0.5)
ax2.set_xlabel("Time [s]")
ax2.set_ylabel("Stress σ(t) [Pa]")
ax2.set_title("Startup Shear (Log Time)")
ax2.grid(True, alpha=0.3)

plt.tight_layout()
display(fig)
plt.close(fig)

## 5. NLSQ Fitting

### 5.1 Fit Startup Transient

Fit the startup stress growth using NLSQ optimization. We must specify `test_mode='startup'` and `gamma_dot=1.0` to activate the transient ODE solver.

In [None]:
# Initialize model
model = FluidityLocal()

# Fit with NLSQ
t0_fit = time.time()
model.fit(
    t_data,
    stress_data,
    test_mode="startup",
    gamma_dot=GAMMA_DOT,
)
t_nlsq = time.time() - t0_fit

print(f"NLSQ fit time: {t_nlsq:.2f} s")
print(f"R²: {model._fit_result.r_squared:.6f}")
print(f"RMSE: {model._fit_result.rmse:.4g} Pa")
print("\nFitted parameters (NLSQ):")
print("=" * 60)
print(f"{'Parameter':<12s}  {'Fitted':<12s}  {'True':<12s}  {'Rel. Error'}")
print("-" * 60)
for name in param_names:
    fitted_val = model.parameters.get_value(name)
    true_val = calib_params.get(name, float('nan'))
    rel_err = abs(fitted_val - true_val) / true_val * 100 if true_val != 0 else float('nan')
    print(f"{name:<12s}  {fitted_val:<12.4g}  {true_val:<12.4g}  {rel_err:>6.1f}%")

### 5.2 Visualize NLSQ Fit

In [None]:
# Generate fine time grid for smooth prediction
t_fine = np.linspace(t_data.min(), t_data.max(), 200)

# Set protocol for prediction
model._gamma_dot_applied = GAMMA_DOT
model._test_mode = "startup"
stress_pred = model.predict(t_fine)

# Plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

# Linear scale
ax1.plot(t_data, stress_data, "ko", markersize=5, alpha=0.5, label="Data")
ax1.plot(t_fine, stress_pred, "-", lw=2, color="C0", label="NLSQ fit")
ax1.axhline(sigma_ss_expected, color="gray", linestyle="--", alpha=0.5, label=f"σ_ss={sigma_ss_expected:.1f} Pa")
ax1.set_xlabel("Time [s]")
ax1.set_ylabel("Stress σ(t) [Pa]")
ax1.set_title(f"NLSQ Fit — R²={model._fit_result.r_squared:.4f}")
ax1.legend()
ax1.grid(True, alpha=0.3)

# Residuals
stress_pred_data = model.predict(t_data)
residuals = stress_data - np.asarray(stress_pred_data).flatten()
ax2.plot(t_data, residuals, "o", markersize=4, alpha=0.6)
ax2.axhline(0, color="gray", linestyle="--", alpha=0.5)
ax2.set_xlabel("Time [s]")
ax2.set_ylabel("Residuals [Pa]")
ax2.set_title(f"Residuals — RMSE={model._fit_result.rmse:.2f} Pa")
ax2.grid(True, alpha=0.3)

plt.tight_layout()
display(fig)
plt.close(fig)

### 5.3 Fluidity Evolution

Visualize the underlying fluidity field f(t) during startup. This shows how the material transitions from structured (low f) to flowing (high f) state.

In [None]:
# Check if model stored trajectory
if hasattr(model, "_trajectory") and model._trajectory is not None:
    traj = model._trajectory
    t_traj = traj["t"]
    f_traj = traj["f"]

    # Compute steady-state fluidity
    f_eq_fit = model.parameters.get_value("f_eq")
    f_inf_fit = model.parameters.get_value("f_inf")
    theta_fit = model.parameters.get_value("theta")
    a_fit = model.parameters.get_value("a")
    n_rejuv_fit = model.parameters.get_value("n_rejuv")

    rate_term = a_fit * GAMMA_DOT**n_rejuv_fit
    f_ss = (f_eq_fit / theta_fit + rate_term * f_inf_fit) / (1.0 / theta_fit + rate_term)

    fig, ax = plt.subplots(figsize=(8, 5))
    ax.plot(t_traj, f_traj, "-", lw=2, label="Fluidity f(t)")
    ax.axhline(f_eq_fit, color="C1", linestyle="--", alpha=0.7, label=f"f_eq={f_eq_fit:.2e}")
    ax.axhline(f_inf_fit, color="C2", linestyle="--", alpha=0.7, label=f"f_inf={f_inf_fit:.2e}")
    ax.axhline(f_ss, color="gray", linestyle=":", alpha=0.5, label=f"f_ss={f_ss:.2e}")
    ax.set_xlabel("Time [s]")
    ax.set_ylabel("Fluidity f(t) [1/(Pa·s)]")
    ax.set_title(f"Fluidity Evolution During Startup (γ̇={GAMMA_DOT} s⁻¹)")
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    display(fig)
    plt.close(fig)
else:
    print("No trajectory data stored. Fluidity evolution not available.")

## 6. Bayesian Inference

### 6.1 Run NUTS with NLSQ Warm-Start

**Critical**: Use NLSQ fit as `initial_values` for fast MCMC convergence. Set `test_mode='startup'` to ensure correct likelihood computation.

In [None]:
# Extract NLSQ estimates as warm-start
initial_values = {name: model.parameters.get_value(name) for name in param_names}
print("Warm-start values (from NLSQ):")
for name, val in initial_values.items():
    print(f"  {name:10s} = {val:.4g}")

# Bayesian configuration
NUM_WARMUP = 200
NUM_SAMPLES = 500
NUM_CHAINS = 1
# NUM_WARMUP = 1000; NUM_SAMPLES = 2000; NUM_CHAINS = 4  # production mode

print(f"\nRunning NUTS: {NUM_CHAINS} chain(s), {NUM_WARMUP} warmup, {NUM_SAMPLES} samples")

# Run Bayesian inference
t0 = time.time()
result = model.fit_bayesian(
    t_data,
    stress_data,
    test_mode="startup",
    gamma_dot=GAMMA_DOT,  # CRITICAL: must pass protocol-specific argument
    num_warmup=NUM_WARMUP,
    num_samples=NUM_SAMPLES,
    num_chains=NUM_CHAINS,
    initial_values=initial_values,
    seed=42,
)
t_bayes = time.time() - t0
print(f"\nBayesian inference time: {t_bayes:.1f} s")

### 6.2 Convergence Diagnostics

Check R-hat (should be < 1.05) and ESS (should be > 100) for reliable posterior samples.

In [None]:
# Print convergence summary
converged = print_convergence_summary(result, param_names=param_names)

# Print parameter comparison
print_parameter_comparison(model, result.posterior_samples, param_names=param_names)

### 6.3 ArviZ Diagnostic Plots

#### Trace Plots

In [None]:
idata = result.to_inference_data()
axes = az.plot_trace(idata, var_names=param_names, figsize=(12, 8))
fig = axes.ravel()[0].figure
fig.suptitle("Trace Plots — Check for Stationarity", fontsize=14, y=1.001)
plt.tight_layout()
display(fig)
plt.close(fig)

#### Pair Plot (Parameter Correlations)

In [None]:
# Subset to key parameters for clarity
key_params = ["G", "tau_y", "K", "f_eq", "theta"]
axes = az.plot_pair(
    idata,
    var_names=key_params,
    kind="scatter",
    divergences=True,
    figsize=(10, 10),
)
fig = axes.ravel()[0].figure
plt.tight_layout()
display(fig)
plt.close(fig)

#### Forest Plot (Credible Intervals)

In [None]:
axes = az.plot_forest(
    idata,
    var_names=param_names,
    combined=True,
    hdi_prob=0.95,
    figsize=(10, 6),
)
fig = axes.ravel()[0].figure
plt.tight_layout()
display(fig)
plt.close(fig)

## 7. Posterior Predictive

Visualize uncertainty in stress predictions by sampling from the posterior distribution.

In [None]:
posterior = result.posterior_samples
n_draws = min(200, len(list(posterior.values())[0]))
t_pred = np.linspace(t_data.min(), t_data.max(), 100)

print(f"Generating posterior predictions ({n_draws} draws)...")

# Sample predictions
pred_samples = []
for i in range(n_draws):
    # Set parameters from posterior draw i
    params_i = jnp.array([posterior[name][i] for name in param_names])

    # Predict using model_function
    pred_i = model.model_function(
        jnp.array(t_pred),
        params_i,
        test_mode="startup",
        gamma_dot=GAMMA_DOT,
    )
    pred_samples.append(np.array(pred_i))

pred_samples = np.array(pred_samples)
pred_median = np.median(pred_samples, axis=0)
pred_lo = np.percentile(pred_samples, 2.5, axis=0)
pred_hi = np.percentile(pred_samples, 97.5, axis=0)

print("Posterior predictive statistics:")
print(f"  Median peak stress: {pred_median.max():.2f} Pa")
print(f"  95% CI width at peak: {pred_hi.max() - pred_lo.max():.2f} Pa")

In [None]:
fig, ax = plt.subplots(figsize=(10, 6))

# Posterior predictive bands
ax.fill_between(
    t_pred,
    pred_lo,
    pred_hi,
    alpha=0.3,
    color="C0",
    label="95% Credible Interval",
)
ax.plot(t_pred, pred_median, "-", lw=2, color="C0", label="Posterior Median")

# Data
ax.plot(t_data, stress_data, "ko", markersize=5, alpha=0.5, label="Data")

# Expected steady state
ax.axhline(
    sigma_ss_expected,
    color="gray",
    linestyle="--",
    alpha=0.5,
    label=f"Expected σ_ss={sigma_ss_expected:.1f} Pa",
)

ax.set_xlabel("Time [s]")
ax.set_ylabel("Stress σ(t) [Pa]")
ax.set_title(f"Posterior Predictive — Startup Flow (γ̇={GAMMA_DOT} s⁻¹)")
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
display(fig)
plt.close(fig)

## 8. Save Results

Save NLSQ parameters and posterior samples for use in other notebooks.

In [None]:
save_fluidity_results(
    model=model,
    result=result,
    model_variant="local",
    protocol="startup",
    param_names=param_names,
)

# Compute and save fit quality metrics
fit_quality = compute_fit_quality(stress_data, stress_pred_data)
print(f"\nFit quality metrics:")
print(f"  R² = {fit_quality['R2']:.6f}")
print(f"  RMSE = {fit_quality['RMSE']:.4g} Pa")

## Key Takeaways

1. **Stress Overshoot**: Hallmark thixotropic signature — peak stress σ_max > steady-state σ_ss due to structural breakdown
2. **Coupled Dynamics**: Stress σ(t) and fluidity f(t) evolve together via viscoelastic and structural ODEs
3. **Physical Interpretation**:
   - **Overshoot magnitude**: Reflects degree of initial structure (low f_eq)
   - **Time to peak**: Inverse breakdown rate ~ 1/(a|γ̇|^n)
   - **Steady state**: f → f_ss = (f_eq/θ + a|γ̇|^n f_inf) / (1/θ + a|γ̇|^n)
4. **NLSQ Performance**: Fast (< 10s) convergence with proper warm-start from flow curve
5. **Bayesian Inference**: Critical to pass `gamma_dot` argument to `fit_bayesian()` for correct predictions
6. **Model Limitation**: Local model assumes homogeneous flow — use FluidityNonlocal for shear banding

### Physical Insights

- **Aging term**: (f_eq - f)/θ → structural recovery at rest (f → f_eq)
- **Rejuvenation term**: a|γ̇|^n(f_inf - f) → shear-induced breakdown (f → f_inf)
- **Timescale separation**: θ (aging, ~10s) vs 1/(a|γ̇|^n) (rejuvenation, <1s at high γ̇)

### Next Steps

- **Notebook 03**: Creep compliance (stress-controlled deformation)
- **Notebook 04**: Stress relaxation (step strain response)
- **Notebook 05**: SAOS (small amplitude oscillatory shear)
- **Notebook 06**: LAOS (large amplitude for nonlinear viscoelasticity)
- Explore multi-shear-rate fitting for global parameter estimation
- Investigate shear banding with FluidityNonlocal model