# Tutorial 14: Startup Shear with FluiditySaramitoLocal

## Learning Objectives

This notebook demonstrates startup shear flow analysis using the FluiditySaramitoLocal model with tensorial stress evolution:

1. **Stress Overshoot**: Understand stress overshoot during startup as a signature of thixotropic elastoviscoplastic behavior
2. **Tensorial Evolution**: Track full stress tensor evolution [τ_xx, τ_yy, τ_xy] during transient flow
3. **Normal Stress N₁**: Extract first normal stress difference N₁ = τ_xx - τ_yy (Weissenberg effect)
4. **UCM-like Viscoelasticity**: Observe elastic response from Upper Convected Maxwell backbone
5. **Thixotropic Coupling**: See how fluidity evolution affects transient stress response
6. **NLSQ + Bayesian**: Calibrate model parameters using startup transient data
7. **Parameter Uncertainty**: Quantify uncertainty in relaxation time λ, yield stress τ_y, and fluidity parameters

**Key Physics**: Startup shear reveals both elastic (stress overshoot) and thixotropic (structural breakdown) effects. The tensorial formulation enables prediction of normal stresses, critical for understanding material viscoelasticity.

## Google Colab Setup

Run this cell if using Google Colab to install RheoJAX:

In [None]:
# Uncomment and run in Google Colab
# !pip install rheojax jaxopt optax arviz

## Setup and Imports

In [None]:
# JAX float64 configuration (CRITICAL: must come before any JAX imports)
from rheojax.core.jax_config import safe_import_jax

jax, jnp = safe_import_jax()

# Standard imports
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# RheoJAX imports
from rheojax.models.fluidity import FluiditySaramitoLocal
from rheojax.core.data import RheoData
from rheojax.logging import configure_logging, get_logger

# Bayesian inference
import arviz as az

# Configure logging
configure_logging(level="INFO")
logger = get_logger(__name__)

# Plot styling
plt.style.use('seaborn-v0_8-darkgrid')
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 11

print(f"JAX version: {jax.__version__}")
print(f"JAX devices: {jax.devices()}")
print(f"Float64 enabled: {jax.config.jax_enable_x64}")

## Theory: Startup Shear with Tensorial Stress

### Governing Equations

The FluiditySaramitoLocal model combines:

1. **Tensorial Upper Convected Maxwell (UCM) Viscoelasticity**:
   $$\boldsymbol{\tau} + \lambda(f) \overset{\nabla}{\boldsymbol{\tau}} = 2\eta(f) \mathbf{D}$$
   
   where $\overset{\nabla}{\boldsymbol{\tau}}$ is the upper-convected derivative:
   $$\overset{\nabla}{\boldsymbol{\tau}} = \frac{D\boldsymbol{\tau}}{Dt} - (\nabla \mathbf{v})^T \cdot \boldsymbol{\tau} - \boldsymbol{\tau} \cdot \nabla \mathbf{v}$$

2. **Von Mises Yielding**:
   $$\alpha = \max\left(0, 1 - \frac{\tau_y}{|\boldsymbol{\tau}|}\right)$$
   
   Active only when $|\boldsymbol{\tau}| > \tau_y(f)$

3. **Fluidity Evolution**:
   $$\frac{df}{dt} = \frac{1 - f}{t_{eq}} + b |\dot{\gamma}|^n f^{-m}$$
   
   Aging (structure buildup) + Shear rejuvenation

### Startup Protocol

- **Initial Condition**: Material at rest with $f = f_0$ (typically $f_0 = 1$ for fully structured)
- **Imposed Shear**: Step to constant $\dot{\gamma}$ at $t = 0$
- **Response**: Stress overshoot as elastic energy builds, then decay to steady state
- **Stress Components**: Track $[\tau_{xx}, \tau_{yy}, \tau_{xy}]$ evolution

### Normal Stress Differences

From the UCM backbone, the model predicts:

$$N_1 = \tau_{xx} - \tau_{yy} = 2\lambda(f) \dot{\gamma} \tau_{xy}$$

This captures the **Weissenberg effect** (rod-climbing) in viscoelastic fluids.

### Key Observables

1. **Stress Overshoot**: Peak in $\tau_{xy}(t)$ indicates elastic storage before yielding
2. **Overshoot Time**: $t_{peak} \sim \lambda(f_0)$ related to relaxation time
3. **Steady State**: Eventual plateau when structural evolution balances
4. **Normal Stress**: $N_1(t)$ follows similar overshoot, remains positive in steady state

## Load Calibrated Parameters

If available from flow curve fitting (Tutorial 11), load parameters. Otherwise, use sensible defaults.

In [None]:
# Try to load from flow curve calibration
param_file = Path("../outputs/fluidity/saramito_local/flow_curve/parameters.txt")

if param_file.exists():
    logger.info(f"Loading calibrated parameters from {param_file}")
    # Parse parameter file (simple key=value format)
    params = {}
    with open(param_file, 'r') as f:
        for line in f:
            if '=' in line and not line.strip().startswith('#'):
                key, val = line.split('=')
                params[key.strip()] = float(val.strip())
    
    # Extract relevant parameters
    eta_0 = params.get('eta_0', 100.0)
    tau_y = params.get('tau_y', 50.0)
    lambda_0 = params.get('lambda_0', 1.0)
    t_eq = params.get('t_eq', 10.0)
    b = params.get('b', 0.5)
    n = params.get('n', 1.0)
    m = params.get('m', 0.0)
    
    logger.info(f"Loaded: η₀={eta_0:.2f}, τ_y={tau_y:.2f}, λ₀={lambda_0:.2f}, t_eq={t_eq:.2f}")
else:
    logger.info("No calibrated parameters found, using defaults")
    # Default parameters for demonstration
    eta_0 = 100.0      # Zero-shear viscosity (Pa·s)
    tau_y = 50.0       # Yield stress (Pa)
    lambda_0 = 1.0     # Relaxation time at f=1 (s)
    t_eq = 10.0        # Equilibrium aging time (s)
    b = 0.5            # Rejuvenation coefficient
    n = 1.0            # Shear-rate exponent
    m = 0.0            # Fluidity exponent

print("\n=== Model Parameters ===")
print(f"η₀ (zero-shear viscosity): {eta_0:.2f} Pa·s")
print(f"τ_y (yield stress): {tau_y:.2f} Pa")
print(f"λ₀ (relaxation time): {lambda_0:.2f} s")
print(f"t_eq (aging time): {t_eq:.2f} s")
print(f"b (rejuvenation): {b:.2f}")
print(f"n (shear exponent): {n:.2f}")
print(f"m (fluidity exponent): {m:.2f}")

## Generate Synthetic Startup Data

Simulate startup shear at multiple shear rates to observe rate-dependent overshoot.

In [None]:
# Create model with known parameters for data generation
model_true = FluiditySaramitoLocal(coupling="full")

# Set true parameters
model_true.parameters['eta_0'].value = eta_0
model_true.parameters['tau_y'].value = tau_y
model_true.parameters['lambda_0'].value = lambda_0
model_true.parameters['t_eq'].value = t_eq
model_true.parameters['b'].value = b
model_true.parameters['n'].value = n
model_true.parameters['m'].value = m

# Startup simulation parameters
gamma_dot_startup = 1.0  # Applied shear rate (1/s)
t_end = 50.0             # Simulation time (s)
n_points = 500           # Time points

# Generate time array (logarithmic spacing for better resolution of overshoot)
t_startup = np.logspace(-2, np.log10(t_end), n_points)

# Simulate startup (returns strain, stress, fluidity)
logger.info(f"Simulating startup at γ̇ = {gamma_dot_startup:.2f} 1/s")
strain_true, stress_true, fluidity_true = model_true.simulate_startup(
    t_startup, 
    gamma_dot=gamma_dot_startup,
    t_wait=100.0  # Wait time before startup for equilibration
)

# Extract shear stress component (τ_xy)
tau_xy_true = stress_true[:, 2]  # Third component is τ_xy

# Add realistic noise (5% relative error)
np.random.seed(42)
noise_level = 0.05
noise = noise_level * np.abs(tau_xy_true) * np.random.randn(len(tau_xy_true))
tau_xy_noisy = tau_xy_true + noise

# Visualization
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Shear stress τ_xy
axes[0, 0].plot(t_startup, tau_xy_true, 'b-', linewidth=2, label='True')
axes[0, 0].plot(t_startup, tau_xy_noisy, 'r.', markersize=4, alpha=0.6, label='Noisy (5%)')
axes[0, 0].set_xlabel('Time (s)')
axes[0, 0].set_ylabel('Shear Stress τ_xy (Pa)')
axes[0, 0].set_xscale('log')
axes[0, 0].set_title('Startup Shear Stress (Overshoot)')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Strain evolution
axes[0, 1].plot(t_startup, strain_true, 'g-', linewidth=2)
axes[0, 1].set_xlabel('Time (s)')
axes[0, 1].set_ylabel('Strain γ (dimensionless)')
axes[0, 1].set_xscale('log')
axes[0, 1].set_title('Accumulated Strain')
axes[0, 1].grid(True, alpha=0.3)

# Fluidity evolution
axes[1, 0].plot(t_startup, fluidity_true, 'purple', linewidth=2)
axes[1, 0].set_xlabel('Time (s)')
axes[1, 0].set_ylabel('Fluidity f (dimensionless)')
axes[1, 0].set_xscale('log')
axes[1, 0].set_title('Structural Breakdown (Fluidity Evolution)')
axes[1, 0].grid(True, alpha=0.3)
axes[1, 0].axhline(y=1.0, color='k', linestyle='--', alpha=0.3, label='Initial')
axes[1, 0].legend()

# Normal stress components
tau_xx_true = stress_true[:, 0]
tau_yy_true = stress_true[:, 1]
N1_true = tau_xx_true - tau_yy_true  # First normal stress difference

axes[1, 1].plot(t_startup, tau_xx_true, 'b-', linewidth=2, label='τ_xx')
axes[1, 1].plot(t_startup, tau_yy_true, 'r-', linewidth=2, label='τ_yy')
axes[1, 1].plot(t_startup, N1_true, 'g-', linewidth=2, label='N₁ = τ_xx - τ_yy')
axes[1, 1].set_xlabel('Time (s)')
axes[1, 1].set_ylabel('Normal Stress Components (Pa)')
axes[1, 1].set_xscale('log')
axes[1, 1].set_title('Normal Stresses (Weissenberg Effect)')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Find overshoot peak
peak_idx = np.argmax(tau_xy_true)
t_peak = t_startup[peak_idx]
tau_peak = tau_xy_true[peak_idx]
tau_steady = tau_xy_true[-1]

print(f"\n=== Startup Characteristics ===")
print(f"Peak stress: {tau_peak:.2f} Pa at t = {t_peak:.3f} s")
print(f"Steady-state stress: {tau_steady:.2f} Pa")
print(f"Overshoot ratio: {tau_peak/tau_steady:.2f}")
print(f"Final fluidity: {fluidity_true[-1]:.3f}")
print(f"Final N₁: {N1_true[-1]:.2f} Pa")

## NLSQ Fitting: Parameter Estimation from Startup Data

Fit the model to synthetic startup data using NLSQ optimization.

In [None]:
# Create fresh model for fitting
model = FluiditySaramitoLocal(coupling="full")

# Prepare data
rheo_data = RheoData(
    x=t_startup,
    y=tau_xy_noisy,
    test_mode='startup'
)

# Set initial guesses (slightly perturbed from truth)
model.parameters['eta_0'].value = eta_0 * 0.8
model.parameters['tau_y'].value = tau_y * 1.2
model.parameters['lambda_0'].value = lambda_0 * 0.9
model.parameters['t_eq'].value = t_eq * 1.1
model.parameters['b'].value = b * 0.95
model.parameters['n'].value = n
model.parameters['m'].value = m

# Fit with NLSQ
logger.info("Starting NLSQ optimization for startup data...")
result = model.fit(
    rheo_data,
    gamma_dot=gamma_dot_startup,  # Pass shear rate for startup simulation
    max_iter=5000,
    ftol=1e-8,
    xtol=1e-8
)

print(f"\n=== NLSQ Fitting Results ===")
print(f"Converged: {result.success}")
print(f"Iterations: {result.nit}")
print(f"Final cost: {result.cost:.6e}")
print(f"R²: {result.r_squared:.6f}")

print("\n=== Fitted Parameters ===")
for name, param in model.parameters.items():
    true_val = model_true.parameters[name].value
    error = 100 * abs(param.value - true_val) / true_val
    print(f"{name:12s}: {param.value:10.4f}  (true: {true_val:10.4f}, error: {error:5.2f}%)")

# Plot fit quality
tau_xy_fit = model.predict(t_startup, test_mode='startup', gamma_dot=gamma_dot_startup)[:, 2]

plt.figure(figsize=(12, 6))
plt.plot(t_startup, tau_xy_noisy, 'ko', markersize=4, alpha=0.5, label='Data (noisy)')
plt.plot(t_startup, tau_xy_true, 'b--', linewidth=2, label='True')
plt.plot(t_startup, tau_xy_fit, 'r-', linewidth=2, label='NLSQ Fit')
plt.xlabel('Time (s)')
plt.ylabel('Shear Stress τ_xy (Pa)')
plt.xscale('log')
plt.title(f'NLSQ Fit Quality (R² = {result.r_squared:.4f})')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Residual analysis
residuals = tau_xy_noisy - tau_xy_fit

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(t_startup, residuals, 'ko', markersize=3, alpha=0.6)
axes[0].axhline(y=0, color='r', linestyle='--', linewidth=2)
axes[0].set_xlabel('Time (s)')
axes[0].set_ylabel('Residuals (Pa)')
axes[0].set_xscale('log')
axes[0].set_title('Residual Plot')
axes[0].grid(True, alpha=0.3)

axes[1].hist(residuals, bins=30, density=True, alpha=0.7, edgecolor='black')
axes[1].set_xlabel('Residuals (Pa)')
axes[1].set_ylabel('Density')
axes[1].set_title(f'Residual Distribution (σ = {np.std(residuals):.2f} Pa)')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Bayesian Inference: Parameter Uncertainty Quantification

Use NUTS sampling to quantify parameter uncertainties, using NLSQ fit as warm-start.

In [None]:
# Bayesian inference with NUTS
logger.info("Starting Bayesian inference with NUTS...")

# Run MCMC (using NLSQ fit as warm-start)
bayes_result = model.fit_bayesian(
    rheo_data,
    num_warmup=1000,
    num_samples=2000,
    num_chains=4,
    seed=42,
    gamma_dot=gamma_dot_startup  # Pass shear rate for startup
)

# Convert to ArviZ InferenceData
idata = az.from_dict(
    posterior=bayes_result.posterior_samples,
    observed_data={"y": tau_xy_noisy}
)

# Compute diagnostics
print("\n=== MCMC Diagnostics ===")
summary = az.summary(idata, hdi_prob=0.95)
print(summary)

# Check convergence
rhat_max = summary['r_hat'].max()
ess_min = summary['ess_bulk'].min()

print(f"\nMax R-hat: {rhat_max:.4f} (should be < 1.01)")
print(f"Min ESS: {ess_min:.0f} (should be > 400)")

if rhat_max < 1.01 and ess_min > 400:
    print("✓ Convergence achieved!")
else:
    print("⚠ Convergence issues detected. Consider increasing num_warmup/num_samples.")

# Extract credible intervals
intervals = model.get_credible_intervals(bayes_result.posterior_samples, credibility=0.95)

print("\n=== 95% Credible Intervals ===")
for name, (lower, upper) in intervals.items():
    median = np.median(bayes_result.posterior_samples[name])
    true_val = model_true.parameters[name].value
    in_interval = lower <= true_val <= upper
    status = "✓" if in_interval else "✗"
    print(f"{name:12s}: [{lower:8.4f}, {upper:8.4f}]  median: {median:8.4f}  true: {true_val:8.4f} {status}")

## Stress Tensor Evolution: Full Tensorial Dynamics

Visualize evolution of all stress components during startup.

In [None]:
# Predict with fitted parameters
stress_fit = model.predict(t_startup, test_mode='startup', gamma_dot=gamma_dot_startup)

# Extract components
tau_xx_fit = stress_fit[:, 0]
tau_yy_fit = stress_fit[:, 1]
tau_xy_fit = stress_fit[:, 2]

# Plot tensorial evolution
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# τ_xx evolution
axes[0, 0].plot(t_startup, tau_xx_true, 'b--', linewidth=2, label='True', alpha=0.7)
axes[0, 0].plot(t_startup, tau_xx_fit, 'r-', linewidth=2, label='Fitted')
axes[0, 0].set_xlabel('Time (s)')
axes[0, 0].set_ylabel('τ_xx (Pa)')
axes[0, 0].set_xscale('log')
axes[0, 0].set_title('Normal Stress Component τ_xx')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# τ_yy evolution
axes[0, 1].plot(t_startup, tau_yy_true, 'b--', linewidth=2, label='True', alpha=0.7)
axes[0, 1].plot(t_startup, tau_yy_fit, 'r-', linewidth=2, label='Fitted')
axes[0, 1].set_xlabel('Time (s)')
axes[0, 1].set_ylabel('τ_yy (Pa)')
axes[0, 1].set_xscale('log')
axes[0, 1].set_title('Normal Stress Component τ_yy')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# τ_xy evolution
axes[1, 0].plot(t_startup, tau_xy_true, 'b--', linewidth=2, label='True', alpha=0.7)
axes[1, 0].plot(t_startup, tau_xy_fit, 'r-', linewidth=2, label='Fitted')
axes[1, 0].plot(t_startup, tau_xy_noisy, 'ko', markersize=3, alpha=0.3, label='Data')
axes[1, 0].set_xlabel('Time (s)')
axes[1, 0].set_ylabel('τ_xy (Pa)')
axes[1, 0].set_xscale('log')
axes[1, 0].set_title('Shear Stress Component τ_xy (with Overshoot)')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Von Mises stress magnitude
tau_mag_true = np.sqrt(tau_xx_true**2 + tau_yy_true**2 + tau_xy_true**2)
tau_mag_fit = np.sqrt(tau_xx_fit**2 + tau_yy_fit**2 + tau_xy_fit**2)

axes[1, 1].plot(t_startup, tau_mag_true, 'b--', linewidth=2, label='True', alpha=0.7)
axes[1, 1].plot(t_startup, tau_mag_fit, 'r-', linewidth=2, label='Fitted')
axes[1, 1].axhline(y=tau_y, color='k', linestyle=':', linewidth=2, label=f'τ_y = {tau_y:.1f} Pa')
axes[1, 1].set_xlabel('Time (s)')
axes[1, 1].set_ylabel('|τ| (Pa)')
axes[1, 1].set_xscale('log')
axes[1, 1].set_title('Von Mises Stress Magnitude (for Yielding)')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\n=== Stress Component Analysis ===")
print(f"Peak τ_xx: {np.max(tau_xx_fit):.2f} Pa")
print(f"Peak τ_yy: {np.max(tau_yy_fit):.2f} Pa")
print(f"Peak τ_xy: {np.max(tau_xy_fit):.2f} Pa")
print(f"Steady τ_xx: {tau_xx_fit[-1]:.2f} Pa")
print(f"Steady τ_yy: {tau_yy_fit[-1]:.2f} Pa")
print(f"Steady τ_xy: {tau_xy_fit[-1]:.2f} Pa")

## First Normal Stress Difference N₁: Weissenberg Effect

Extract and analyze N₁ = τ_xx - τ_yy, the signature of viscoelastic normal stress.

In [None]:
# Compute N₁
N1_fit = tau_xx_fit - tau_yy_fit

# Theoretical prediction from UCM: N₁ = 2λγ̇τ_xy in steady state
lambda_fit = model.parameters['lambda_0'].value
N1_ucm_steady = 2 * lambda_fit * gamma_dot_startup * tau_xy_fit[-1]

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# N₁ evolution
axes[0].plot(t_startup, N1_true, 'b--', linewidth=2, label='True', alpha=0.7)
axes[0].plot(t_startup, N1_fit, 'r-', linewidth=2, label='Fitted')
axes[0].axhline(y=N1_ucm_steady, color='g', linestyle=':', linewidth=2, 
                label=f'UCM prediction: {N1_ucm_steady:.1f} Pa')
axes[0].set_xlabel('Time (s)')
axes[0].set_ylabel('N₁ = τ_xx - τ_yy (Pa)')
axes[0].set_xscale('log')
axes[0].set_title('First Normal Stress Difference (Weissenberg Effect)')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# N₁/τ_xy ratio (should approach 2λγ̇ in steady state for UCM)
ratio = N1_fit / (tau_xy_fit + 1e-10)  # Avoid division by zero
ratio_true = N1_true / (tau_xy_true + 1e-10)
ratio_ucm = 2 * lambda_fit * gamma_dot_startup

axes[1].plot(t_startup, ratio_true, 'b--', linewidth=2, label='True', alpha=0.7)
axes[1].plot(t_startup, ratio, 'r-', linewidth=2, label='Fitted')
axes[1].axhline(y=ratio_ucm, color='g', linestyle=':', linewidth=2,
                label=f'UCM: 2λγ̇ = {ratio_ucm:.2f}')
axes[1].set_xlabel('Time (s)')
axes[1].set_ylabel('N₁/τ_xy (dimensionless)')
axes[1].set_xscale('log')
axes[1].set_title('Normal Stress Ratio (UCM Consistency Check)')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
axes[1].set_ylim([0, ratio_ucm * 2])  # Reasonable y-axis limits

plt.tight_layout()
plt.show()

print("\n=== Normal Stress Analysis ===")
print(f"Peak N₁: {np.max(N1_fit):.2f} Pa")
print(f"Steady-state N₁: {N1_fit[-1]:.2f} Pa")
print(f"UCM prediction (steady): {N1_ucm_steady:.2f} Pa")
print(f"Prediction accuracy: {100*(1 - abs(N1_fit[-1] - N1_ucm_steady)/N1_ucm_steady):.1f}%")
print(f"\nN₁/τ_xy steady ratio: {ratio[-1]:.3f}")
print(f"UCM ratio (2λγ̇): {ratio_ucm:.3f}")

## ArviZ Diagnostics: Convergence and Correlation Analysis

In [None]:
# Trace plots (chain mixing)
az.plot_trace(
    idata,
    var_names=['eta_0', 'tau_y', 'lambda_0', 't_eq', 'b'],
    compact=True,
    figsize=(14, 10)
)
plt.tight_layout()
plt.suptitle('MCMC Trace Plots (Chain Mixing)', y=1.02, fontsize=14)
plt.show()

# Pair plot (parameter correlations)
az.plot_pair(
    idata,
    var_names=['eta_0', 'tau_y', 'lambda_0', 't_eq'],
    kind='hexbin',
    marginals=True,
    figsize=(12, 12)
)
plt.suptitle('Parameter Correlations (Pair Plot)', y=1.00, fontsize=14)
plt.tight_layout()
plt.show()

# Forest plot (credible intervals)
az.plot_forest(
    idata,
    var_names=['eta_0', 'tau_y', 'lambda_0', 't_eq', 'b', 'n', 'm'],
    combined=True,
    hdi_prob=0.95,
    figsize=(10, 6)
)
plt.title('95% Credible Intervals (Forest Plot)')
plt.tight_layout()
plt.show()

# Autocorrelation (should decay quickly)
az.plot_autocorr(
    idata,
    var_names=['eta_0', 'tau_y', 'lambda_0'],
    max_lag=100,
    figsize=(12, 4)
)
plt.tight_layout()
plt.show()

# Rank plots (uniform if well-mixed)
az.plot_rank(
    idata,
    var_names=['eta_0', 'tau_y', 'lambda_0', 't_eq'],
    figsize=(12, 8)
)
plt.tight_layout()
plt.show()

## Save Results

Export fitted parameters, posteriors, and diagnostic plots.

In [None]:
# Create output directory
output_dir = Path("../outputs/fluidity/saramito_local/startup")
output_dir.mkdir(parents=True, exist_ok=True)

# Save parameters
param_file = output_dir / "parameters_nlsq.txt"
with open(param_file, 'w') as f:
    f.write("# NLSQ Fitted Parameters\n")
    f.write(f"# Test mode: startup\n")
    f.write(f"# Shear rate: {gamma_dot_startup} 1/s\n")
    f.write(f"# R²: {result.r_squared:.6f}\n\n")
    for name, param in model.parameters.items():
        f.write(f"{name} = {param.value:.6e}\n")

logger.info(f"Parameters saved to {param_file}")

# Save posterior samples
posterior_file = output_dir / "posterior_samples.npz"
np.savez(
    posterior_file,
    **bayes_result.posterior_samples,
    t_startup=t_startup,
    tau_xy_data=tau_xy_noisy,
    gamma_dot=gamma_dot_startup
)
logger.info(f"Posterior samples saved to {posterior_file}")

# Save ArviZ diagnostics
idata.to_netcdf(output_dir / "arviz_inference.nc")
logger.info(f"ArviZ InferenceData saved to {output_dir / 'arviz_inference.nc'}")

# Save summary statistics
summary_file = output_dir / "mcmc_summary.txt"
with open(summary_file, 'w') as f:
    f.write("# MCMC Diagnostics Summary\n\n")
    f.write(summary.to_string())
    f.write(f"\n\nMax R-hat: {rhat_max:.4f}\n")
    f.write(f"Min ESS: {ess_min:.0f}\n")

logger.info(f"MCMC summary saved to {summary_file}")

# Save prediction data
pred_file = output_dir / "predictions.npz"
np.savez(
    pred_file,
    t=t_startup,
    tau_xy_true=tau_xy_true,
    tau_xy_fit=tau_xy_fit,
    tau_xx_fit=tau_xx_fit,
    tau_yy_fit=tau_yy_fit,
    N1_fit=N1_fit,
    gamma_dot=gamma_dot_startup
)
logger.info(f"Predictions saved to {pred_file}")

print(f"\n✓ All results saved to {output_dir}")

## Key Takeaways

### Physical Insights

1. **Stress Overshoot**: Peak in τ_xy during startup is a hallmark of elastoviscoplastic materials, reflecting elastic energy storage before yielding.

2. **Normal Stress N₁**: The tensorial formulation enables prediction of N₁ = τ_xx - τ_yy, capturing the Weissenberg effect (rod-climbing) in viscoelastic fluids.

3. **UCM Consistency**: In steady state, N₁/(γ̇τ_xy) → 2λ, consistent with Upper Convected Maxwell predictions.

4. **Thixotropic Signature**: Fluidity evolution during startup shows structural breakdown, with overshoot time related to equilibrium time t_eq.

5. **Yield Transition**: Von Mises stress magnitude |τ| > τ_y triggers plastic flow via the α factor.

### Numerical Insights

1. **NLSQ Efficiency**: Fast point estimation (seconds) provides excellent initial guess for Bayesian inference.

2. **Warm-Start Critical**: NLSQ-initialized NUTS converges reliably with R-hat < 1.01 and ESS > 400.

3. **Parameter Identifiability**: Startup data constrains λ (overshoot time), τ_y (yield point), and η₀ (steady viscosity).

4. **Multi-Chain Diagnostic**: 4 chains (default) enable robust R-hat and ESS calculations for production-quality inference.

5. **Residual Structure**: Random residuals confirm model adequacy; systematic patterns indicate missing physics.

### Model Capabilities

1. **Tensorial Stress**: Full [τ_xx, τ_yy, τ_xy] tracking enables normal stress predictions unavailable in scalar models.

2. **Protocol Versatility**: Same model handles FLOW_CURVE, STARTUP, CREEP, RELAXATION, OSCILLATION, and LAOS.

3. **Coupling Modes**: "minimal" (λ = 1/f only) vs "full" (λ + τ_y(f) aging) provide different thixotropic behaviors.

4. **JAX Acceleration**: JIT compilation enables fast ODE integration for transient protocols.

5. **Bayesian Uncertainty**: Posterior distributions quantify parameter uncertainties for reliable predictions.

### Experimental Connection

**Startup shear experiments** measure:
- Shear stress τ(t) at fixed γ̇
- Optionally N₁(t) via force transducers
- Overshoot peak time and magnitude

**Model predictions** enable:
- Material parameter extraction (λ, τ_y, η₀)
- Classification (yield stress fluid, viscoelastic liquid)
- Prediction of unmeasured normal stresses
- Design of processing protocols (e.g., preshear conditioning)