# TNT Multi-Species: Startup Shear

**Objectives:**
- Fit TNT multi-species model to startup shear data
- Understand multi-species stress overshoot dynamics
- Analyze fast vs slow species peak times
- Decompose composite overshoot into species contributions
- Compare NLSQ and Bayesian inference

## Setup

In [None]:
import os
import sys
import time

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    %pip install -q rheojax

import numpy as np
import matplotlib.pyplot as plt
import arviz as az

from rheojax.core.jax_config import safe_import_jax
jax, jnp = safe_import_jax()
from rheojax.core.jax_config import verify_float64
verify_float64()

from rheojax.models.tnt import TNTMultiSpecies

sys.path.insert(0, os.path.join("..", "utils"))
from tnt_tutorial_utils import (
    load_pnas_startup,
    compute_fit_quality,
    print_convergence_summary,
    print_parameter_comparison,
    save_tnt_results,
    get_tnt_multi_species_param_names,
    plot_multi_species_spectrum,
    plot_mode_decomposition,
)

## Theory: Multi-Species Startup Dynamics

During startup shear at constant $\dot{\gamma}$, each bond species evolves independently:

**Per-species stress evolution:**
$$\frac{d\sigma_i}{dt} = G_i \dot{\gamma} - \frac{\sigma_i}{\tau_{b,i}}$$

**Total stress:**
$$\sigma(t) = \sum_{i=0}^{N-1} \sigma_i(t) + \eta_s \dot{\gamma}$$

**Key physics:**
- Each species has its own overshoot timescale $\sim \tau_{b,i}$
- Fast species (short $\tau_{b,0}$) peaks first
- Slow species (long $\tau_{b,1}$) peaks later
- Composite overshoot may show multiple peaks or broadened single peak
- Peak stress: $\sigma_{i,\text{max}} \sim G_i$ (order of magnitude)

**Timescale hierarchy:**
- $t_{\text{peak},0} \sim \tau_{b,0}$ (fast species)
- $t_{\text{peak},1} \sim \tau_{b,1}$ (slow species)
- Large $\tau_{b,1}/\tau_{b,0}$ ratio → well-separated peaks

## Load Data

In [None]:
time_data, stress = load_pnas_startup(gamma_dot=1.0)

print(f"Data points: {len(time_data)}")
print(f"Time range: {time_data.min():.4f} to {time_data.max():.2f} s")
print(f"Stress range: {stress.min():.2f} to {stress.max():.2f} Pa")
print(f"Peak stress: {stress.max():.2f} Pa at t = {time_data[np.argmax(stress)]:.2f} s")

fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(time_data, stress, 'o', label='Data', markersize=6)
ax.set_xlabel('Time [s]', fontsize=12)
ax.set_ylabel('Stress [Pa]', fontsize=12)
ax.set_title(r'Startup Shear ($\dot{\gamma}$ = 1.0 s$^{-1}$)', fontsize=14)
ax.legend()
ax.grid(True, alpha=0.3)
display(fig)
plt.close(fig)

## NLSQ Fitting

In [None]:
model = TNTMultiSpecies(n_species=2)
param_names = get_tnt_multi_species_param_names(n_species=2)
print(f"Parameters: {param_names}")

start_time = time.time()
result_nlsq = model.fit(time_data, stress, test_mode="startup", gamma_dot=1.0)
nlsq_time = time.time() - start_time

print(f"\nNLSQ converged: {result_nlsq.success}")
print(f"Optimization time: {nlsq_time:.2f} s")
print(f"\nFitted parameters:")
for name in param_names:
    print(f"  {name}: {result_nlsq.params[name]:.6e}")

In [None]:
time_pred = np.linspace(time_data.min(), time_data.max(), 500)
stress_pred = model.predict_startup(time_pred, gamma_dot=1.0)

fit_metrics = compute_fit_quality(stress, result_nlsq.y_pred)
print(f"\nFit quality:")
print(f"  R² = {fit_metrics['r_squared']:.6f}")
print(f"  RMSE = {fit_metrics['rmse']:.6e}")
print(f"  NRMSE = {fit_metrics['nrmse']:.6f}")

fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(time_data, stress, 'o', label='Data', markersize=6, alpha=0.7)
ax.plot(time_pred, stress_pred, '-', label='NLSQ Fit', linewidth=2)
ax.set_xlabel('Time [s]', fontsize=12)
ax.set_ylabel('Stress [Pa]', fontsize=12)
ax.set_title(f'TNT Multi-Species Startup (R² = {fit_metrics["r_squared"]:.4f})', fontsize=14)
ax.legend()
ax.grid(True, alpha=0.3)
display(fig)
plt.close(fig)

## Physical Analysis: Multi-Species Overshoot

In [None]:
fig = plot_mode_decomposition(model, time_pred, "startup", gamma_dot=1.0)
display(fig)
plt.close(fig)

G_0 = result_nlsq.params['G_0']
tau_b_0 = result_nlsq.params['tau_b_0']
G_1 = result_nlsq.params['G_1']
tau_b_1 = result_nlsq.params['tau_b_1']

print("\nSpecies overshoot analysis:")
print(f"\nSpecies 0 (fast):")
print(f"  G_0 = {G_0:.3e} Pa")
print(f"  tau_b_0 = {tau_b_0:.3e} s")
print(f"  Expected peak time ~ tau_b_0 = {tau_b_0:.3e} s")
print(f"  Expected peak stress ~ G_0 = {G_0:.1f} Pa")

print(f"\nSpecies 1 (slow):")
print(f"  G_1 = {G_1:.3e} Pa")
print(f"  tau_b_1 = {tau_b_1:.3e} s")
print(f"  Expected peak time ~ tau_b_1 = {tau_b_1:.3e} s")
print(f"  Expected peak stress ~ G_1 = {G_1:.1f} Pa")

print(f"\nTimescale separation: tau_b_1/tau_b_0 = {tau_b_1/tau_b_0:.2f}")

if tau_b_1/tau_b_0 > 10:
    print("  → Well-separated peaks expected")
elif tau_b_1/tau_b_0 > 3:
    print("  → Moderately separated peaks, composite overshoot")
else:
    print("  → Overlapping peaks, single broadened overshoot")

## Discrete Relaxation Spectrum

In [None]:
fig = plot_multi_species_spectrum(model)
display(fig)
plt.close(fig)

print("\nSpectrum interpretation:")
print("- Two discrete relaxation times control startup dynamics")
print("- Fast species relaxes on short timescale (early overshoot)")
print("- Slow species relaxes on long timescale (delayed overshoot)")
print("- Total overshoot is superposition of individual species")

## Bayesian Inference

In [None]:
NUM_WARMUP = 200
NUM_SAMPLES = 500
NUM_CHAINS = 1

print(f"Running Bayesian inference with {NUM_CHAINS} chain(s)...")
start_time = time.time()
result_bayes = model.fit_bayesian(
    time_data,
    stress,
    test_mode="startup",
    gamma_dot=1.0,
    num_warmup=NUM_WARMUP,
    num_samples=NUM_SAMPLES,
    num_chains=NUM_CHAINS,
    seed=42,
)
bayes_time = time.time() - start_time
print(f"Bayesian inference time: {bayes_time:.2f} s")

## Convergence Diagnostics

In [None]:
print_convergence_summary(result_bayes, param_names)

## ArviZ Diagnostics: Trace Plots

In [None]:
idata = result_bayes.to_arviz()
fig = az.plot_trace(idata, var_names=param_names, compact=True)
plt.tight_layout()
display(fig)
plt.close(fig)

## ArviZ Diagnostics: Posterior Distributions

In [None]:
fig = az.plot_posterior(idata, var_names=param_names, hdi_prob=0.95)
plt.tight_layout()
display(fig)
plt.close(fig)

## ArviZ Diagnostics: Pair Plot

In [None]:
fig = az.plot_pair(idata, var_names=param_names, divergences=True)
display(fig)
plt.close(fig)

## NLSQ vs Bayesian Parameter Comparison

In [None]:
print_parameter_comparison(result_nlsq, result_bayes, param_names)

## Posterior Predictive: Startup Response

In [None]:
posterior = result_bayes.posterior_samples
n_draws = min(200, NUM_SAMPLES)
draw_indices = np.linspace(0, NUM_SAMPLES - 1, n_draws, dtype=int)

x_pred = jnp.array(time_pred)
y_pred_samples = []

for i in draw_indices:
    params_i = jnp.array([posterior[name][i] for name in param_names])
    y_pred_i = model.model_function(x_pred, params_i, test_mode="startup", gamma_dot=1.0)
    y_pred_samples.append(np.array(y_pred_i))

y_pred_samples = np.array(y_pred_samples)
y_pred_mean = np.mean(y_pred_samples, axis=0)
y_pred_lower = np.percentile(y_pred_samples, 2.5, axis=0)
y_pred_upper = np.percentile(y_pred_samples, 97.5, axis=0)

fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(time_data, stress, 'o', label='Data', markersize=6, alpha=0.7, zorder=3)
ax.plot(time_pred, y_pred_mean, '-', label='Posterior Mean', linewidth=2, zorder=2)
ax.fill_between(time_pred, y_pred_lower, y_pred_upper, alpha=0.3, label='95% CI', zorder=1)
ax.set_xlabel('Time [s]', fontsize=12)
ax.set_ylabel('Stress [Pa]', fontsize=12)
ax.set_title('Posterior Predictive: Startup Shear', fontsize=14)
ax.legend()
ax.grid(True, alpha=0.3)
display(fig)
plt.close(fig)

## Physical Interpretation

**Multi-species overshoot dynamics:**
- Fast species contributes early-time stress buildup
- Slow species contributes delayed stress buildup
- Composite overshoot shape depends on timescale separation
- Well-separated $\tau_{b,i}$ → multiple peaks possible

**Timescale hierarchy:**
- Peak times correlate with bond lifetimes
- Peak heights scale with moduli $G_i$
- Lifetime ratio $\tau_{b,1}/\tau_{b,0}$ controls peak separation

**Uncertainty quantification:**
- Bayesian posteriors capture parameter correlations
- Overshoot region constrains both species parameters
- Steady-state plateau constrains solvent viscosity $\eta_s$

## Save Results

In [None]:
save_tnt_results(model, result_bayes, "multi_species", "startup", param_names)
print("Results saved successfully.")

## Key Takeaways

1. **Multi-species overshoot**: Each species has independent stress buildup dynamics
2. **Timescale separation**: Fast species peaks before slow species
3. **Composite dynamics**: Total overshoot is superposition of individual peaks
4. **Peak correlations**: Peak times ~ $\tau_{b,i}$, peak heights ~ $G_i$
5. **Discrete spectrum**: Two relaxation times control startup behavior
6. **Bayesian inference**: Quantifies uncertainty in species parameters
7. **Species resolution**: Data quality and timescale separation determine identifiability