# TNT Multi-Species: Creep Response

**Objectives:**
- Fit TNT multi-species model to creep compliance data
- Understand multi-mode compliance and sequential yielding
- Decompose total creep into per-species contributions
- Analyze fast vs slow species yielding dynamics
- Compare NLSQ and Bayesian inference

## Setup

In [None]:
import os
import sys
import time

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    %pip install -q rheojax

import numpy as np
import matplotlib.pyplot as plt
import arviz as az
from scipy.integrate import cumulative_trapezoid

from rheojax.core.jax_config import safe_import_jax
jax, jnp = safe_import_jax()
from rheojax.core.jax_config import verify_float64
verify_float64()

from rheojax.models.tnt import TNTMultiSpecies

sys.path.insert(0, os.path.join("..", "utils"))
from tnt_tutorial_utils import (
    load_ml_ikh_creep,
    compute_fit_quality,
    print_convergence_summary,
    print_parameter_comparison,
    save_tnt_results,
    get_tnt_multi_species_param_names,
    plot_multi_species_spectrum,
    plot_mode_decomposition,
)

## Theory: Multi-Mode Creep Compliance

For a multi-species TNT model, the creep compliance is:

**Creep compliance:**
$$J(t) = \sum_{i=0}^{N-1} \frac{1}{G_i}\left[1 - \exp\left(-\frac{t}{\tau_{b,i}}\right)\right] + \frac{t}{\eta_0}$$

For 2 species:
$$J(t) = \frac{1}{G_0}\left[1 - e^{-t/\tau_{b,0}}\right] + \frac{1}{G_1}\left[1 - e^{-t/\tau_{b,1}}\right] + \frac{t}{\eta_0}$$

**Strain response to constant stress $\sigma_0$:**
$$\gamma(t) = \sigma_0 \cdot J(t)$$

**Key physics:**
- **Sequential yielding**: Fast bonds break first, then slow bonds
- Early time: Fast species dominates ($t \sim \tau_{b,0}$)
- Intermediate time: Slow species kicks in ($t \sim \tau_{b,1}$)
- Late time: Viscous flow dominates (linear growth $\sim t/\eta_0$)

**Multi-stage dynamics:**
1. Elastic jump: $\gamma(0^+) = \sigma_0 \sum_i 1/G_i$ (instantaneous)
2. Fast species yield: $t \sim \tau_{b,0}$ (first plateau)
3. Slow species yield: $t \sim \tau_{b,1}$ (second plateau)
4. Terminal flow: $\gamma(t) \sim \sigma_0 t/\eta_0$ (linear)

## Load Data and Compute Strain

In [None]:
time_data, shear_rate = load_ml_ikh_creep(stress_pair_index=0)

# Integrate shear rate to get strain
strain = cumulative_trapezoid(shear_rate, time_data, initial=0)

print(f"Data points: {len(time_data)}")
print(f"Time range: {time_data.min():.4f} to {time_data.max():.2f} s")
print(f"Strain range: {strain.min():.4f} to {strain.max():.4f}")
print(f"Final strain: {strain[-1]:.4f}")

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(time_data, strain, 'o-', markersize=4)
ax1.set_xlabel('Time [s]', fontsize=12)
ax1.set_ylabel('Strain', fontsize=12)
ax1.set_title('Creep Strain Response', fontsize=14)
ax1.grid(True, alpha=0.3)

ax2.loglog(time_data, strain, 'o-', markersize=4)
ax2.set_xlabel('Time [s]', fontsize=12)
ax2.set_ylabel('Strain', fontsize=12)
ax2.set_title('Creep Strain (log-log)', fontsize=14)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
display(fig)
plt.close(fig)

## NLSQ Fitting

In [None]:
model = TNTMultiSpecies(n_species=2)
param_names = get_tnt_multi_species_param_names(n_species=2)
print(f"Parameters: {param_names}")

start_time = time.time()
result_nlsq = model.fit(time_data, strain, test_mode="creep")
nlsq_time = time.time() - start_time

print(f"\nNLSQ converged: {result_nlsq.success}")
print(f"Optimization time: {nlsq_time:.2f} s")
print(f"\nFitted parameters:")
for name in param_names:
    print(f"  {name}: {result_nlsq.params[name]:.6e}")

In [None]:
time_pred = np.linspace(time_data.min(), time_data.max(), 500)
strain_pred = model.predict_creep(time_pred)

fit_metrics = compute_fit_quality(strain, result_nlsq.y_pred)
print(f"\nFit quality:")
print(f"  R² = {fit_metrics['r_squared']:.6f}")
print(f"  RMSE = {fit_metrics['rmse']:.6e}")
print(f"  NRMSE = {fit_metrics['nrmse']:.6f}")

fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(time_data, strain, 'o', label='Data', markersize=6, alpha=0.7)
ax.plot(time_pred, strain_pred, '-', label='NLSQ Fit', linewidth=2)
ax.set_xlabel('Time [s]', fontsize=12)
ax.set_ylabel('Strain', fontsize=12)
ax.set_title(f'TNT Multi-Species Creep (R² = {fit_metrics["r_squared"]:.4f})', fontsize=14)
ax.legend()
ax.grid(True, alpha=0.3)
display(fig)
plt.close(fig)

## Physical Analysis: Sequential Species Yielding

In [None]:
fig = plot_mode_decomposition(model, time_pred, "creep")
display(fig)
plt.close(fig)

G_0 = result_nlsq.params['G_0']
tau_b_0 = result_nlsq.params['tau_b_0']
G_1 = result_nlsq.params['G_1']
tau_b_1 = result_nlsq.params['tau_b_1']
eta_s = result_nlsq.params['eta_s']

print("\nSequential yielding analysis:")
print(f"\nSpecies 0 (fast bonds - yield first):")
print(f"  G_0 = {G_0:.3e} Pa")
print(f"  tau_b_0 = {tau_b_0:.3e} s")
print(f"  Compliance contribution: 1/G_0 = {1/G_0:.3e} Pa^-1")
print(f"  Yield timescale ~ tau_b_0 = {tau_b_0:.3e} s")

print(f"\nSpecies 1 (slow bonds - yield second):")
print(f"  G_1 = {G_1:.3e} Pa")
print(f"  tau_b_1 = {tau_b_1:.3e} s")
print(f"  Compliance contribution: 1/G_1 = {1/G_1:.3e} Pa^-1")
print(f"  Yield timescale ~ tau_b_1 = {tau_b_1:.3e} s")

print(f"\nSolvent viscosity:")
print(f"  eta_s = {eta_s:.3e} Pa·s")
print(f"  Terminal flow rate: 1/eta_s = {1/eta_s:.3e} (Pa·s)^-1")

print(f"\nTimescale hierarchy:")
print(f"  tau_b_1/tau_b_0 = {tau_b_1/tau_b_0:.2f}")
print(f"  Fast bonds break {tau_b_1/tau_b_0:.1f}x faster than slow bonds")

# Estimate compliance contributions
J_0 = 1/G_0
J_1 = 1/G_1
J_total = J_0 + J_1
print(f"\nCompliance partitioning:")
print(f"  Fast species: {100*J_0/J_total:.1f}% of elastic compliance")
print(f"  Slow species: {100*J_1/J_total:.1f}% of elastic compliance")

## Multi-Stage Creep Dynamics

In [None]:
fig = plot_multi_species_spectrum(model)
display(fig)
plt.close(fig)

print("\nMulti-stage creep dynamics:")
print(f"\n1. Elastic jump (t → 0+):")
print(f"   Instantaneous strain ~ σ_0 * (1/G_0 + 1/G_1)")

print(f"\n2. Fast species yield (t ~ tau_b_0 = {tau_b_0:.3e} s):")
print(f"   Fast bonds break, strain increases by σ_0/G_0")
print(f"   First plateau region")

print(f"\n3. Slow species yield (t ~ tau_b_1 = {tau_b_1:.3e} s):")
print(f"   Slow bonds break, strain increases by σ_0/G_1")
print(f"   Second plateau region")

print(f"\n4. Terminal flow (t >> tau_b_1):")
print(f"   Linear strain growth: γ(t) ~ σ_0 * t/eta_s")
print(f"   All bonds broken, pure viscous flow")

## Bayesian Inference

In [None]:
NUM_WARMUP = 200
NUM_SAMPLES = 500
NUM_CHAINS = 1

print(f"Running Bayesian inference with {NUM_CHAINS} chain(s)...")
start_time = time.time()
result_bayes = model.fit_bayesian(
    time_data,
    strain,
    test_mode="creep",
    num_warmup=NUM_WARMUP,
    num_samples=NUM_SAMPLES,
    num_chains=NUM_CHAINS,
    seed=42,
)
bayes_time = time.time() - start_time
print(f"Bayesian inference time: {bayes_time:.2f} s")

## Convergence Diagnostics

In [None]:
print_convergence_summary(result_bayes, param_names)

## ArviZ Diagnostics: Trace Plots

In [None]:
idata = result_bayes.to_arviz()
fig = az.plot_trace(idata, var_names=param_names, compact=True)
plt.tight_layout()
display(fig)
plt.close(fig)

## ArviZ Diagnostics: Posterior Distributions

In [None]:
fig = az.plot_posterior(idata, var_names=param_names, hdi_prob=0.95)
plt.tight_layout()
display(fig)
plt.close(fig)

## ArviZ Diagnostics: Pair Plot

In [None]:
fig = az.plot_pair(idata, var_names=param_names, divergences=True)
display(fig)
plt.close(fig)

## NLSQ vs Bayesian Parameter Comparison

In [None]:
print_parameter_comparison(result_nlsq, result_bayes, param_names)

## Posterior Predictive: Creep Response

In [None]:
posterior = result_bayes.posterior_samples
n_draws = min(200, NUM_SAMPLES)
draw_indices = np.linspace(0, NUM_SAMPLES - 1, n_draws, dtype=int)

x_pred = jnp.array(time_pred)
y_pred_samples = []

for i in draw_indices:
    params_i = jnp.array([posterior[name][i] for name in param_names])
    y_pred_i = model.model_function(x_pred, params_i, test_mode="creep")
    y_pred_samples.append(np.array(y_pred_i))

y_pred_samples = np.array(y_pred_samples)
y_pred_mean = np.mean(y_pred_samples, axis=0)
y_pred_lower = np.percentile(y_pred_samples, 2.5, axis=0)
y_pred_upper = np.percentile(y_pred_samples, 97.5, axis=0)

fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(time_data, strain, 'o', label='Data', markersize=6, alpha=0.7, zorder=3)
ax.plot(time_pred, y_pred_mean, '-', label='Posterior Mean', linewidth=2, zorder=2)
ax.fill_between(time_pred, y_pred_lower, y_pred_upper, alpha=0.3, label='95% CI', zorder=1)
ax.set_xlabel('Time [s]', fontsize=12)
ax.set_ylabel('Strain', fontsize=12)
ax.set_title('Posterior Predictive: Creep Response', fontsize=14)
ax.legend()
ax.grid(True, alpha=0.3)
display(fig)
plt.close(fig)

## Physical Interpretation

**Multi-mode compliance:**
- Each species contributes independent compliance mode
- Fast species (short $\tau_{b,0}$) yields first → early strain growth
- Slow species (long $\tau_{b,1}$) yields later → delayed strain growth
- Terminal flow: All bonds broken, viscous flow dominates

**Sequential yielding:**
- Staged strain accumulation with multiple timescales
- Compliance partitioning reflects network structure
- Weak bonds (low $G_i$, short $\tau_{b,i}$) break first
- Strong bonds (high $G_i$, long $\tau_{b,i}$) resist longer

**Uncertainty quantification:**
- Bayesian posteriors capture parameter correlations
- Early-time data constrains fast species
- Late-time data constrains slow species and viscosity
- Species parameters may trade off depending on timescale separation

## Save Results

In [None]:
save_tnt_results(model, result_bayes, "multi_species", "creep", param_names)
print("Results saved successfully.")

## Key Takeaways

1. **Multi-mode compliance**: Sum of independent species contributions
2. **Sequential yielding**: Fast bonds break before slow bonds
3. **Timescale hierarchy**: Multiple yield stages with distinct timescales
4. **Compliance partitioning**: Species contributions reflect network structure
5. **Terminal flow**: Long-time viscous behavior from solvent viscosity
6. **Bayesian inference**: Quantifies parameter uncertainty and correlations
7. **Data requirements**: Need time range spanning all yield timescales for full resolution