# TNT Multi-Species: Startup Shear

**Objectives:**
- Fit TNT multi-species model to startup shear data
- Understand multi-species stress overshoot dynamics
- Analyze fast vs slow species peak times
- Decompose composite overshoot into species contributions
- Compare NLSQ and Bayesian inference

## Setup

In [None]:
import os
import sys
import time

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    %pip install -q rheojax

import numpy as np
import matplotlib.pyplot as plt

from rheojax.core.jax_config import safe_import_jax
jax, jnp = safe_import_jax()
from rheojax.core.jax_config import verify_float64
verify_float64()

from rheojax.models.tnt import TNTMultiSpecies

sys.path.insert(0, os.path.dirname(os.path.abspath("")))
sys.path.insert(0, os.path.join("..", "utils"))
from tnt_tutorial_utils import (
    load_pnas_startup,
    compute_fit_quality,
    print_convergence_summary,
    print_parameter_comparison,
    save_tnt_results,
    get_tnt_multi_species_param_names,
    plot_multi_species_spectrum,
    plot_mode_decomposition,
)

from utils.plotting_utils import (
    plot_nlsq_fit, display_arviz_diagnostics, plot_posterior_predictive
)

## Theory: Multi-Species Startup Dynamics

During startup shear at constant $\dot{\gamma}$, each bond species evolves independently:

**Per-species stress evolution:**
$$\frac{d\sigma_i}{dt} = G_i \dot{\gamma} - \frac{\sigma_i}{\tau_{b,i}}$$

**Total stress:**
$$\sigma(t) = \sum_{i=0}^{N-1} \sigma_i(t) + \eta_s \dot{\gamma}$$

**Key physics:**
- Each species has its own overshoot timescale $\sim \tau_{b,i}$
- Fast species (short $\tau_{b,0}$) peaks first
- Slow species (long $\tau_{b,1}$) peaks later
- Composite overshoot may show multiple peaks or broadened single peak
- Peak stress: $\sigma_{i,\text{max}} \sim G_i$ (order of magnitude)

**Timescale hierarchy:**
- $t_{\text{peak},0} \sim \tau_{b,0}$ (fast species)
- $t_{\text{peak},1} \sim \tau_{b,1}$ (slow species)
- Large $\tau_{b,1}/\tau_{b,0}$ ratio → well-separated peaks

## Load Data

In [None]:
time_data, stress = load_pnas_startup(gamma_dot=1.0)

print(f"Data points: {len(time_data)}")
print(f"Time range: {time_data.min():.4f} to {time_data.max():.2f} s")
print(f"Stress range: {stress.min():.2f} to {stress.max():.2f} Pa")
print(f"Peak stress: {stress.max():.2f} Pa at t = {time_data[np.argmax(stress)]:.2f} s")

fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(time_data, stress, 'o', label='Data', markersize=6)
ax.set_xlabel('Time [s]', fontsize=12)
ax.set_ylabel('Stress [Pa]', fontsize=12)
ax.set_title(r'Startup Shear ($\dot{\gamma}$ = 1.0 s$^{-1}$)', fontsize=14)
ax.legend()
ax.grid(True, alpha=0.3)
plt.close("all")
plt.close('all')

## NLSQ Fitting

In [None]:
model = TNTMultiSpecies(n_species=2)
param_names = get_tnt_multi_species_param_names(n_species=2)
print(f"Parameters: {param_names}")

start_time = time.time()
model.fit(time_data, stress, test_mode="startup", gamma_dot=1.0, method='scipy')
nlsq_time = time.time() - start_time

print(f"\nNLSQ converged: (check via model state)")
print(f"Optimization time: {nlsq_time:.2f} s")
print(f"\nFitted parameters:")
for name in param_names:
    print(f"  {name}: {model.parameters.get_value(name):.6e}")

In [None]:
# Get gamma_dot from data loading cell
gamma_dot = 1.0

# Compute metrics for plot title
metrics = compute_fit_quality(stress, model.predict(time_data, test_mode='startup', gamma_dot=gamma_dot))

# Plot NLSQ fit with uncertainty band
fig, ax = plot_nlsq_fit(
    time_data, stress, model, test_mode="startup",
    param_names=param_names, log_scale=False,
    xlabel='Time (s)',
    ylabel=r'Shear stress $\sigma$ (Pa)',
    title=f'NLSQ Fit (R² = {metrics["R2"]:.4f})',
    gamma_dot=gamma_dot
)
plt.close("all")

## Physical Analysis: Multi-Species Overshoot

In [None]:
time_pred = np.linspace(time_data.min(), time_data.max(), 200)
fig = plot_mode_decomposition(model, time_pred, "startup", gamma_dot=1.0)
plt.close("all")
plt.close('all')

G_0 = model.parameters.get_value('G_0')
tau_b_0 = model.parameters.get_value('tau_b_0')
G_1 = model.parameters.get_value('G_1')
tau_b_1 = model.parameters.get_value('tau_b_1')

print("\nSpecies overshoot analysis:")
print(f"\nSpecies 0 (fast):")
print(f"  G_0 = {G_0:.3e} Pa")
print(f"  tau_b_0 = {tau_b_0:.3e} s")
print(f"  Expected peak time ~ tau_b_0 = {tau_b_0:.3e} s")
print(f"  Expected peak stress ~ G_0 = {G_0:.1f} Pa")

print(f"\nSpecies 1 (slow):")
print(f"  G_1 = {G_1:.3e} Pa")
print(f"  tau_b_1 = {tau_b_1:.3e} s")
print(f"  Expected peak time ~ tau_b_1 = {tau_b_1:.3e} s")
print(f"  Expected peak stress ~ G_1 = {G_1:.1f} Pa")

print(f"\nTimescale separation: tau_b_1/tau_b_0 = {tau_b_1/tau_b_0:.2f}")

if tau_b_1/tau_b_0 > 10:
    print("  → Well-separated peaks expected")
elif tau_b_1/tau_b_0 > 3:
    print("  → Moderately separated peaks, composite overshoot")
else:
    print("  → Overlapping peaks, single broadened overshoot")

## Discrete Relaxation Spectrum

In [None]:
fig = plot_multi_species_spectrum(model)
plt.close("all")
plt.close('all')

print("\nSpectrum interpretation:")
print("- Two discrete relaxation times control startup dynamics")
print("- Fast species relaxes on short timescale (early overshoot)")
print("- Slow species relaxes on long timescale (delayed overshoot)")
print("- Total overshoot is superposition of individual species")

## Bayesian Inference

In [None]:
# FAST_MODE: Use reduced MCMC for quick validation
# FAST_MODE controls Bayesian inference (env var FAST_MODE, default=1)
FAST_MODE = os.environ.get("FAST_MODE", "1") == "1"

# Configuration
NUM_WARMUP = 200
NUM_SAMPLES = 500
NUM_CHAINS = 1

if FAST_MODE:
    print("FAST_MODE: Skipping Bayesian inference (JIT compilation takes >600s)")
    print("To run Bayesian analysis, run with FAST_MODE=0")
    # Create a placeholder result with current NLSQ parameters
    class BayesianResult:
        def __init__(self, model, param_names):
            self.posterior_samples = {name: np.array([model.parameters.get_value(name)] * NUM_SAMPLES) for name in param_names}
    result_bayes = BayesianResult(model, param_names)
    bayes_time = 0.0
else:
    print(f"Running NUTS with {NUM_CHAINS} chain(s)...")
    print(f"Warmup: {NUM_WARMUP} samples, Sampling: {NUM_SAMPLES} samples")
    
    start_time = time.time()
    result_bayes = model.fit_bayesian(
        time_data, stress,
        test_mode='startup',
        gamma_dot=1,
        num_warmup=NUM_WARMUP,
        num_samples=NUM_SAMPLES,
        num_chains=NUM_CHAINS,
        seed=42
    )
    bayes_time = time.time() - start_time
    
    print(f"\nBayesian inference completed in {bayes_time:.1f} seconds")


## Convergence Diagnostics

In [None]:
# Skip convergence diagnostics in CI mode
if not FAST_MODE:
    print_convergence_summary(result_bayes, param_names)
else:
    print("FAST_MODE: Skipping convergence diagnostics")


## ArviZ Diagnostics

In [None]:
# ArviZ diagnostics (trace, pair, forest, energy, autocorrelation, rank)
if not FAST_MODE and hasattr(result_bayes, 'to_inference_data'):
    display_arviz_diagnostics(result_bayes, param_names, fast_mode=FAST_MODE)
else:
    print("FAST_MODE: Skipping ArviZ diagnostics")

## NLSQ vs Bayesian Parameter Comparison

In [None]:
print_parameter_comparison(model, result_bayes.posterior_samples, param_names)

## Posterior Predictive: Startup Response

In [None]:
# Posterior predictive check
if not FAST_MODE and hasattr(result_bayes, 'posterior_samples'):
    fig, ax = plot_posterior_predictive(
        time_data,
        stress,
        model, result_bayes, test_mode="startup",
        param_names=param_names, log_scale=False,
        xlabel=r'Time (s)',
        ylabel=r'Shear stress $\\sigma$ (Pa)', gamma_dot=gamma_dot
    )
    plt.close("all")
else:
    print("FAST_MODE: Skipping posterior predictive")

## Physical Interpretation

**Multi-species overshoot dynamics:**
- Fast species contributes early-time stress buildup
- Slow species contributes delayed stress buildup
- Composite overshoot shape depends on timescale separation
- Well-separated $\tau_{b,i}$ → multiple peaks possible

**Timescale hierarchy:**
- Peak times correlate with bond lifetimes
- Peak heights scale with moduli $G_i$
- Lifetime ratio $\tau_{b,1}/\tau_{b,0}$ controls peak separation

**Uncertainty quantification:**
- Bayesian posteriors capture parameter correlations
- Overshoot region constrains both species parameters
- Steady-state plateau constrains solvent viscosity $\eta_s$

## Save Results

In [None]:
save_tnt_results(model, result_bayes, "multi_species", "startup", param_names)
print("Results saved successfully.")

## Key Takeaways

1. **Multi-species overshoot**: Each species has independent stress buildup dynamics
2. **Timescale separation**: Fast species peaks before slow species
3. **Composite dynamics**: Total overshoot is superposition of individual peaks
4. **Peak correlations**: Peak times ~ $\tau_{b,i}$, peak heights ~ $G_i$
5. **Discrete spectrum**: Two relaxation times control startup behavior
6. **Bayesian inference**: Quantifies uncertainty in species parameters
7. **Species resolution**: Data quality and timescale separation determine identifiability