# MLIKH Model: Steady-State Flow Curves

## Learning Objectives

1. Fit the **MLIKH (Multi-Lambda IKH)** model to steady-state flow curve data
2. Understand **multi-mode thixotropy** with distributed timescales
3. Compare **per_mode** vs **weighted_sum** yield formulations
4. Analyze how multiple modes capture complex flow behavior
5. Calibrate parameters for downstream synthetic data generation

## Prerequisites

- NB01-06: MIKH tutorials (single-mode understanding)

## Runtime

- Fast demo (NUM_CHAINS=1, NUM_SAMPLES=500): ~3-4 minutes
- Full run (NUM_CHAINS=4, NUM_SAMPLES=2000): ~15-20 minutes

## 1. Setup

In [None]:
# Google Colab setup
import sys

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    %pip install -q rheojax
    import os
    os.environ["JAX_ENABLE_X64"] = "true"
    print("RheoJAX installed successfully.")

In [None]:
# Imports
%matplotlib inline
import os
import sys
import time
import warnings

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display

from rheojax.core.jax_config import safe_import_jax, verify_float64
from rheojax.models.ikh import MLIKH

# Add examples/utils to path for tutorial utilities
sys.path.insert(0, os.path.join("..", "utils"))
from ikh_tutorial_utils import (
    load_ml_ikh_flow_curve,
    save_ikh_results,
    print_convergence_summary,
    print_parameter_comparison,
    compute_fit_quality,
    get_mlikh_param_names,
)

jax, jnp = safe_import_jax()
verify_float64()

warnings.filterwarnings("ignore", category=FutureWarning)
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")

## 2. Theory: Multi-Mode IKH

The **MLIKH (Multi-Lambda IKH)** model extends MIKH to N modes connected in parallel, capturing **distributed thixotropic timescales**.

### Motivation

Real materials often exhibit:
- Multiple restructuring timescales
- Broad relaxation spectra
- Complex flow history dependence

### Two Yield Formulations

**1. Per-Mode Yield** (default):
- Each mode has independent yield surface
- Total stress = $\sum_i \sigma_i$ (parallel connection)
- Parameters: 7 per mode + 1 global

**2. Weighted-Sum Yield**:
- Single global yield surface: $\sigma_y = \sigma_{y0} + k_3 \sum_i w_i \lambda_i$
- All modes share elastic/plastic response
- Parameters: 5 global + 3 per mode

### Per-Mode Parameters (for each mode $i$)

| Parameter | Symbol | Description |
|-----------|--------|-------------|
| `G_i` | $G_i$ | Mode shear modulus (Pa) |
| `C_i` | $C_i$ | Kinematic hardening modulus (Pa) |
| `gamma_dyn_i` | $\gamma_{dyn,i}$ | Dynamic recovery |
| `sigma_y0_i` | $\sigma_{y0,i}$ | Minimal yield stress (Pa) |
| `delta_sigma_y_i` | $\Delta\sigma_{y,i}$ | Structural yield contribution (Pa) |
| `tau_thix_i` | $\tau_{thix,i}$ | Thixotropic timescale (s) |
| `Gamma_i` | $\Gamma_i$ | Breakdown coefficient |

### Key Physics

- **Distributed timescales**: Different modes restructure at different rates
- **Parallel stress**: Total stress is sum of mode contributions
- **Independent yielding**: Each mode can yield independently (per_mode)

## 3. Load Data

In [None]:
# Load flow curve data (same as MIKH NB01)
gamma_dot, stress = load_ml_ikh_flow_curve(instrument="ARES_up")

print(f"Data points: {len(gamma_dot)}")
print(f"Shear rate range: [{gamma_dot.min():.4f}, {gamma_dot.max():.2f}] 1/s")
print(f"Stress range: [{stress.min():.2f}, {stress.max():.2f}] Pa")

In [None]:
# Plot raw data
fig, ax = plt.subplots(figsize=(10, 6))
ax.loglog(gamma_dot, stress, "ko", markersize=7, label="Data")
ax.set_xlabel("Shear rate [1/s]", fontsize=12)
ax.set_ylabel("Stress [Pa]", fontsize=12)
ax.set_title("ML-IKH Flow Curve Data", fontsize=13)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3, which="both")
plt.tight_layout()
display(fig)
plt.close(fig)

## 4. NLSQ Fitting

### 4.1 Per-Mode Yield (2 modes)

In [None]:
# Create MLIKH model with 2 modes, per_mode yield
n_modes = 2
model_per_mode = MLIKH(n_modes=n_modes, yield_mode="per_mode")
param_names_pm = get_mlikh_param_names(n_modes=n_modes, yield_mode="per_mode")

print(f"MLIKH (per_mode, {n_modes} modes): {len(param_names_pm)} parameters")
print(f"Parameters: {param_names_pm}")

In [None]:
# Fit model
t0 = time.time()
model_per_mode.fit(gamma_dot, stress, test_mode="flow_curve")
t_nlsq = time.time() - t0

print(f"NLSQ fit time: {t_nlsq:.2f} s")
print(f"\nFitted parameters (per-mode):")
for name in param_names_pm:
    val = model_per_mode.parameters.get_value(name)
    print(f"  {name:18s} = {val:.4g}")

In [None]:
# Compute fit quality
stress_pred_pm = model_per_mode.predict(gamma_dot, test_mode="flow_curve")
metrics_pm = compute_fit_quality(stress, stress_pred_pm)

print(f"\nFit Quality (per_mode):")
print(f"  R^2:   {metrics_pm['R2']:.6f}")
print(f"  RMSE:  {metrics_pm['RMSE']:.4g} Pa")

### 4.2 Weighted-Sum Yield (2 modes)

In [None]:
# Create MLIKH model with weighted_sum yield
model_weighted_sum = MLIKH(n_modes=n_modes, yield_mode="weighted_sum")
param_names_ws = get_mlikh_param_names(n_modes=n_modes, yield_mode="weighted_sum")

print(f"MLIKH (weighted_sum, {n_modes} modes): {len(param_names_ws)} parameters")
print(f"Parameters: {param_names_ws}")

In [None]:
# Fit model
t0 = time.time()
model_weighted_sum.fit(gamma_dot, stress, test_mode="flow_curve")
t_nlsq_ws = time.time() - t0

print(f"NLSQ fit time: {t_nlsq_ws:.2f} s")

# Compute fit quality
stress_pred_ws = model_weighted_sum.predict(gamma_dot, test_mode="flow_curve")
metrics_ws = compute_fit_quality(stress, stress_pred_ws)

print(f"\nFit Quality (weighted_sum):")
print(f"  R^2:   {metrics_ws['R2']:.6f}")
print(f"  RMSE:  {metrics_ws['RMSE']:.4g} Pa")

### 4.3 Compare Yield Formulations

In [None]:
# Plot comparison
gamma_dot_fine = np.logspace(
    np.log10(gamma_dot.min()) - 0.5,
    np.log10(gamma_dot.max()) + 0.3,
    200,
)
stress_pm_fine = model_per_mode.predict(gamma_dot_fine, test_mode="flow_curve")
stress_ws_fine = model_weighted_sum.predict(gamma_dot_fine, test_mode="flow_curve")

fig, ax = plt.subplots(figsize=(10, 6))
ax.loglog(gamma_dot, stress, "ko", markersize=7, label="Data")
ax.loglog(gamma_dot_fine, stress_pm_fine, "-", lw=2.5, color="C0", 
          label=f"per_mode (R$^2$={metrics_pm['R2']:.5f})")
ax.loglog(gamma_dot_fine, stress_ws_fine, "--", lw=2.5, color="C1", 
          label=f"weighted_sum (R$^2$={metrics_ws['R2']:.5f})")

ax.set_xlabel("Shear rate [1/s]", fontsize=12)
ax.set_ylabel("Stress [Pa]", fontsize=12)
ax.set_title("MLIKH: Comparison of Yield Formulations", fontsize=13)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3, which="both")
plt.tight_layout()
display(fig)
plt.close(fig)

### 4.4 Mode Contribution Analysis

In [None]:
# Analyze per-mode contributions (for per_mode model)
print("Mode Contributions (per_mode formulation):")
print("=" * 50)

for i in range(1, n_modes + 1):
    G_i = model_per_mode.parameters.get_value(f"G_{i}")
    sigma_y0_i = model_per_mode.parameters.get_value(f"sigma_y0_{i}")
    delta_sigma_y_i = model_per_mode.parameters.get_value(f"delta_sigma_y_{i}")
    tau_thix_i = model_per_mode.parameters.get_value(f"tau_thix_{i}")
    Gamma_i = model_per_mode.parameters.get_value(f"Gamma_{i}")
    
    print(f"\nMode {i}:")
    print(f"  G_{i} = {G_i:.3g} Pa")
    print(f"  sigma_y0_{i} = {sigma_y0_i:.3g} Pa")
    print(f"  delta_sigma_y_{i} = {delta_sigma_y_i:.3g} Pa")
    print(f"  tau_thix_{i} = {tau_thix_i:.3g} s")
    print(f"  Gamma_{i} = {Gamma_i:.4g}")

## 5. Bayesian Inference with NUTS

In [None]:
# Use per_mode model for Bayesian inference
model = model_per_mode
param_names = param_names_pm

initial_values = {
    name: model.parameters.get_value(name)
    for name in param_names
}

# Fast demo config
NUM_WARMUP = 200
NUM_SAMPLES = 500
NUM_CHAINS = 1

print(f"Running NUTS: {NUM_WARMUP} warmup + {NUM_SAMPLES} samples x {NUM_CHAINS} chain(s)")
t0 = time.time()
result = model.fit_bayesian(
    gamma_dot,
    stress,
    test_mode="flow_curve",
    num_warmup=NUM_WARMUP,
    num_samples=NUM_SAMPLES,
    num_chains=NUM_CHAINS,
    initial_values=initial_values,
    seed=42,
)
t_bayes = time.time() - t0
print(f"\nBayesian inference time: {t_bayes:.1f} s")

In [None]:
# Convergence diagnostics
all_pass = print_convergence_summary(result, param_names)

In [None]:
# Trace plots for key mode parameters
idata = result.to_inference_data()
mode_params = ["G_1", "G_2", "tau_thix_1", "tau_thix_2", "sigma_y0_1", "sigma_y0_2"]
axes = az.plot_trace(idata, var_names=mode_params, figsize=(12, 10))
fig = axes.ravel()[0].figure
fig.suptitle("Trace Plots (Mode Parameters)", fontsize=14, y=1.00)
plt.tight_layout()
display(fig)
plt.close(fig)

In [None]:
# Forest plot
axes = az.plot_forest(
    idata,
    var_names=mode_params,
    combined=True,
    hdi_prob=0.95,
    figsize=(10, 6),
)
fig = axes.ravel()[0].figure
fig.suptitle("95% Credible Intervals (Mode Parameters)", fontsize=13)
plt.tight_layout()
display(fig)
plt.close(fig)

In [None]:
# Posterior predictive
posterior = result.posterior_samples
# FAST mode: 10 draws for batch testing
# FULL mode: 100 draws for publication
FAST_POSTERIOR_PREDICTIVE = True
n_draws = min(10 if FAST_POSTERIOR_PREDICTIVE else 100, len(list(posterior.values())[0]))

fig, ax = plt.subplots(figsize=(10, 6))
ax.loglog(gamma_dot, stress, "ko", markersize=7, label="Data")

# Plot posterior samples
for i in range(n_draws):
    params_i = jnp.array([posterior[name][i] for name in param_names])
    pred_i = model.model_function(jnp.array(gamma_dot_fine), params_i, test_mode="flow_curve")
    ax.loglog(gamma_dot_fine, pred_i, "-", color="C0", alpha=0.05, lw=0.5)

ax.set_xlabel("Shear rate [1/s]", fontsize=12)
ax.set_ylabel("Stress [Pa]", fontsize=12)
ax.set_title("Posterior Predictive Check (MLIKH per_mode)", fontsize=13)
ax.grid(True, alpha=0.3, which="both")
plt.tight_layout()
display(fig)
plt.close(fig)

## 6. Physical Interpretation

### Multi-Mode Benefits

1. **Distributed timescales**: Modes with different $\tau_{thix,i}$ capture multiple restructuring rates
2. **Flexible yield**: Per-mode allows different yield stresses for different structural components
3. **Complex flow curves**: Captures inflections and subtle curvature in data

### Mode Identification

- **Fast mode** (small $\tau_{thix}$): Responds quickly, dominates high-rate behavior
- **Slow mode** (large $\tau_{thix}$): Responds slowly, dominates low-rate behavior

### Per-Mode vs Weighted-Sum

- **Per-mode**: More flexible, each mode contributes independently
- **Weighted-sum**: More constrained, single yield surface with structure contribution

## 7. Save Results

In [None]:
# Save results for downstream notebooks
save_ikh_results(model, result, "mlikh", "flow_curve", param_names)

print("\nParameters saved for synthetic data generation in:")
print("  - NB09: Stress Relaxation")
print("  - NB11: SAOS")

## Key Takeaways

1. **MLIKH extends MIKH** to N modes with distributed thixotropic timescales

2. **Two yield formulations**:
   - **per_mode**: Independent yield surfaces, more flexible
   - **weighted_sum**: Single global yield, more constrained

3. **Parameter count**: 7N + 1 (per_mode) vs 5 + 3N (weighted_sum)

4. **Mode identification**: Different modes capture fast vs slow restructuring

5. **Multi-mode benefit**: Captures complex flow behavior that single-mode cannot

6. **Bayesian inference**: More challenging with more parameters, but NLSQ warm-start helps

### Next Steps

- **NB08**: MLIKH Startup (richer overshoot dynamics)
- **NB09**: MLIKH Relaxation (multi-mode decay)
- **NB10**: MLIKH Creep
- **NB11**: MLIKH SAOS (broadened spectra)
- **NB12**: MLIKH LAOS