# MLIKH Model: Creep Response


> **Handbook:** See [MLIKH Reference](../../docs/source/models/ikh/ml_ikh.rst) for sequential yielding analysis, mode-specific structure evolution, and distributed delayed yielding behavior.

## Protocol Context: Multi-Mode Creep for MLIKH

**Multi-mode creep** with MLIKH reveals **distributed delayed yielding** behavior. With N parallel modes:
- Each mode can yield at different times as its structure $\lambda_i$ breaks down
- Total strain accumulates from all mode contributions
- **Gradual or stepped acceleration**: Fast modes (small $\tau_{thix,i}$) yield first, slow modes later

**Key Physics**:
- **Mode-specific yield stress**: $\sigma_{y,i}(\lambda_i) = \sigma_{y0,i} + \Delta\sigma_{y,i} \cdot \lambda_i$
- **Sequential yielding**: When $\sigma_{app} > \sigma_{y,i}(\lambda_i)$ for mode $i$, that mode begins plastic flow
- **Smooth transition**: Overlapping mode contributions create continuous acceleration curves

> **Further Reading**: See [MLIKH Reference](../../docs/source/models/ikh/ml_ikh.rst) for per-mode yielding mechanics, multi-mode delayed yielding analysis.

## Learning Objectives

1. Fit **MLIKH** to creep data (constant stress, measure strain)
2. Analyze **multi-mode delayed yielding** behavior
3. Observe how distributed thixotropic timescales affect creep
4. Compare with single-mode MIKH predictions

## Prerequisites

- NB04: MIKH Creep (single-mode understanding)
- NB07: MLIKH Flow Curve (multi-mode basics)

## Estimated Runtime

- **Fast demo**: ~4-5 minutes
- **Full run**: ~15-18 minutes

**Estimated Time:** 4-5 minutes (fast demo), 15-18 minutes (full run)

## 1. Setup

In [None]:
# Google Colab setup
import sys

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    %pip install -q rheojax
    import os
    os.environ["JAX_ENABLE_X64"] = "true"
    print("RheoJAX installed successfully.")

In [None]:
# Imports
%matplotlib inline
import os
import sys
import time
import warnings

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display

from rheojax.core.jax_config import safe_import_jax, verify_float64
from rheojax.models.ikh import MLIKH

# Add examples/utils to path
sys.path.insert(0, os.path.join("..", "utils"))
from ikh_tutorial_utils import (
    load_ml_ikh_creep,
    save_ikh_results,
    print_convergence_summary,
    compute_fit_quality,
    get_mlikh_param_names,
)

# Shared plotting utilities
sys.path.insert(0, os.path.dirname(os.path.abspath("")))
from utils.plotting_utils import (
    plot_nlsq_fit,
    display_arviz_diagnostics,
    plot_posterior_predictive,
)

jax, jnp = safe_import_jax()
verify_float64()

# Suppress equinox DeprecationWarnings for jax.core.mapped_aval/unmapped_aval
# (third-party equinox internals, cannot fix at source — harmless with JAX 0.8.x)
warnings.filterwarnings(
    "ignore",
    message=r"jax\.core\.(mapped|unmapped)_aval",
    category=DeprecationWarning,
    module=r"equinox\..*",
)
print(f"JAX version: {jax.__version__}")
# Startup cleanup: force garbage collection to reclaim memory from previous notebooks
import gc
gc.collect()

**Diagnostic Interpretation:**

| Metric | Target | Meaning |
|--------|--------|---------|
| **R-hat** | < 1.01 | Chain convergence (1.0 = perfect) |
| **ESS (bulk)** | > 400 | Effective independent samples |
| **Divergences** | < 1% | NUTS geometric issues |

### Convergence Diagnostic Interpretation

| Metric | Target | Meaning |
|--------|--------|--------|
| **R-hat < 1.01** | Chains converged | Multiple chains agree on posterior |
| **ESS > 400** | Sufficient samples | Independent information content |
| **Divergences < 1%** | Well-behaved sampler | No numerical issues in posterior geometry |

For IKH models, watch for correlations between yield stress and hardening parameters in the pair plot.

## 2. Theory: Multi-Mode Creep

In MLIKH creep:
- Each mode can yield at different times
- Structure evolution for each mode: $\lambda_i(t)$
- Total strain accumulates from all mode contributions

### Delayed Yielding with Multiple Modes

With distributed thixotropic timescales:
- Fast modes yield first (small $\tau_{thix,i}$)
- Slow modes yield later (large $\tau_{thix,i}$)
- Creates stepped or gradual acceleration

## 3. Load Data

In [None]:
# Load creep data (step stress tests)
creep_datasets = {}

for idx in range(3):
    t, gamma_dot, sigma_i, sigma_f = load_ml_ikh_creep(stress_pair_index=idx)
    creep_datasets[(sigma_i, sigma_f)] = {
        "time": t,
        "shear_rate": gamma_dot,
        "initial_stress": sigma_i,
        "final_stress": sigma_f,
    }
    print(f"Stress: {sigma_i:.0f} -> {sigma_f:.0f} Pa, {len(t)} points")

In [None]:
# Plot creep data
fig, ax = plt.subplots(figsize=(10, 6))
colors = ["C0", "C1", "C2"]

for i, (key, d) in enumerate(creep_datasets.items()):
    sigma_i, sigma_f = key
    ax.semilogy(d["time"], d["shear_rate"], "o-", color=colors[i],
                markersize=3, lw=1, alpha=0.7,
                label=f"$\\sigma$: {sigma_i:.0f} -> {sigma_f:.0f} Pa")

ax.set_xlabel("Time [s]", fontsize=12)
ax.set_ylabel("Shear rate [1/s]", fontsize=12)
ax.set_title("ML-IKH Creep Data", fontsize=13)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3, which="both")
plt.tight_layout()
display(fig)
plt.close(fig)

## 4. NLSQ Fitting

In [None]:
# Select reference dataset
ref_key = (3.0, 7.0)
d = creep_datasets[ref_key]
t_data = d["time"]
gamma_dot_data = d["shear_rate"]
sigma_applied = d["final_stress"]

# Create and fit model
n_modes = 2
model = MLIKH(n_modes=n_modes, yield_mode="per_mode")
param_names = get_mlikh_param_names(n_modes=n_modes, yield_mode="per_mode")

print(f"Fitting MLIKH ({n_modes} modes) to creep at sigma = {sigma_applied} Pa")
t0 = time.time()
model.fit(t_data, gamma_dot_data, test_mode="creep", sigma_applied=sigma_applied, method='scipy')
t_nlsq = time.time() - t0

print(f"NLSQ fit time: {t_nlsq:.2f} s")

In [None]:
# Predict and compute fit quality
gamma_pred = model.predict(t_data, test_mode="creep", sigma_applied=sigma_applied)
gamma_dot_pred = np.gradient(np.array(gamma_pred), np.array(t_data))

metrics = compute_fit_quality(gamma_dot_data, gamma_dot_pred)
print(f"\nFit Quality:")
print(f"  R^2:   {metrics['R2']:.6f}")
print(f"  RMSE:  {metrics['RMSE']:.4g} 1/s")

# Residual analysis
residuals = gamma_dot_data - gamma_dot_pred
print(f"\nResidual Analysis:")
print(f"  Mean residual:  {np.mean(residuals):.4g} (should be ~ 0)")
print(f"  Max abs residual: {np.max(np.abs(residuals)):.4g}")

In [None]:
# Plot fit
fig, ax = plt.subplots(figsize=(10, 6))
ax.semilogy(t_data, gamma_dot_data, "ko", markersize=4, alpha=0.5, label="Data")
ax.semilogy(t_data, np.abs(gamma_dot_pred), "-", lw=2, color="C0", label="MLIKH fit")
ax.set_xlabel("Time [s]", fontsize=12)
ax.set_ylabel("Shear rate [1/s]", fontsize=12)
ax.set_title(f"MLIKH Creep Fit ($\\sigma$ = {sigma_applied} Pa)", fontsize=13)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3, which="both")
plt.tight_layout()
display(fig)
plt.close(fig)

## 5. Bayesian Inference

In [None]:
# Clear JAX compilation caches before Bayesian inference to reduce peak memory
import gc
gc.collect()
try:
    jax.clear_caches()
except Exception:
    pass

# Bayesian inference
initial_values = {name: model.parameters.get_value(name) for name in param_names}

# FAST_MODE: reduced samples for CI; set FAST_MODE=0 for production
FAST_MODE = os.environ.get("FAST_MODE", "1") == "1"

if FAST_MODE:
    NUM_WARMUP = 50
    NUM_SAMPLES = 100
    NUM_CHAINS = 1
else:
    NUM_WARMUP = 200
    NUM_SAMPLES = 500
    NUM_CHAINS = 1

print(f"Running NUTS: {NUM_WARMUP} warmup + {NUM_SAMPLES} samples x {NUM_CHAINS} chain(s)")
t0 = time.time()
result = model.fit_bayesian(
    t_data,
    gamma_dot_data,
    test_mode="creep",
    sigma_applied=sigma_applied,
    num_warmup=NUM_WARMUP,
    num_samples=NUM_SAMPLES,
    num_chains=NUM_CHAINS,
    initial_values=initial_values,
    seed=42,
)
t_bayes = time.time() - t0
print(f"\nBayesian inference time: {t_bayes:.1f} s")

In [None]:
# Convergence diagnostics
all_pass = print_convergence_summary(result, param_names)

## 6. Physical Interpretation

### Multi-Mode Delayed Yielding

With multiple modes, the yielding transition can be:
- **Gradual**: Modes yield sequentially as their structure breaks down
- **Stepped**: Distinct jumps when each mode yields
- **Smooth**: Overlapping mode contributions create continuous transition

### Mode-Specific Structure Evolution

Each mode evolves independently:
$$
\frac{d\lambda_i}{dt} = \frac{1 - \lambda_i}{\tau_{thix,i}} - \Gamma_i \lambda_i |\dot{\gamma}^p|
$$

## 7. Save Results

In [None]:
# Save results
save_ikh_results(model, result, "mlikh", "creep", param_names)

## Key Takeaways

1. **Multi-mode creep** captures distributed restructuring timescales
2. **Delayed yielding** can show gradual or stepped transitions
3. **Mode-specific structure** ($\lambda_i$) evolves independently
4. **Viscosity bifurcation** affected by all mode contributions

## Next Steps

- **NB11**: MLIKH SAOS (broadened spectra)
- **NB12**: MLIKH LAOS

## Further Reading

**Multi-Mode Creep:**
- [MLIKH Reference](../../docs/source/models/ikh/ml_ikh.rst) — Sequential yielding analysis, mode-specific structure evolution

**Key References:**
1. Wei et al. (2018). *J. Rheol.*, 62(1), 321-342. — ML-IKH creep validation
2. de Souza Mendes & Thompson (2019). *Annu. Rev. Fluid Mech.*, 51, 421-449. — Time-dependent yield stress

In [None]:
# Cleanup: release JAX caches and Python garbage for sequential notebook runs
import gc
try:
    jax.clear_caches()
except Exception:
    pass
gc.collect()
print("Notebook complete. Memory cleaned up.")
