# FIKH Model: Stress Relaxation

## Protocol-Specific Context

**Stress relaxation** is the **signature protocol** for fractional models. After a step strain, FIKH predicts:

1. **Power-law tails**: $\sigma(t) \sim t^{-\alpha}$ at long times (not exponential decay)
2. **Mittag-Leffler function**: $E_{\alpha}(-(t/\tau)^{\alpha})$ interpolates short/long time
3. **Structure recovery**: $D_t^{\alpha} \lambda = (1-\lambda)/\tau_{thix}$ drives slow rebuilding

**Why this matters**: Classical models predict exponential decay $\exp(-t/\tau)$. FIKH's power-law tail directly reveals the fractional order $\alpha$ — no other protocol isolates this effect as cleanly.

> **Physical insight**: During relaxation at rest, structure slowly rebuilds via fractional kinetics. Lower $\alpha$ means stronger memory: the material "remembers" its broken state longer, slowing recovery.

> **Handbook:** See [FIKH Stress Relaxation](../../docs/source/models/fikh/fikh.rst#stress-relaxation) for Mittag-Leffler asymptotics and memory kernel details.

## Learning Objectives

1. Generate synthetic relaxation data showing power-law tails
2. Compare FIKH ($\alpha < 1$) vs classical IKH ($\alpha=1$) relaxation
3. Fit relaxation data to extract $\alpha$ from long-time decay
4. Validate Mittag-Leffler function against numerical predictions
5. Understand connection to Cole-Cole depression in frequency domain

## Prerequisites

- NB01: Flow curve (parameter calibration)
- Basic fractional calculus (Caputo derivative)

**Estimated Time:** 3-5 minutes (fast), 10-15 minutes (full)

## 1. Setup

In [None]:
# Google Colab setup
import sys

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    %pip install -q rheojax
    import os
    os.environ["JAX_ENABLE_X64"] = "true"
    print("RheoJAX installed successfully.")

In [None]:
%matplotlib inline
import os
import sys
import time
import warnings

# Robust path resolution for execution from any directory
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display

from rheojax.core.jax_config import safe_import_jax, verify_float64
from rheojax.models.fikh import FIKH

_nb_dir = Path(__file__).parent if "__file__" in dir() else Path.cwd()
_utils_candidates = [_nb_dir / ".." / "utils", Path("examples/utils"), _nb_dir.parent / "utils"]
for _p in _utils_candidates:
    if (_p / "fikh_tutorial_utils.py").exists():
        sys.path.insert(0, str(_p.resolve()))
        break
from fikh_tutorial_utils import (
    compute_fit_quality,
    generate_synthetic_relaxation,
    get_fikh_param_names,
    load_fikh_parameters,
    plot_alpha_sweep,
    print_alpha_interpretation,
    print_convergence_summary,
    print_parameter_comparison,
    save_fikh_results,
    set_model_parameters,
)

# Shared plotting utilities
sys.path.insert(0, os.path.dirname(os.path.abspath("")))
from utils.plotting_utils import (
    display_arviz_diagnostics,
    plot_nlsq_fit,
    plot_posterior_predictive,
)

jax, jnp = safe_import_jax()
verify_float64()

print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")

# ============================================================
# FAST_MODE Configuration
# ============================================================
FAST_MODE = os.environ.get("FAST_MODE", "1") == "1"

if FAST_MODE:
    print("FAST_MODE: reduced data/samples for quick validation")
    N_DATA_POINTS = 50  # Reduced from 200
    NUM_WARMUP = 50
    NUM_SAMPLES = 100
    NUM_CHAINS = 1
else:
    print("FULL mode: complete Bayesian inference")
    N_DATA_POINTS = 200  # Full resolution
    NUM_WARMUP = 200
    NUM_SAMPLES = 500
    NUM_CHAINS = 1

## 2. Theory: Fractional Relaxation

Stress relaxation reveals the **memory kernel** most clearly:

### Classical IKH ($\alpha = 1$)
$$
\sigma(t) \sim \sigma_0 \exp(-t/\tau)
$$

### FIKH ($0 < \alpha < 1$)
$$
\sigma(t) \sim \sigma_0 E_\alpha(-(t/\tau)^\alpha)
$$

where $E_\alpha$ is the Mittag-Leffler function with asymptotic behavior:
- **Short times**: $E_\alpha(x) \approx \exp(x)$ (exponential-like)
- **Long times**: $E_\alpha(-x) \sim x^{-1}/\Gamma(1-\alpha)$ (power-law tail)

### Key Observation

The **long-time tail** distinguishes FIKH from IKH:
- $\alpha = 1$: Exponential decay (fast)
- $\alpha < 1$: Power-law decay $\sim t^{-\alpha}$ (slow)

## 3. Load Calibrated Parameters

In [None]:
# Try to load parameters from NB01, fall back to defaults
try:
    calibrated_params = load_fikh_parameters("fikh", "flow_curve")
    print("Loaded calibrated parameters from NB01:")
    for name, val in calibrated_params.items():
        print(f"  {name:15s} = {val:.4g}")
except FileNotFoundError:
    print("NB01 parameters not found. Using defaults.")
    calibrated_params = None

In [None]:
# Create model and set parameters
model = FIKH(include_thermal=False, alpha_structure=0.7)

if calibrated_params is not None:
    set_model_parameters(model, calibrated_params)

param_names = get_fikh_param_names(include_thermal=False)
print(f"\nModel parameters:")
for name in param_names:
    print(f"  {name:15s} = {model.parameters.get_value(name):.4g}")

## 4. Generate Synthetic Data

In [None]:
# Generate synthetic relaxation data with 3% noise
SIGMA_0 = 100.0  # Initial stress
T_END = 100.0    # End time
NOISE_LEVEL = 0.03

time_data, stress_data = generate_synthetic_relaxation(
    model,
    sigma_0=SIGMA_0,
    t_end=T_END,
    n_points=N_DATA_POINTS,
    noise_level=NOISE_LEVEL,
    seed=42,
)

print(f"Generated synthetic relaxation data:")
print(f"  Initial stress: {SIGMA_0} Pa")
print(f"  Time range: [{time_data.min():.4f}, {time_data.max():.2f}] s")
print(f"  Noise level: {NOISE_LEVEL*100:.0f}%")
print(f"  Data points: {len(time_data)}")

In [None]:
# Plot synthetic data
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Linear scale
ax1.plot(time_data, stress_data, "ko", markersize=5, label="Synthetic data")
ax1.set_xlabel("Time [s]", fontsize=12)
ax1.set_ylabel("Stress [Pa]", fontsize=12)
ax1.set_title("Stress Relaxation (Linear Scale)", fontsize=13)
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# Log-log scale (shows power-law tail)
ax2.loglog(time_data, stress_data, "ko", markersize=5, label="Synthetic data")
ax2.set_xlabel("Time [s]", fontsize=12)
ax2.set_ylabel("Stress [Pa]", fontsize=12)
ax2.set_title("Stress Relaxation (Log-Log Scale)", fontsize=13)
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3, which="both")

plt.tight_layout()
display(fig)
plt.close(fig)

## 5. Alpha Sweep: Power-Law vs Exponential

In [None]:
# Show alpha effect on relaxation
alpha_values = [0.3, 0.5, 0.7, 0.9, 0.99]

fig = plot_alpha_sweep(
    model,
    protocol="relaxation",
    alpha_values=alpha_values,
    x_data=time_data,
    sigma_0=SIGMA_0,
    figsize=(14, 5),
)

# Add data to left panel
fig.axes[0].loglog(time_data, stress_data, "ko", markersize=3, alpha=0.5, label="Data")
fig.axes[0].legend(fontsize=8, loc="best")

display(fig)
plt.close(fig)

## 6. NLSQ Fitting

In [None]:
# Fit to synthetic data (verify recovery of true parameters)
model_fit = FIKH(include_thermal=False, alpha_structure=0.5)  # Start from different alpha

t0 = time.time()
model_fit.fit(time_data, stress_data, test_mode="relaxation", sigma_0=SIGMA_0, method='scipy')
t_nlsq = time.time() - t0

print(f"NLSQ fit time: {t_nlsq:.2f} s")
print(f"\nFitted vs True parameters:")
for name in param_names:
    fitted = model_fit.parameters.get_value(name)
    true_val = model.parameters.get_value(name)
    rel_err = abs(fitted - true_val) / (abs(true_val) + 1e-10) * 100
    print(f"  {name:15s}: fitted={fitted:.4g}, true={true_val:.4g} (err={rel_err:.1f}%)")

In [None]:
# Plot NLSQ fit with uncertainty band
stress_pred = model_fit.predict_relaxation(time_data, sigma_0=SIGMA_0)
metrics = compute_fit_quality(stress_data, stress_pred)

fig, ax = plot_nlsq_fit(
    time_data, stress_data, model_fit, test_mode="relaxation",
    param_names=param_names, log_scale=True,
    xlabel="Time [s]", ylabel="Stress [Pa]",
    title=f"FIKH Relaxation Fit (R$^2$ = {metrics['R2']:.5f})",
    sigma_0=SIGMA_0,
)
display(fig)
plt.close(fig)

## 7. Bayesian Inference

In [None]:
# Bayesian inference — clamp initial values inside bounds for stable NUTS init
initial_values = {}
for name in param_names:
    val = model_fit.parameters.get_value(name)
    lo, hi = model_fit.parameters[name].bounds
    eps = 1e-6 * max(abs(hi - lo), 1.0)
    initial_values[name] = float(np.clip(val, lo + eps, hi - eps))

print(f"Running NUTS: {NUM_WARMUP} warmup + {NUM_SAMPLES} samples x {NUM_CHAINS} chain(s)")
print(f"  FAST_MODE={FAST_MODE}, data points={len(time_data)}")
t0 = time.time()
result = model_fit.fit_bayesian(
    time_data,
    stress_data,
    test_mode="relaxation",
    sigma_0=SIGMA_0,
    num_warmup=NUM_WARMUP,
    num_samples=NUM_SAMPLES,
    num_chains=NUM_CHAINS,
    initial_values=initial_values,
    seed=42,
)
t_bayes = time.time() - t0
print(f"\nBayesian inference time: {t_bayes:.1f} s")

In [None]:
# Convergence and ArviZ diagnostics
all_pass = print_convergence_summary(result, param_names)

print("\n### Diagnostic Interpretation")
print("| Metric | Target | Meaning |")
print("|--------|--------|---------|")
print("| R-hat | < 1.01 | Chain mixing (< 1.05 acceptable) |")
print("| ESS | > 400 | Independent samples (> 100 min) |")
print("| Divergences | < 1% | Sampling quality indicator |")
print("\nAll diagnostics passing indicates reliable posterior estimates.")

display_arviz_diagnostics(result, param_names, fast_mode=FAST_MODE)

### Convergence Diagnostics

**Bayesian Diagnostic Targets:**

| Metric | Target | Interpretation |
|--------|--------|----------------|
| **R-hat** | < 1.01 | Chain mixing (< 1.05 acceptable) |
| **ESS** | > 400 | Independent samples (> 100 min) |
| **Divergences** | < 1% | Sampling quality indicator |

In [None]:
# Alpha posterior (key result for relaxation)
posterior = result.posterior_samples
alpha_samples = posterior["alpha_structure"]
alpha_median = np.median(alpha_samples)
alpha_lo, alpha_hi = np.percentile(alpha_samples, [2.5, 97.5])
true_alpha = model.parameters.get_value("alpha_structure")

print("\nFractional Order Recovery:")
print("=" * 50)
print(f"  True alpha:      {true_alpha:.3f}")
print(f"  Posterior:       {alpha_median:.3f} [{alpha_lo:.3f}, {alpha_hi:.3f}]")
print(f"  True in 95% CI:  {alpha_lo <= true_alpha <= alpha_hi}")

In [None]:
# Alpha posterior histogram
fig, ax = plt.subplots(figsize=(8, 5))
ax.hist(alpha_samples, bins=30, density=True, alpha=0.7, color="C0", edgecolor="white")
ax.axvline(true_alpha, color="red", lw=2, linestyle="--", label=f"True α = {true_alpha:.3f}")
ax.axvline(alpha_median, color="C0", lw=2, label=f"Median = {alpha_median:.3f}")
ax.set_xlabel("alpha_structure", fontsize=12)
ax.set_ylabel("Density", fontsize=12)
ax.set_title("Posterior Distribution of Fractional Order", fontsize=13)
ax.legend(fontsize=10)
plt.tight_layout()
display(fig)
plt.close(fig)

## 8. Save Results

In [None]:
save_fikh_results(model_fit, result, "fikh", "relaxation", param_names)
print("\nResults saved.")

## Key Takeaways

1. **Relaxation is the signature protocol for fractional models**
2. **Power-law tails $\sigma(t) \sim t^{-\alpha}$ distinguish FIKH from classical IKH**
3. **Mittag-Leffler function** governs transient → long-time crossover
4. **$\alpha$ directly measurable** from log-log slope at long times
5. **Structure recovery drives stress evolution** via fractional rebuilding
6. **Residual plots** confirm power-law vs exponential decay distinction

---

## Further Reading

- **[FIKH Relaxation Protocol](../../docs/source/models/fikh/fikh.rst#stress-relaxation)**: Mittag-Leffler solutions and asymptotic behavior
- **[Mittag-Leffler Function](../../docs/source/models/fikh/fikh.rst#mittag-leffler-relaxation)**: Generalized exponential and its properties

### Key References

1. Mainardi, F. (2010). *Fractional Calculus and Waves in Linear Viscoelasticity*. Imperial College Press.
2. Podlubny, I. (1999). *Fractional Differential Equations*. Academic Press.

### Next Steps

**Next**: NB04 (Creep) — delayed yielding and viscosity bifurcation with fractional memory