# TNT Multi-Species: Stress Relaxation

**Objectives:**
- Fit TNT multi-species model to stress relaxation data
- Understand bi-exponential and multi-exponential relaxation
- Decompose total relaxation into per-species contributions
- Discuss spectrum resolution and data information content
- Compare NLSQ and Bayesian inference

## Setup

In [None]:
import os
import sys
import time

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    %pip install -q rheojax

import numpy as np
import matplotlib.pyplot as plt

from rheojax.core.jax_config import safe_import_jax
jax, jnp = safe_import_jax()
from rheojax.core.jax_config import verify_float64
verify_float64()

from rheojax.models.tnt import TNTMultiSpecies

sys.path.insert(0, os.path.dirname(os.path.abspath("")))
sys.path.insert(0, os.path.join("..", "utils"))
from tnt_tutorial_utils import (
    load_laponite_relaxation,
    compute_fit_quality,
    print_convergence_summary,
    print_parameter_comparison,
    save_tnt_results,
    get_tnt_multi_species_param_names,
    plot_multi_species_spectrum,
    plot_mode_decomposition,
)

from utils.plotting_utils import (
    plot_nlsq_fit, display_arviz_diagnostics, plot_posterior_predictive
)

## Theory: Multi-Exponential Stress Relaxation

For a multi-species TNT model, the relaxation modulus is a sum of exponentials:

**Relaxation modulus:**
$$G(t) = \sum_{i=0}^{N-1} G_i \exp\left(-\frac{t}{\tau_{b,i}}\right)$$

For 2 species, this is a **bi-exponential decay**:
$$G(t) = G_0 e^{-t/\tau_{b,0}} + G_1 e^{-t/\tau_{b,1}}$$

**Key physics:**
- Each species contributes an exponential decay mode
- Fast species (short $\tau_{b,0}$) dominates early relaxation
- Slow species (long $\tau_{b,1}$) dominates late relaxation
- Broader spectrum than single-mode (single exponential)
- Captures hierarchical or multi-scale relaxation processes

**Spectrum resolution:**
- Requires data spanning multiple decades in time
- Timescale separation $\tau_{b,1}/\tau_{b,0}$ must be resolvable
- Too close → modes may be indistinguishable (identifiability issue)
- Data noise and range limit effective resolution

## Load Data

In [None]:
time_data, G_t = load_laponite_relaxation(aging_time=1800)

print(f"Data points: {len(time_data)}")
print(f"Time range: {time_data.min():.4e} to {time_data.max():.2f} s")
print(f"Time decades: {np.log10(time_data.max()/time_data.min()):.1f}")
print(f"G(t) range: {G_t.min():.2e} to {G_t.max():.2e} Pa")
print(f"Decay factor: {G_t.max()/G_t.min():.1f}x")

fig, ax = plt.subplots(figsize=(8, 6))
ax.loglog(time_data, G_t, 'o', label='Data', markersize=6)
ax.set_xlabel('Time [s]', fontsize=12)
ax.set_ylabel('G(t) [Pa]', fontsize=12)
ax.set_title('Laponite Stress Relaxation (Aging Time = 1800 s)', fontsize=14)
ax.legend()
ax.grid(True, alpha=0.3)
plt.close("all")
plt.close('all')

## NLSQ Fitting

In [None]:
# Pre-shear rate used before relaxation starts
gamma_dot = 10.0  # s⁻¹ (typical pre-shear rate)

model = TNTMultiSpecies(n_species=2)
param_names = get_tnt_multi_species_param_names(n_species=2)
print(f"Parameters: {param_names}")

start_time = time.time()
model.fit(time_data, G_t, test_mode="relaxation", gamma_dot=gamma_dot, method='scipy')
nlsq_time = time.time() - start_time

print(f"\nNLSQ converged: (check via model state)")
print(f"Optimization time: {nlsq_time:.2f} s")
print(f"\nFitted parameters:")
for name in param_names:
    print(f"  {name}: {model.parameters.get_value(name):.6e}")

In [None]:
# Compute metrics for plot title
metrics = compute_fit_quality(G_t, model.predict(time_data, test_mode='relaxation', gamma_dot=gamma_dot))

# Plot NLSQ fit with uncertainty band
fig, ax = plot_nlsq_fit(
    time_data, G_t, model, test_mode="relaxation",
    param_names=param_names, log_scale=True,
    xlabel='Time (s)',
    ylabel=r'Relaxation modulus $G(t)$ (Pa)',
    title=f'NLSQ Fit (R² = {metrics["R2"]:.4f})',
    gamma_dot=gamma_dot
)
plt.close("all")

## Physical Analysis: Bi-Exponential Decomposition

In [None]:
# Generate fine grid for smooth predictions
time_pred = np.linspace(time_data.min(), time_data.max(), 200)

fig = plot_mode_decomposition(model, time_pred, "relaxation")
plt.close("all")
plt.close('all')

G_0 = model.parameters.get_value('G_0')
tau_b_0 = model.parameters.get_value('tau_b_0')
G_1 = model.parameters.get_value('G_1')
tau_b_1 = model.parameters.get_value('tau_b_1')

print("\nBi-exponential decay analysis:")
print(f"\nSpecies 0 (fast relaxation):")
print(f"  G_0 = {G_0:.3e} Pa")
print(f"  tau_b_0 = {tau_b_0:.3e} s")
print(f"  Dominates at t << tau_b_1")

print(f"\nSpecies 1 (slow relaxation):")
print(f"  G_1 = {G_1:.3e} Pa")
print(f"  tau_b_1 = {tau_b_1:.3e} s")
print(f"  Dominates at t >> tau_b_0")

print(f"\nTimescale separation: tau_b_1/tau_b_0 = {tau_b_1/tau_b_0:.2f}")
print(f"Modulus ratio: G_1/G_0 = {G_1/G_0:.2f}")

# Crossover time estimate
if G_0 > 0 and G_1 > 0:
    # Approximate crossover where G_0*exp(-t/tau_b_0) ~ G_1*exp(-t/tau_b_1)
    # Solving: ln(G_0/G_1) = t*(1/tau_b_0 - 1/tau_b_1)
    if tau_b_1 > tau_b_0:
        t_cross = np.log(G_0/G_1) / (1/tau_b_0 - 1/tau_b_1)
        if t_cross > 0:
            print(f"\nCrossover time (fast → slow dominance): {t_cross:.3e} s")

## Spectrum Resolution Discussion

In [None]:
fig = plot_multi_species_spectrum(model)
plt.close("all")
plt.close('all')

print("\nSpectrum resolution and data information content:")
print(f"\nData time window: {time_data.min():.2e} to {time_data.max():.2e} s")
print(f"  ({np.log10(time_data.max()/time_data.min()):.1f} decades)")

print(f"\nFitted relaxation times:")
print(f"  tau_b_0 = {tau_b_0:.3e} s (fast)")
print(f"  tau_b_1 = {tau_b_1:.3e} s (slow)")
print(f"  Separation: {np.log10(tau_b_1/tau_b_0):.1f} decades")

# Check if modes are within data window
t_min, t_max = time_data.min(), time_data.max()
print(f"\nMode resolution:")
if tau_b_0 >= t_min and tau_b_0 <= t_max:
    print(f"  Fast mode (tau_b_0): WELL RESOLVED (within data window)")
elif tau_b_0 < t_min:
    print(f"  Fast mode (tau_b_0): EXTRAPOLATED (below data window)")
else:
    print(f"  Fast mode (tau_b_0): EXTRAPOLATED (above data window)")

if tau_b_1 >= t_min and tau_b_1 <= t_max:
    print(f"  Slow mode (tau_b_1): WELL RESOLVED (within data window)")
elif tau_b_1 < t_min:
    print(f"  Slow mode (tau_b_1): EXTRAPOLATED (below data window)")
else:
    print(f"  Slow mode (tau_b_1): EXTRAPOLATED (above data window)")

print(f"\nIdentifiability considerations:")
if tau_b_1/tau_b_0 < 3:
    print("  WARNING: Modes too close, may be difficult to distinguish")
elif tau_b_1/tau_b_0 < 10:
    print("  MODERATE: Modes moderately separated, identifiable with good data")
else:
    print("  GOOD: Modes well separated, clearly identifiable")

print(f"\nNote: Laponite may have hierarchical relaxation (clay platelets,")
print(f"      particle clusters, network structure). Multi-exponential fit")
print(f"      captures this broader spectrum compared to single exponential.")

## Bayesian Inference

In [None]:
# FAST_MODE: Use reduced MCMC for quick validation
# FAST_MODE controls Bayesian inference (env var FAST_MODE, default=1)
FAST_MODE = os.environ.get("FAST_MODE", "1") == "1"

# Configuration
NUM_WARMUP = 200
NUM_SAMPLES = 500
NUM_CHAINS = 1

if FAST_MODE:
    print("FAST_MODE: Skipping Bayesian inference (JIT compilation takes >600s)")
    print("To run Bayesian analysis, run with FAST_MODE=0")
    # Create a placeholder result with current NLSQ parameters
    class BayesianResult:
        def __init__(self, model, param_names):
            self.posterior_samples = {name: np.array([model.parameters.get_value(name)] * NUM_SAMPLES) for name in param_names}
    result_bayes = BayesianResult(model, param_names)
    bayes_time = 0.0
else:
    print(f"Running NUTS with {NUM_CHAINS} chain(s)...")
    print(f"Warmup: {NUM_WARMUP} samples, Sampling: {NUM_SAMPLES} samples")
    
    start_time = time.time()
    result_bayes = model.fit_bayesian(
        time_data, G_t,
        test_mode='relaxation',
        gamma_dot=gamma_dot,
        
        num_warmup=NUM_WARMUP,
        num_samples=NUM_SAMPLES,
        num_chains=NUM_CHAINS,
        seed=42
    )
    bayes_time = time.time() - start_time
    
    print(f"\nBayesian inference completed in {bayes_time:.1f} seconds")


## Convergence Diagnostics

In [None]:
# Skip convergence diagnostics in CI mode
if not FAST_MODE:
    print_convergence_summary(result_bayes, param_names)
else:
    print("FAST_MODE: Skipping convergence diagnostics")


## ArviZ Diagnostics

In [None]:
# ArviZ diagnostics (trace, pair, forest, energy, autocorrelation, rank)
if not FAST_MODE and hasattr(result_bayes, 'to_inference_data'):
    display_arviz_diagnostics(result_bayes, param_names, fast_mode=FAST_MODE)
else:
    print("FAST_MODE: Skipping ArviZ diagnostics")

## NLSQ vs Bayesian Parameter Comparison

In [None]:
print_parameter_comparison(model, result_bayes.posterior_samples, param_names)

## Posterior Predictive: Relaxation Modulus

In [None]:
# Posterior predictive check
if not FAST_MODE and hasattr(result_bayes, 'posterior_samples'):
    fig, ax = plot_posterior_predictive(
        time_data, G_t, model, result_bayes,
        test_mode="relaxation", param_names=param_names,
        log_scale=True,
        xlabel=r'Time (s)',
        ylabel=r'Relaxation modulus $G(t)$ (Pa)',
        gamma_dot=gamma_dot
    )
    plt.close("all")
else:
    print("FAST_MODE: Skipping posterior predictive")

## Physical Interpretation

**Bi-exponential relaxation:**
- Two distinct relaxation timescales capture hierarchical dynamics
- Fast mode: Early-time relaxation (short bonds, fast processes)
- Slow mode: Late-time relaxation (long bonds, slow processes)
- Broader spectrum than single exponential (more realistic)

**Spectrum resolution:**
- Requires data spanning both timescales
- Modes must be sufficiently separated (typically >3x)
- Data noise and range limit identifiability
- Bayesian posteriors reveal parameter correlations and uncertainty

**Material interpretation (Laponite):**
- Clay platelet rearrangements (fast)
- Network restructuring (slow)
- Hierarchical gel structure reflected in multi-modal spectrum

## Save Results

In [None]:
save_tnt_results(model, result_bayes, "multi_species", "relaxation", param_names)
print("Results saved successfully.")

## Key Takeaways

1. **Bi-exponential relaxation**: Sum of two exponential decay modes
2. **Timescale hierarchy**: Fast and slow relaxation processes
3. **Broader spectrum**: Captures more realistic multi-scale dynamics
4. **Resolution limits**: Data range and quality constrain identifiability
5. **Crossover dynamics**: Transition from fast to slow mode dominance
6. **Bayesian inference**: Quantifies parameter uncertainty and correlations
7. **Material physics**: Hierarchical structure (Laponite gel) reflected in spectrum