# FMLIKH Model: Steady-State Flow Curves

## Learning Objectives

1. Fit the **FMLIKH (Fractional Multi-Layer IKH)** model to flow curve data
2. Understand **multi-mode relaxation** with fractional memory
3. Analyze mode contributions to total stress
4. Compare single-mode FIKH vs multi-mode FMLIKH predictions
5. Use Bayesian inference for multi-modal parameter estimation
6. Calibrate parameters for downstream synthetic data generation

## Prerequisites

- NB01-NB06: FIKH tutorials (single-mode concepts)
- Bayesian inference fundamentals

## Runtime

- Fast demo (NUM_CHAINS=1, NUM_SAMPLES=500): ~5-8 minutes
- Full run (NUM_CHAINS=4, NUM_SAMPLES=2000): ~20-30 minutes

## 1. Setup

In [None]:
# Google Colab setup
import sys

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    %pip install -q rheojax
    import os
    os.environ["JAX_ENABLE_X64"] = "true"
    print("RheoJAX installed successfully.")

In [None]:
# Imports
%matplotlib inline
import os
import sys
import time
import warnings

import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display

from rheojax.core.jax_config import safe_import_jax, verify_float64
from rheojax.models.fikh import FMLIKH

# Robust path resolution for execution from any directory
from pathlib import Path
_nb_dir = Path(__file__).parent if "__file__" in dir() else Path.cwd()
_utils_candidates = [_nb_dir / ".." / "utils", Path("examples/utils"), _nb_dir.parent / "utils"]
for _p in _utils_candidates:
    if (_p / "fikh_tutorial_utils.py").exists():
        sys.path.insert(0, str(_p.resolve()))
        break
from fikh_tutorial_utils import (
    load_ml_ikh_flow_curve,
    save_fikh_results,
    print_convergence_summary,
    print_parameter_comparison,
    compute_fit_quality,
    get_fmlikh_param_names,
    print_alpha_interpretation,
)

# Shared plotting utilities
sys.path.insert(0, os.path.dirname(os.path.abspath("")))
from utils.plotting_utils import (
    plot_nlsq_fit,
    display_arviz_diagnostics,
    plot_posterior_predictive,
)

# FAST_MODE: controlled by environment variable
FAST_MODE = os.environ.get("FAST_MODE", "1") == "1"

jax, jnp = safe_import_jax()
verify_float64()

print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
print(f"FAST_MODE: {FAST_MODE}")

## 2. Theory: FMLIKH Model

The **Fractional Multi-Layer IKH (FMLIKH)** model extends FIKH with multiple viscoelastic modes:

### Total Stress
$$
\sigma_{total} = \sum_{i=1}^{N} \sigma_i + \eta_{\infty} \dot{\gamma}
$$

Each mode $i$ has its own:
- $G_i$: Shear modulus
- $\eta_i$: Viscosity (defines $\tau_i = \eta_i / G_i$)
- $C_i$: Kinematic hardening modulus
- $\gamma_{dyn,i}$: Dynamic recovery parameter

### Shared Parameters
- Yield stress: $\sigma_{y0}$, $\Delta\sigma_y$
- Thixotropy: $\tau_{thix}$, $\Gamma$
- Fractional order: $\alpha$ (shared or per-mode)

### Why Multi-Mode?

1. **Broad relaxation spectra**: Real materials have multiple time scales
2. **Wide-frequency SAOS**: Single mode cannot capture full frequency range
3. **Complex startup**: Multiple overshoot features
4. **Prony-series-like**: Generalized Maxwell behavior

## 3. Load Data

In [None]:
# Load flow curve data
gamma_dot, stress = load_ml_ikh_flow_curve(instrument="ARES_up")

print(f"Data points: {len(gamma_dot)}")
print(f"Shear rate range: [{gamma_dot.min():.4f}, {gamma_dot.max():.2f}] 1/s")
print(f"Stress range: [{stress.min():.2f}, {stress.max():.2f}] Pa")

In [None]:
# Plot data
fig, ax = plt.subplots(figsize=(10, 6))
ax.loglog(gamma_dot, stress, "ko", markersize=8, label="Data")
ax.set_xlabel("Shear rate [1/s]", fontsize=12)
ax.set_ylabel("Stress [Pa]", fontsize=12)
ax.set_title("ML-IKH Flow Curve Data", fontsize=13)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3, which="both")
plt.tight_layout()
display(fig)
plt.close(fig)

## 4. NLSQ Fitting

In [None]:
# Create FMLIKH model with 3 modes
N_MODES = 3
model = FMLIKH(
    n_modes=N_MODES,
    include_thermal=False,
    shared_alpha=True,
    alpha_structure=0.7,
)

print(f"FMLIKH with {N_MODES} modes")
print(f"Shared alpha: True")
print(f"Total parameters: {len(list(model.parameters.keys()))}")

In [None]:
# Fit model
t0 = time.time()
model.fit(gamma_dot, stress, test_mode="flow_curve")
t_nlsq = time.time() - t0

param_names = get_fmlikh_param_names(n_modes=N_MODES, shared_alpha=True)

print(f"NLSQ fit time: {t_nlsq:.2f} s")
print(f"\nFitted parameters:")
for name in param_names:
    try:
        val = model.parameters.get_value(name)
        print(f"  {name:15s} = {val:.4g}")
    except KeyError:
        pass

In [None]:
# Get mode information
mode_info = model.get_mode_info()

print("\nMode Information:")
print("=" * 60)
print(f"Shared alpha: {mode_info.get('alpha_shared', 'N/A'):.3f}")
print("\nPer-mode parameters:")
print(f"{'Mode':>6s}  {'G [Pa]':>12s}  {'η [Pa.s]':>12s}  {'τ [s]':>12s}")
print("-" * 60)
for mode in mode_info["modes"]:
    print(f"{mode['mode']:>6d}  {mode['G']:>12.4g}  {mode['eta']:>12.4g}  {mode['tau']:>12.4g}")

In [None]:
# Compute fit quality and plot with uncertainty band
stress_pred = model.predict(gamma_dot, test_mode="flow_curve")
metrics = compute_fit_quality(stress, stress_pred)

print(f"\nFit Quality:")
print(f"  R^2:   {metrics['R2']:.6f}")
print(f"  RMSE:  {metrics['RMSE']:.4g} Pa")

fig, ax = plot_nlsq_fit(
    gamma_dot, stress, model, test_mode="flow_curve",
    param_names=param_names, log_scale=True,
    xlabel="Shear rate [1/s]", ylabel="Stress [Pa]",
    title=f"FMLIKH Flow Curve Fit ({N_MODES} modes, R$^2$ = {metrics['R2']:.5f})",
)
display(fig)
plt.close(fig)

## 5. Mode Decomposition

In [None]:
# Visualize relaxation time spectrum
mode_info = model.get_mode_info()

taus = [m["tau"] for m in mode_info["modes"]]
Gs = [m["G"] for m in mode_info["modes"]]

fig, ax = plt.subplots(figsize=(8, 5))
ax.bar(range(len(taus)), Gs, tick_label=[f"Mode {i}\nτ={t:.2g}s" for i, t in enumerate(taus)])
ax.set_xlabel("Mode", fontsize=12)
ax.set_ylabel("Modulus G [Pa]", fontsize=12)
ax.set_title("Relaxation Time Spectrum", fontsize=13)
ax.set_yscale("log")
plt.tight_layout()
display(fig)
plt.close(fig)

## 6. Bayesian Inference

In [None]:
# Bayesian inference
initial_values = {}
for name in param_names:
    try:
        initial_values[name] = model.parameters.get_value(name)
    except KeyError:
        pass

NUM_WARMUP = 200
NUM_SAMPLES = 500
NUM_CHAINS = 1

print(f"Running NUTS: {NUM_WARMUP} warmup + {NUM_SAMPLES} samples x {NUM_CHAINS} chain(s)")
t0 = time.time()
result = model.fit_bayesian(
    gamma_dot,
    stress,
    test_mode="flow_curve",
    num_warmup=NUM_WARMUP,
    num_samples=NUM_SAMPLES,
    num_chains=NUM_CHAINS,
    initial_values=initial_values,
    seed=42,
)
t_bayes = time.time() - t0
print(f"\nBayesian inference time: {t_bayes:.1f} s")

In [None]:
# Convergence (check key parameters)
key_params = ["G_0", "G_1", "G_2", "alpha_structure", "sigma_y0", "tau_thix"]
all_pass = print_convergence_summary(result, key_params)

In [None]:
# ArviZ diagnostic plots (trace, pair, forest, energy, autocorr, rank)
display_arviz_diagnostics(result, key_params, fast_mode=FAST_MODE)

In [None]:
# Parameter comparison
posterior = result.posterior_samples
print_parameter_comparison(model, posterior, key_params)

## 7. Save Results

In [None]:
save_fikh_results(model, result, "fmlikh", "flow_curve", param_names)
print("\nParameters saved for synthetic data in NB09, NB11.")

## Key Takeaways

1. **FMLIKH extends FIKH with multiple relaxation modes**
2. **Multi-mode captures broad relaxation spectra** in real materials
3. **Shared alpha** reduces parameters while maintaining memory effects
4. **Mode decomposition** shows contribution of each time scale
5. **More parameters require careful initialization** and longer inference

### Next Steps

- **NB08**: Startup shear (multi-mode overshoot)
- **NB09**: Stress relaxation (multi-exponential decay)
- **NB10**: Creep (distributed time scales)
- **NB11**: SAOS (broadened spectra)
- **NB12**: LAOS (multi-mode harmonics)