# FIKH Model: Stress Relaxation

## Learning Objectives

1. Generate **synthetic relaxation data** from NB01 calibrated parameters
2. Understand **power-law vs exponential decay** controlled by alpha_structure
3. Fit FIKH to relaxation data and infer fractional order
4. Visualize how alpha affects long-time relaxation tails
5. Use Bayesian inference to quantify uncertainty in alpha

## Prerequisites

- **NB01**: FIKH Flow Curve (provides calibrated parameters)
- Bayesian inference fundamentals

## Runtime

- Fast demo (NUM_CHAINS=1, NUM_SAMPLES=500): ~3-5 minutes
- Full run (NUM_CHAINS=4, NUM_SAMPLES=2000): ~15-20 minutes

## 1. Setup

In [None]:
# Google Colab setup
import sys

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    %pip install -q rheojax
    import os
    os.environ["JAX_ENABLE_X64"] = "true"
    print("RheoJAX installed successfully.")

In [None]:
%matplotlib inline
import os
import sys
import time
import warnings

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display

from rheojax.core.jax_config import safe_import_jax, verify_float64
from rheojax.models.fikh import FIKH

# Robust path resolution for execution from any directory
from pathlib import Path
_nb_dir = Path(__file__).parent if "__file__" in dir() else Path.cwd()
_utils_candidates = [_nb_dir / ".." / "utils", Path("examples/utils"), _nb_dir.parent / "utils"]
for _p in _utils_candidates:
    if (_p / "fikh_tutorial_utils.py").exists():
        sys.path.insert(0, str(_p.resolve()))
        break
from fikh_tutorial_utils import (
    load_fikh_parameters,
    generate_synthetic_relaxation,
    save_fikh_results,
    set_model_parameters,
    print_convergence_summary,
    print_parameter_comparison,
    compute_fit_quality,
    get_fikh_param_names,
    plot_alpha_sweep,
    print_alpha_interpretation,
)

jax, jnp = safe_import_jax()
verify_float64()

print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")

# ============================================================
# FAST_MODE Configuration
# ============================================================
# True  = quick validation (~5-10 min, reduced data + samples)
# False = full Bayesian run (may need 30+ min on CPU)
FAST_MODE = True

if FAST_MODE:
    print("FAST_MODE: reduced data/samples for quick validation")
    N_DATA_POINTS = 50  # Reduced from 200
    NUM_WARMUP = 50
    NUM_SAMPLES = 100
    NUM_CHAINS = 1
else:
    print("FULL mode: complete Bayesian inference")
    N_DATA_POINTS = 200  # Full resolution
    NUM_WARMUP = 200
    NUM_SAMPLES = 500
    NUM_CHAINS = 1

## 2. Theory: Fractional Relaxation

Stress relaxation reveals the **memory kernel** most clearly:

### Classical IKH ($\alpha = 1$)
$$
\sigma(t) \sim \sigma_0 \exp(-t/\tau)
$$

### FIKH ($0 < \alpha < 1$)
$$
\sigma(t) \sim \sigma_0 E_\alpha(-(t/\tau)^\alpha)
$$

where $E_\alpha$ is the Mittag-Leffler function with asymptotic behavior:
- **Short times**: $E_\alpha(x) \approx \exp(x)$ (exponential-like)
- **Long times**: $E_\alpha(-x) \sim x^{-1}/\Gamma(1-\alpha)$ (power-law tail)

### Key Observation

The **long-time tail** distinguishes FIKH from IKH:
- $\alpha = 1$: Exponential decay (fast)
- $\alpha < 1$: Power-law decay $\sim t^{-\alpha}$ (slow)

## 3. Load Calibrated Parameters

In [None]:
# Try to load parameters from NB01, fall back to defaults
try:
    calibrated_params = load_fikh_parameters("fikh", "flow_curve")
    print("Loaded calibrated parameters from NB01:")
    for name, val in calibrated_params.items():
        print(f"  {name:15s} = {val:.4g}")
except FileNotFoundError:
    print("NB01 parameters not found. Using defaults.")
    calibrated_params = None

In [None]:
# Create model and set parameters
model = FIKH(include_thermal=False, alpha_structure=0.7)

if calibrated_params is not None:
    set_model_parameters(model, calibrated_params)

param_names = get_fikh_param_names(include_thermal=False)
print(f"\nModel parameters:")
for name in param_names:
    print(f"  {name:15s} = {model.parameters.get_value(name):.4g}")

## 4. Generate Synthetic Data

In [None]:
# Generate synthetic relaxation data with 3% noise
SIGMA_0 = 100.0  # Initial stress
T_END = 100.0    # End time
NOISE_LEVEL = 0.03

time_data, stress_data = generate_synthetic_relaxation(
    model,
    sigma_0=SIGMA_0,
    t_end=T_END,
    n_points=N_DATA_POINTS,
    noise_level=NOISE_LEVEL,
    seed=42,
)

print(f"Generated synthetic relaxation data:")
print(f"  Initial stress: {SIGMA_0} Pa")
print(f"  Time range: [{time_data.min():.4f}, {time_data.max():.2f}] s")
print(f"  Noise level: {NOISE_LEVEL*100:.0f}%")
print(f"  Data points: {len(time_data)}")

In [None]:
# Plot synthetic data
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Linear scale
ax1.plot(time_data, stress_data, "ko", markersize=5, label="Synthetic data")
ax1.set_xlabel("Time [s]", fontsize=12)
ax1.set_ylabel("Stress [Pa]", fontsize=12)
ax1.set_title("Stress Relaxation (Linear Scale)", fontsize=13)
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# Log-log scale (shows power-law tail)
ax2.loglog(time_data, stress_data, "ko", markersize=5, label="Synthetic data")
ax2.set_xlabel("Time [s]", fontsize=12)
ax2.set_ylabel("Stress [Pa]", fontsize=12)
ax2.set_title("Stress Relaxation (Log-Log Scale)", fontsize=13)
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3, which="both")

plt.tight_layout()
display(fig)
plt.close(fig)

## 5. Alpha Sweep: Power-Law vs Exponential

In [None]:
# Show alpha effect on relaxation
alpha_values = [0.3, 0.5, 0.7, 0.9, 0.99]

fig = plot_alpha_sweep(
    model,
    protocol="relaxation",
    alpha_values=alpha_values,
    x_data=time_data,
    sigma_0=SIGMA_0,
    figsize=(14, 5),
)

# Add data to left panel
fig.axes[0].loglog(time_data, stress_data, "ko", markersize=3, alpha=0.5, label="Data")
fig.axes[0].legend(fontsize=8, loc="best")

display(fig)
plt.close(fig)

## 6. NLSQ Fitting

In [None]:
# Fit to synthetic data (verify recovery of true parameters)
model_fit = FIKH(include_thermal=False, alpha_structure=0.5)  # Start from different alpha

t0 = time.time()
model_fit.fit(time_data, stress_data, test_mode="relaxation", sigma_0=SIGMA_0, method='scipy')
t_nlsq = time.time() - t0

print(f"NLSQ fit time: {t_nlsq:.2f} s")
print(f"\nFitted vs True parameters:")
for name in param_names:
    fitted = model_fit.parameters.get_value(name)
    true_val = model.parameters.get_value(name)
    rel_err = abs(fitted - true_val) / (abs(true_val) + 1e-10) * 100
    print(f"  {name:15s}: fitted={fitted:.4g}, true={true_val:.4g} (err={rel_err:.1f}%)")

In [None]:
# Plot fit
stress_pred = model_fit.predict_relaxation(time_data, sigma_0=SIGMA_0)
metrics = compute_fit_quality(stress_data, stress_pred)

time_fine = np.logspace(np.log10(time_data.min()), np.log10(time_data.max()), 300)
stress_pred_fine = model_fit.predict_relaxation(time_fine, sigma_0=SIGMA_0)

fig, ax = plt.subplots(figsize=(10, 6))
ax.loglog(time_data, stress_data, "ko", markersize=5, label="Synthetic data")
ax.loglog(time_fine, stress_pred_fine, "-", lw=2.5, color="C0", label="FIKH fit")
ax.set_xlabel("Time [s]", fontsize=12)
ax.set_ylabel("Stress [Pa]", fontsize=12)
ax.set_title(f"FIKH Relaxation Fit (R$^2$ = {metrics['R2']:.5f})", fontsize=13)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3, which="both")
plt.tight_layout()
display(fig)
plt.close(fig)

## 7. Bayesian Inference

In [None]:
# Bayesian inference — clamp initial values inside bounds for stable NUTS init
initial_values = {}
for name in param_names:
    val = model_fit.parameters.get_value(name)
    lo, hi = model_fit.parameters[name].bounds
    eps = 1e-6 * max(abs(hi - lo), 1.0)
    initial_values[name] = float(np.clip(val, lo + eps, hi - eps))

print(f"Running NUTS: {NUM_WARMUP} warmup + {NUM_SAMPLES} samples x {NUM_CHAINS} chain(s)")
print(f"  FAST_MODE={FAST_MODE}, data points={len(time_data)}")
t0 = time.time()
result = model_fit.fit_bayesian(
    time_data,
    stress_data,
    test_mode="relaxation",
    sigma_0=SIGMA_0,
    num_warmup=NUM_WARMUP,
    num_samples=NUM_SAMPLES,
    num_chains=NUM_CHAINS,
    initial_values=initial_values,
    seed=42,
)
t_bayes = time.time() - t0
print(f"\nBayesian inference time: {t_bayes:.1f} s")

In [None]:
# Convergence
all_pass = print_convergence_summary(result, param_names)

In [None]:
# Alpha posterior (key result for relaxation)
posterior = result.posterior_samples
alpha_samples = posterior["alpha_structure"]
alpha_median = np.median(alpha_samples)
alpha_lo, alpha_hi = np.percentile(alpha_samples, [2.5, 97.5])
true_alpha = model.parameters.get_value("alpha_structure")

print("\nFractional Order Recovery:")
print("=" * 50)
print(f"  True alpha:      {true_alpha:.3f}")
print(f"  Posterior:       {alpha_median:.3f} [{alpha_lo:.3f}, {alpha_hi:.3f}]")
print(f"  True in 95% CI:  {alpha_lo <= true_alpha <= alpha_hi}")

In [None]:
# Alpha posterior histogram
fig, ax = plt.subplots(figsize=(8, 5))
ax.hist(alpha_samples, bins=30, density=True, alpha=0.7, color="C0", edgecolor="white")
ax.axvline(true_alpha, color="red", lw=2, linestyle="--", label=f"True α = {true_alpha:.3f}")
ax.axvline(alpha_median, color="C0", lw=2, label=f"Median = {alpha_median:.3f}")
ax.set_xlabel("alpha_structure", fontsize=12)
ax.set_ylabel("Density", fontsize=12)
ax.set_title("Posterior Distribution of Fractional Order", fontsize=13)
ax.legend(fontsize=10)
plt.tight_layout()
display(fig)
plt.close(fig)

## 8. Save Results

In [None]:
save_fikh_results(model_fit, result, "fikh", "relaxation", param_names)
print("\nResults saved.")

## Key Takeaways

1. **Stress relaxation reveals power-law memory** most clearly in long-time behavior
2. **Lower alpha** → slower power-law decay $\sigma \sim t^{-\alpha}$
3. **Higher alpha (→1)** → classical exponential decay
4. **Log-log plots essential** for seeing power-law tails
5. **Relaxation data strongly constrains alpha** compared to flow curve data
6. **Synthetic data pipeline** validates parameter recovery

### Next Steps

- **NB04**: Creep (delayed yielding with memory effects)
- **NB05**: SAOS (frequency response)
- **NB06**: LAOS (nonlinear oscillatory)