# ITT-MCT Isotropic Model: Stress Relaxation

## Learning Objectives

1. Generate **synthetic relaxation data** from NB07 calibrated parameters
2. Understand **two-step relaxation** in ISM (β and α processes)
3. Analyze the **k-resolved non-ergodicity parameter**
4. Fit the model to verify parameter recovery

## Prerequisites

- **NB07: ISM Flow Curve** (required for calibrated parameters)

## Runtime

- Fast demo (NUM_CHAINS=1, NUM_SAMPLES=500): ~3-5 minutes
- Full run (NUM_CHAINS=4, NUM_SAMPLES=2000): ~15-20 minutes

## 1. Setup

In [None]:
# Google Colab setup
import sys

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    %pip install -q rheojax
    import os
    os.environ["JAX_ENABLE_X64"] = "true"
    print("RheoJAX installed successfully.")

In [None]:
# Imports
%matplotlib inline
import os
import sys
import time
import warnings

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display

from rheojax.core.jax_config import safe_import_jax, verify_float64
from rheojax.models.itt_mct import ITTMCTIsotropic

# Add examples/utils to path
sys.path.insert(0, os.path.join("..", "utils"))
from itt_mct_tutorial_utils import (
    load_itt_mct_parameters,
    set_model_parameters,
    generate_synthetic_relaxation_isotropic,
    save_itt_mct_results,
    print_convergence_summary,
    print_parameter_comparison,
    print_glass_state_summary,
    compute_fit_quality,
    get_isotropic_param_names,
)

jax, jnp = safe_import_jax()
verify_float64()

warnings.filterwarnings("ignore", category=FutureWarning)
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")

## 2. Theory: ISM Stress Relaxation

### k-Resolved Two-Step Relaxation

In the ISM, each wave vector mode has its own correlator Φ(k,t):

1. **Fast β-process**: Short-time decay from particle vibrations
2. **Slow α-process**: Long-time decay from cage rearrangement

### Key Equations

**k-resolved correlator**:
$$
\frac{\partial \Phi(k,t)}{\partial t} + \Gamma(k)\left[\Phi(k,t) + \int_0^t m(k,t-s) \frac{\partial \Phi(k,s)}{\partial s} ds\right] = 0
$$

**Stress relaxation**:
$$
\sigma(t) = \gamma_0 \frac{k_BT}{6\pi^2} \int dk \, k^4 S(k)^2 \left[\frac{\partial \ln S}{\partial \ln k}\right]^2 \Phi(k,t)^2 h(\gamma_0)
$$

### ISM vs Schematic Relaxation

| Aspect | Schematic | ISM |
|--------|-----------|-----|
| Correlator | Single Φ(t) | k-resolved Φ(k,t) |
| Decay | Averaged | k-dependent rates |
| Plateau | f | f(k) integrated |

## 3. Load Calibrated Parameters from NB07

In [None]:
# Load parameters calibrated in NB07
try:
    params = load_itt_mct_parameters("isotropic", "flow_curve")
    print("Loaded parameters from NB07:")
    for name, val in params.items():
        print(f"  {name:10s} = {val:.4g}")
except FileNotFoundError as e:
    print(f"Warning: {e}")
    print("Using default parameters (run NB07 first for calibrated values)")
    params = {
        "phi": 0.55, 
        "sigma_d": 1e-6, 
        "D0": 1e-12, 
        "kBT": 4.11e-21,
        "gamma_c": 0.1
    }

In [None]:
# Create model and set parameters
model = ITTMCTIsotropic(phi=params.get("phi", 0.55))
set_model_parameters(model, params)

print("\nModel state:")
print(model)
print()
print_glass_state_summary(model)

## 4. Generate Synthetic Relaxation Data

In [None]:
# Generate synthetic data with noise
SIGMA_0 = 100.0  # Initial stress (Pa)
T_END = 100.0    # End time (s)
NOISE_LEVEL = 0.02  # 2% noise

time_data, stress_data = generate_synthetic_relaxation_isotropic(
    model,
    sigma_0=SIGMA_0,
    t_end=T_END,
    n_points=200,
    noise_level=NOISE_LEVEL,
    seed=42,
)

print(f"Generated {len(time_data)} data points")
print(f"Time range: [{time_data.min():.4f}, {time_data.max():.2f}] s")
print(f"Stress range: [{stress_data.min():.2f}, {stress_data.max():.2f}] Pa")

In [None]:
# Compute pre-shear strain for fitting
kBT = model.parameters.get_value("kBT")
sigma_d = model.parameters.get_value("sigma_d")
G_approx = kBT / sigma_d**3  # Approximate modulus
gamma_pre = SIGMA_0 / G_approx
print(f"Approximate modulus: G ≈ {G_approx:.2f} Pa")
print(f"Pre-shear strain: γ₀ ≈ {gamma_pre:.4f}")

In [None]:
# Plot synthetic data
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Left: Linear scale
ax1.plot(time_data, stress_data, "ko", markersize=4, alpha=0.7)
ax1.set_xlabel("Time [s]", fontsize=12)
ax1.set_ylabel("Stress [Pa]", fontsize=12)
ax1.set_title("ISM Stress Relaxation (Linear)", fontsize=13)
ax1.grid(True, alpha=0.3)

# Right: Log-log scale
ax2.loglog(time_data, stress_data, "ko", markersize=4, alpha=0.7)
ax2.set_xlabel("Time [s]", fontsize=12)
ax2.set_ylabel("Stress [Pa]", fontsize=12)
ax2.set_title("ISM Stress Relaxation (Log-Log)", fontsize=13)
ax2.grid(True, alpha=0.3, which="both")

plt.tight_layout()
display(fig)
plt.close(fig)

## 5. NLSQ Fitting

In [None]:
# Fit to relaxation data
param_names = ["phi", "D0", "gamma_c"]

t0 = time.time()
model.fit(time_data, stress_data, test_mode="relaxation", gamma_pre=gamma_pre, method='scipy')
t_nlsq = time.time() - t0

print(f"NLSQ fit time: {t_nlsq:.2f} s")
print(f"\nFitted parameters:")
for name in get_isotropic_param_names():
    val = model.parameters.get_value(name)
    orig = params.get(name, val)
    print(f"  {name:10s} = {val:.4g}  (original: {orig:.4g})")

In [None]:
# Compute fit quality
stress_pred = model.predict(time_data, test_mode="relaxation", gamma_pre=gamma_pre)
metrics = compute_fit_quality(stress_data, stress_pred)

print(f"\nFit Quality:")
print(f"  R²:   {metrics['R2']:.6f}")
print(f"  RMSE: {metrics['RMSE']:.4g} Pa")
print(f"  NRMSE: {metrics['NRMSE']:.4%}")

In [None]:
# Plot fit
time_fine = np.logspace(-2, np.log10(T_END), 200)
stress_pred_fine = model.predict(time_fine, test_mode="relaxation", gamma_pre=gamma_pre)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Left: Linear
ax1.plot(time_data, stress_data, "ko", markersize=5, label="Synthetic data")
ax1.plot(time_fine, stress_pred_fine, "-", lw=2, color="C0", label="ISM fit")
ax1.set_xlabel("Time [s]", fontsize=12)
ax1.set_ylabel("Stress [Pa]", fontsize=12)
ax1.set_title(f"ISM Relaxation Fit (R² = {metrics['R2']:.4f})", fontsize=13)
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# Right: Log-log
ax2.loglog(time_data, stress_data, "ko", markersize=5, label="Synthetic data")
ax2.loglog(time_fine, stress_pred_fine, "-", lw=2, color="C0", label="ISM fit")
ax2.set_xlabel("Time [s]", fontsize=12)
ax2.set_ylabel("Stress [Pa]", fontsize=12)
ax2.set_title("Log-Log Scale", fontsize=13)
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3, which="both")

plt.tight_layout()
display(fig)
plt.close(fig)

## 6. Bayesian Inference

In [None]:
# Prepare warm-start
initial_values = {
    name: model.parameters.get_value(name)
    for name in param_names
}

# Fast demo config
NUM_WARMUP = 200
NUM_SAMPLES = 500
NUM_CHAINS = 1

print(f"Running NUTS: {NUM_WARMUP} warmup + {NUM_SAMPLES} samples x {NUM_CHAINS} chain(s)")
t0 = time.time()
result = model.fit_bayesian(
    time_data,
    stress_data,
    test_mode="relaxation",
    gamma_pre=gamma_pre,
    num_warmup=NUM_WARMUP,
    num_samples=NUM_SAMPLES,
    num_chains=NUM_CHAINS,
    initial_values=initial_values,
    seed=42,
)
t_bayes = time.time() - t0
print(f"\nBayesian inference time: {t_bayes:.1f} s")

In [None]:
# Convergence diagnostics
all_pass = print_convergence_summary(result, param_names)

In [None]:
# Trace plots
idata = result.to_inference_data()
axes = az.plot_trace(idata, var_names=param_names, figsize=(12, 6))
fig = axes.ravel()[0].figure
fig.suptitle("Trace Plots (ISM Relaxation)", fontsize=14, y=1.00)
plt.tight_layout()
display(fig)
plt.close(fig)

In [None]:
# Parameter comparison
posterior = result.posterior_samples
print_parameter_comparison(model, posterior, param_names)

## 7. Physical Interpretation

### ISM Relaxation Features

1. **k-resolved decay**: Each mode has different relaxation rate
2. **Structure factor weighting**: Stress integral weighted by S(k)
3. **Glass plateau**: Residual stress from arrested modes

In [None]:
# Summary
print("ISM Relaxation Summary")
print("=" * 50)
print_glass_state_summary(model)
print(f"\nRelaxation Characteristics:")
print(f"  Initial stress: σ₀ = {SIGMA_0:.1f} Pa")
print(f"  Pre-shear strain: γ₀ ≈ {gamma_pre:.4f}")

## 8. Save Results

In [None]:
# Save results
save_itt_mct_results(model, result, "isotropic", "relaxation", param_names)
print("\nISM relaxation results saved.")

## Key Takeaways

1. **ISM** uses k-resolved correlators for quantitative predictions

2. **Two-step relaxation** from β and α processes

3. **Structure factor S(k)** determines stress weighting

4. **Parameter recovery** validates model from synthetic data

### Next Steps

- **NB10:** ISM Creep
- **NB11:** ISM SAOS
- **NB12:** ISM LAOS