# MLIKH Model: Steady-State Flow Curves


> **Handbook:** See [MLIKH Reference](../../docs/source/models/ikh/ml_ikh.rst) for Kohlrausch-Williams-Watts decomposition, mode selection rules, and industrial applications to materials with distributed thixotropic timescales.

## What is MLIKH (Multi-Lambda IKH)?

The **MLIKH (Multi-Lambda IKH)** model extends the single-mode MIKH framework to **N parallel modes**, capturing materials with **distributed thixotropic timescales**. This is analogous to how the Generalized Maxwell Model (Prony series) extends the single Maxwell element for viscoelasticity.

**Physical Motivation**:
- Many thixotropic materials exhibit **stretched-exponential recovery** (not simple exponential)
- **Hierarchical microstructure**: Primary bonds (fast kinetics) + aggregates (medium) + networks (slow)
- Materials: Complex waxy crude oils (broad wax crystal size distribution), bidisperse colloids (small + large particles), hierarchical clay gels, dense emulsions

**Key Physics**:
- **Distributed timescales**: Different modes have different $\tau_{thix,i}$, spanning 2-4 decades
- **Stretch exponent**: Recovery follows $\lambda(t) \sim \exp[-(t/\tau_c)^\beta]$ with $\beta < 1$ → requires N ~ $(1/\beta)^2$ modes
- **Two yield formulations**:
  1. **per_mode** (default): Each mode has independent yield surface, parallel mechanical connection
  2. **weighted_sum**: Single global yield surface with $\sigma_y = \sigma_{y0} + k_3 \sum_i w_i\lambda_i$

**When to Use MLIKH**:
- Recovery experiments show $\beta < 0.8$ (stretched exponential fit)
- Yield stress recovery spans >2 decades of time
- Single-mode MIKH $R^2$ improvement > 10% with multi-mode
- Materials with hierarchical structure (multiple structural populations)

> **Further Reading**: See [IKH Handbook](../../docs/source/models/ikh/index.rst) for model hierarchy, [MLIKH Reference](../../docs/source/models/ikh/ml_ikh.rst) for Kohlrausch-Williams-Watts decomposition, mode selection rules ($\beta$ → N mapping), and industrial applications.

## Learning Objectives

1. Fit the **MLIKH (Multi-Lambda IKH)** model to steady-state flow curve data
2. Understand **multi-mode thixotropy** with distributed timescales
3. Compare **per_mode** vs **weighted_sum** yield formulations
4. Analyze how multiple modes capture complex flow behavior
5. Calibrate parameters for downstream synthetic data generation

## Prerequisites

- NB01-06: MIKH tutorials (single-mode understanding)

## Estimated Runtime

- **Fast demo** (NUM_CHAINS=1, NUM_SAMPLES=500): ~3-4 minutes
- **Full run** (NUM_CHAINS=4, NUM_SAMPLES=2000): ~15-20 minutes

**Estimated Time:** 3-4 minutes (fast demo), 15-20 minutes (full run)

## 1. Setup

In [None]:
# Google Colab setup
import sys

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    %pip install -q rheojax
    import os
    os.environ["JAX_ENABLE_X64"] = "true"
    print("RheoJAX installed successfully.")

In [None]:
# Imports
%matplotlib inline
import os
import sys
import time
import warnings

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display

from rheojax.core.jax_config import safe_import_jax, verify_float64
from rheojax.models.ikh import MLIKH

# Add examples/utils to path for tutorial utilities
sys.path.insert(0, os.path.join("..", "utils"))
from ikh_tutorial_utils import (
    load_ml_ikh_flow_curve,
    save_ikh_results,
    print_convergence_summary,
    print_parameter_comparison,
    compute_fit_quality,
    get_mlikh_param_names,
)

# Shared plotting utilities
sys.path.insert(0, os.path.dirname(os.path.abspath("")))
from utils.plotting_utils import (
    plot_nlsq_fit,
    display_arviz_diagnostics,
    plot_posterior_predictive,
)

jax, jnp = safe_import_jax()
verify_float64()

# Suppress equinox DeprecationWarnings for jax.core.mapped_aval/unmapped_aval
# (third-party equinox internals, cannot fix at source — harmless with JAX 0.8.x)
warnings.filterwarnings(
    "ignore",
    message=r"jax\.core\.(mapped|unmapped)_aval",
    category=DeprecationWarning,
    module=r"equinox\..*",
)
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
# Startup cleanup: force garbage collection to reclaim memory from previous notebooks
import gc
gc.collect()

**Diagnostic Interpretation:**

| Metric | Target | Meaning |
|--------|--------|---------|
| **R-hat** | < 1.01 | Chain convergence (1.0 = perfect) |
| **ESS (bulk)** | > 400 | Effective independent samples |
| **Divergences** | < 1% | NUTS geometric issues |

### Convergence Diagnostic Interpretation

| Metric | Target | Meaning |
|--------|--------|--------|
| **R-hat < 1.01** | Chains converged | Multiple chains agree on posterior |
| **ESS > 400** | Sufficient samples | Independent information content |
| **Divergences < 1%** | Well-behaved sampler | No numerical issues in posterior geometry |

For IKH models, watch for correlations between yield stress and hardening parameters in the pair plot.

## 2. Theory: Multi-Mode IKH

The **MLIKH (Multi-Lambda IKH)** model extends MIKH to N modes connected in parallel, capturing **distributed thixotropic timescales**.

### Motivation

Real materials often exhibit:
- Multiple restructuring timescales
- Broad relaxation spectra
- Complex flow history dependence

### Two Yield Formulations

**1. Per-Mode Yield** (default):
- Each mode has independent yield surface
- Total stress = $\sum_i \sigma_i$ (parallel connection)
- Parameters: 7 per mode + 1 global

**2. Weighted-Sum Yield**:
- Single global yield surface: $\sigma_y = \sigma_{y0} + k_3 \sum_i w_i \lambda_i$
- All modes share elastic/plastic response
- Parameters: 5 global + 3 per mode

### Per-Mode Parameters (for each mode $i$)

| Parameter | Symbol | Description |
|-----------|--------|-------------|
| `G_i` | $G_i$ | Mode shear modulus (Pa) |
| `C_i` | $C_i$ | Kinematic hardening modulus (Pa) |
| `gamma_dyn_i` | $\gamma_{dyn,i}$ | Dynamic recovery |
| `sigma_y0_i` | $\sigma_{y0,i}$ | Minimal yield stress (Pa) |
| `delta_sigma_y_i` | $\Delta\sigma_{y,i}$ | Structural yield contribution (Pa) |
| `tau_thix_i` | $\tau_{thix,i}$ | Thixotropic timescale (s) |
| `Gamma_i` | $\Gamma_i$ | Breakdown coefficient |

### Key Physics

- **Distributed timescales**: Different modes restructure at different rates
- **Parallel stress**: Total stress is sum of mode contributions
- **Independent yielding**: Each mode can yield independently (per_mode)

## 3. Load Data

In [None]:
# Load flow curve data (same as MIKH NB01)
gamma_dot, stress = load_ml_ikh_flow_curve(instrument="ARES_up")

print(f"Data points: {len(gamma_dot)}")
print(f"Shear rate range: [{gamma_dot.min():.4f}, {gamma_dot.max():.2f}] 1/s")
print(f"Stress range: [{stress.min():.2f}, {stress.max():.2f}] Pa")

In [None]:
# Plot raw data
fig, ax = plt.subplots(figsize=(10, 6))
ax.loglog(gamma_dot, stress, "ko", markersize=7, label="Data")
ax.set_xlabel("Shear rate [1/s]", fontsize=12)
ax.set_ylabel("Stress [Pa]", fontsize=12)
ax.set_title("ML-IKH Flow Curve Data", fontsize=13)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3, which="both")
plt.tight_layout()
display(fig)
plt.close(fig)

## 4. NLSQ Fitting

### 4.1 Per-Mode Yield (2 modes)

In [None]:
# Create MLIKH model with 2 modes, per_mode yield
n_modes = 2
model_per_mode = MLIKH(n_modes=n_modes, yield_mode="per_mode")
param_names_pm = get_mlikh_param_names(n_modes=n_modes, yield_mode="per_mode")

print(f"MLIKH (per_mode, {n_modes} modes): {len(param_names_pm)} parameters")
print(f"Parameters: {param_names_pm}")

In [None]:
# Fit model
t0 = time.time()
model_per_mode.fit(gamma_dot, stress, test_mode="flow_curve")
t_nlsq = time.time() - t0

print(f"NLSQ fit time: {t_nlsq:.2f} s")
print(f"\nFitted parameters (per-mode):")
for name in param_names_pm:
    val = model_per_mode.parameters.get_value(name)
    print(f"  {name:18s} = {val:.4g}")

In [None]:
# Compute fit quality
stress_pred_pm = model_per_mode.predict(gamma_dot, test_mode="flow_curve")
metrics_pm = compute_fit_quality(stress, stress_pred_pm)

print(f"\nFit Quality (per_mode):")
print(f"  R^2:   {metrics_pm['R2']:.6f}")
print(f"  RMSE:  {metrics_pm['RMSE']:.4g} Pa")

# Residual analysis
residuals = stress - stress_pred_pm
print(f"\nResidual Analysis:")
print(f"  Mean residual:  {np.mean(residuals):.4g} (should be ~ 0)")
print(f"  Max abs residual: {np.max(np.abs(residuals)):.4g}")

### 4.2 Weighted-Sum Yield (2 modes)

In [None]:
# Create MLIKH model with weighted_sum yield
model_weighted_sum = MLIKH(n_modes=n_modes, yield_mode="weighted_sum")
param_names_ws = get_mlikh_param_names(n_modes=n_modes, yield_mode="weighted_sum")

print(f"MLIKH (weighted_sum, {n_modes} modes): {len(param_names_ws)} parameters")
print(f"Parameters: {param_names_ws}")

In [None]:
# Fit model
t0 = time.time()
model_weighted_sum.fit(gamma_dot, stress, test_mode="flow_curve")
t_nlsq_ws = time.time() - t0

print(f"NLSQ fit time: {t_nlsq_ws:.2f} s")

# Compute fit quality
stress_pred_ws = model_weighted_sum.predict(gamma_dot, test_mode="flow_curve")
metrics_ws = compute_fit_quality(stress, stress_pred_ws)

print(f"\nFit Quality (weighted_sum):")
print(f"  R^2:   {metrics_ws['R2']:.6f}")
print(f"  RMSE:  {metrics_ws['RMSE']:.4g} Pa")

### 4.3 Compare Yield Formulations

In [None]:
# Plot comparison
gamma_dot_fine = np.logspace(
    np.log10(gamma_dot.min()) - 0.5,
    np.log10(gamma_dot.max()) + 0.3,
    200,
)
stress_pm_fine = model_per_mode.predict(gamma_dot_fine, test_mode="flow_curve")
stress_ws_fine = model_weighted_sum.predict(gamma_dot_fine, test_mode="flow_curve")

fig, ax = plt.subplots(figsize=(10, 6))
ax.loglog(gamma_dot, stress, "ko", markersize=7, label="Data")
ax.loglog(gamma_dot_fine, stress_pm_fine, "-", lw=2.5, color="C0", 
          label=f"per_mode (R$^2$={metrics_pm['R2']:.5f})")
ax.loglog(gamma_dot_fine, stress_ws_fine, "--", lw=2.5, color="C1", 
          label=f"weighted_sum (R$^2$={metrics_ws['R2']:.5f})")

ax.set_xlabel("Shear rate [1/s]", fontsize=12)
ax.set_ylabel("Stress [Pa]", fontsize=12)
ax.set_title("MLIKH: Comparison of Yield Formulations", fontsize=13)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3, which="both")
plt.tight_layout()
display(fig)
plt.close(fig)

### 4.4 Mode Contribution Analysis

In [None]:
# Analyze per-mode contributions (for per_mode model)
print("Mode Contributions (per_mode formulation):")
print("=" * 50)

for i in range(1, n_modes + 1):
    G_i = model_per_mode.parameters.get_value(f"G_{i}")
    sigma_y0_i = model_per_mode.parameters.get_value(f"sigma_y0_{i}")
    delta_sigma_y_i = model_per_mode.parameters.get_value(f"delta_sigma_y_{i}")
    tau_thix_i = model_per_mode.parameters.get_value(f"tau_thix_{i}")
    Gamma_i = model_per_mode.parameters.get_value(f"Gamma_{i}")
    
    print(f"\nMode {i}:")
    print(f"  G_{i} = {G_i:.3g} Pa")
    print(f"  sigma_y0_{i} = {sigma_y0_i:.3g} Pa")
    print(f"  delta_sigma_y_{i} = {delta_sigma_y_i:.3g} Pa")
    print(f"  tau_thix_{i} = {tau_thix_i:.3g} s")
    print(f"  Gamma_{i} = {Gamma_i:.4g}")

## 5. Bayesian Inference with NUTS

In [None]:
# Use per_mode model for Bayesian inference
model = model_per_mode
param_names = param_names_pm

initial_values = {
    name: model.parameters.get_value(name)
    for name in param_names
}

# Fast demo config
# FAST_MODE: reduced samples for CI; set FAST_MODE=0 for production
FAST_MODE = os.environ.get("FAST_MODE", "1") == "1"

if FAST_MODE:
    NUM_WARMUP = 50
    NUM_SAMPLES = 100
    NUM_CHAINS = 1
else:
    NUM_WARMUP = 200
    NUM_SAMPLES = 500
    NUM_CHAINS = 1

print(f"Running NUTS: {NUM_WARMUP} warmup + {NUM_SAMPLES} samples x {NUM_CHAINS} chain(s)")
t0 = time.time()
result = model.fit_bayesian(
    gamma_dot,
    stress,
    test_mode="flow_curve",
    num_warmup=NUM_WARMUP,
    num_samples=NUM_SAMPLES,
    num_chains=NUM_CHAINS,
    initial_values=initial_values,
    seed=42,
)
t_bayes = time.time() - t0
print(f"\nBayesian inference time: {t_bayes:.1f} s")

In [None]:
# Convergence diagnostics
all_pass = print_convergence_summary(result, param_names)

In [None]:
# ArviZ diagnostic plots (trace, pair, forest, energy, autocorrelation, rank)
display_arviz_diagnostics(result, param_names, fast_mode=FAST_MODE)

In [None]:
# Posterior predictive with 95% credible intervals
fig, ax = plot_posterior_predictive(
    gamma_dot, stress, model, result,
    test_mode="flow_curve", param_names=param_names,
    xlabel="Shear rate [1/s]", ylabel="Stress [Pa]",
    log_scale=True, title="Posterior Predictive Check (MLIKH per_mode)",
)
display(fig)
plt.close(fig)

## 6. Physical Interpretation

### Multi-Mode Benefits

1. **Distributed timescales**: Modes with different $\tau_{thix,i}$ capture multiple restructuring rates
2. **Flexible yield**: Per-mode allows different yield stresses for different structural components
3. **Complex flow curves**: Captures inflections and subtle curvature in data

### Mode Identification

- **Fast mode** (small $\tau_{thix}$): Responds quickly, dominates high-rate behavior
- **Slow mode** (large $\tau_{thix}$): Responds slowly, dominates low-rate behavior

### Per-Mode vs Weighted-Sum

- **Per-mode**: More flexible, each mode contributes independently
- **Weighted-sum**: More constrained, single yield surface with structure contribution

## 7. Save Results

In [None]:
# Save results for downstream notebooks
save_ikh_results(model, result, "mlikh", "flow_curve", param_names)

print("\nParameters saved for synthetic data generation in:")
print("  - NB09: Stress Relaxation")
print("  - NB11: SAOS")

## Key Takeaways

1. **MLIKH extends MIKH** to N modes with distributed thixotropic timescales

2. **Two yield formulations**:
   - **per_mode**: Independent yield surfaces, more flexible (7N + 1 parameters)
   - **weighted_sum**: Single global yield, more constrained (6 + 3N parameters)

3. **Parameter count**: 7N + 1 (per_mode) vs 6 + 3N (weighted_sum)

4. **Mode identification**: Different modes capture fast vs slow restructuring

5. **Multi-mode benefit**: Captures complex flow behavior that single-mode cannot

6. **Bayesian inference**: More challenging with more parameters, but NLSQ warm-start helps

## Next Steps

- **NB08**: MLIKH Startup (richer overshoot dynamics from distributed modes)
- **NB09**: MLIKH Relaxation (multi-exponential decay)
- **NB10**: MLIKH Creep (complex yielding behavior)
- **NB11**: MLIKH SAOS (broadened modulus spectra)
- **NB12**: MLIKH LAOS (enhanced nonlinearity)

## Further Reading

**MLIKH Framework:**
- [IKH Family Overview](../../docs/source/models/ikh/index.rst) — When to use MLIKH vs MIKH
- [MLIKH Reference](../../docs/source/models/ikh/ml_ikh.rst) — Stretched exponential decomposition, mode selection ($\beta$ → N), timescale distribution strategies, industrial applications (complex waxy crudes, bidisperse colloids, hierarchical gels)

**Key References:**
1. Wei et al. (2018). "A multimode structural kinetics constitutive equation for the transient rheology of thixotropic elasto‐viscoplastic fluids." *J. Rheol.*, 62(1), 321-342. https://doi.org/10.1122/1.4996752 — Original ML-IKH model
2. Kohlrausch (1854) / Williams & Watts (1970). — Stretched exponential (KWW) theory
3. Dimitriou & McKinley (2014). *Soft Matter*, 10, 6619-6644. — Single-mode MIKH foundation

In [None]:
# Cleanup: release JAX caches and Python garbage for sequential notebook runs
import gc
try:
    jax.clear_caches()
except Exception:
    pass
gc.collect()
print("Notebook complete. Memory cleaned up.")
