# MLIKH Model: Stress Relaxation

## Learning Objectives

1. Generate **synthetic relaxation** data from calibrated MLIKH parameters
2. Understand **multi-mode relaxation** behavior
3. Compare with single-mode Maxwell relaxation
4. Observe **Prony-series-like** decay from parallel modes

## Prerequisites

- NB07: MLIKH Flow Curve (provides calibrated parameters)

## Runtime

- Fast demo: ~3-4 minutes
- Full run: ~12-15 minutes

## 1. Setup

In [None]:
# Google Colab setup
import sys

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    %pip install -q rheojax
    import os
    os.environ["JAX_ENABLE_X64"] = "true"
    print("RheoJAX installed successfully.")

In [None]:
# Imports
%matplotlib inline
import os
import sys
import time
import warnings

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display

from rheojax.core.jax_config import safe_import_jax, verify_float64
from rheojax.models.ikh import MLIKH

# Add examples/utils to path
sys.path.insert(0, os.path.join("..", "utils"))
from ikh_tutorial_utils import (
    load_ikh_parameters,
    set_model_parameters,
    generate_synthetic_relaxation,
    save_ikh_results,
    print_convergence_summary,
    compute_fit_quality,
    get_mlikh_param_names,
)

jax, jnp = safe_import_jax()
verify_float64()

warnings.filterwarnings("ignore", category=FutureWarning)
print(f"JAX version: {jax.__version__}")

## 2. Theory: Multi-Mode Relaxation

With N modes in parallel, the total relaxation modulus is:

$$
G(t) = \sum_{i=1}^N G_i \exp\left(-\frac{t}{\tau_{M,i}}\right)
$$

This is a **Prony series** representation, providing:
- Multi-exponential decay
- Broader relaxation spectrum
- Better fit to complex materials

## 3. Load Calibrated Parameters

In [None]:
# Load calibrated parameters from NB07
n_modes = 2
try:
    calibrated_params = load_ikh_parameters("mlikh", "flow_curve")
    print("Loaded calibrated parameters from NB07")
except FileNotFoundError:
    print("NB07 results not found. Using default parameters.")
    calibrated_params = None

In [None]:
# Create model and set parameters
model = MLIKH(n_modes=n_modes, yield_mode="per_mode")
param_names = get_mlikh_param_names(n_modes=n_modes, yield_mode="per_mode")

if calibrated_params is not None:
    set_model_parameters(model, calibrated_params)

## 4. Generate Synthetic Data

In [None]:
# Generate synthetic relaxation data
sigma_0 = 100.0
t_end = 500.0
n_points = 200

t_data, stress_data = generate_synthetic_relaxation(
    model,
    sigma_0=sigma_0,
    t_end=t_end,
    n_points=n_points,
    noise_level=0.02,
    seed=42,
)

print(f"Generated synthetic data:")
print(f"  Time range: [{t_data.min():.2f}, {t_data.max():.1f}] s")
print(f"  Stress range: [{stress_data.min():.2f}, {stress_data.max():.2f}] Pa")

In [None]:
# Plot synthetic data
fig, ax = plt.subplots(figsize=(10, 6))
ax.loglog(t_data, stress_data, "ko", markersize=5, alpha=0.7, label="Synthetic data")
ax.set_xlabel("Time [s]", fontsize=12)
ax.set_ylabel("Stress [Pa]", fontsize=12)
ax.set_title("Synthetic MLIKH Relaxation Data", fontsize=13)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3, which="both")
plt.tight_layout()
display(fig)
plt.close(fig)

## 5. NLSQ Fitting

In [None]:
# Fit model
model_fit = MLIKH(n_modes=n_modes, yield_mode="per_mode")

t0 = time.time()
model_fit.fit(t_data, stress_data, test_mode="relaxation", sigma_0=sigma_0, method='scipy')
t_nlsq = time.time() - t0

print(f"NLSQ fit time: {t_nlsq:.2f} s")

# Compute fit quality
stress_pred = model_fit.predict_relaxation(t_data, sigma_0=sigma_0)
metrics = compute_fit_quality(stress_data, stress_pred)

print(f"\nFit Quality:")
print(f"  R^2:   {metrics['R2']:.6f}")
print(f"  RMSE:  {metrics['RMSE']:.4g} Pa")

In [None]:
# Plot fit
t_fine = np.logspace(np.log10(t_data.min()), np.log10(t_data.max()), 200)
stress_fine = model_fit.predict_relaxation(t_fine, sigma_0=sigma_0)

fig, ax = plt.subplots(figsize=(10, 6))
ax.loglog(t_data, stress_data, "ko", markersize=5, alpha=0.7, label="Data")
ax.loglog(t_fine, stress_fine, "-", lw=2.5, color="C0", label="MLIKH fit")
ax.set_xlabel("Time [s]", fontsize=12)
ax.set_ylabel("Stress [Pa]", fontsize=12)
ax.set_title(f"MLIKH Relaxation Fit (R$^2$ = {metrics['R2']:.5f})", fontsize=13)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3, which="both")
plt.tight_layout()
display(fig)
plt.close(fig)

## 6. Bayesian Inference

In [None]:
# Bayesian inference
initial_values = {name: model_fit.parameters.get_value(name) for name in param_names}

NUM_WARMUP = 200
NUM_SAMPLES = 500
NUM_CHAINS = 1

print(f"Running NUTS: {NUM_WARMUP} warmup + {NUM_SAMPLES} samples x {NUM_CHAINS} chain(s)")
t0 = time.time()
result = model_fit.fit_bayesian(
    t_data,
    stress_data,
    test_mode="relaxation",
    sigma_0=sigma_0,
    num_warmup=NUM_WARMUP,
    num_samples=NUM_SAMPLES,
    num_chains=NUM_CHAINS,
    initial_values=initial_values,
    seed=42,
)
t_bayes = time.time() - t0
print(f"\nBayesian inference time: {t_bayes:.1f} s")

In [None]:
# Convergence diagnostics
all_pass = print_convergence_summary(result, param_names)

## 7. Physical Interpretation

### Multi-Mode Relaxation

- **Short times**: Fast mode dominates (small $\tau_{M,1}$)
- **Long times**: Slow mode dominates (large $\tau_{M,2}$)
- **Transition**: Both modes contribute, creating broader decay

### Thixotropic Effects

During relaxation, structure rebuilds for each mode:
$$
\frac{d\lambda_i}{dt} = \frac{1 - \lambda_i}{\tau_{thix,i}}
$$

## 8. Save Results

In [None]:
# Save results
save_ikh_results(model_fit, result, "mlikh", "relaxation", param_names)

## Key Takeaways

1. **Multi-mode relaxation** provides Prony-series-like behavior
2. **Distributed timescales** create broader relaxation spectrum
3. **Mode contributions** dominate at different time ranges
4. **Thixotropic restructuring** modifies late-stage relaxation

### Next Steps

- **NB10**: MLIKH Creep
- **NB11**: MLIKH SAOS
- **NB12**: MLIKH LAOS