# TNT Sticky Rouse: Startup Shear
> **Handbook:** This notebook demonstrates the TNT Sticky Rouse model. For complete mathematical derivations and theoretical background, see [TNT Sticky Rouse Documentation](../../docs/source/models/tnt/tnt_sticky_rouse.rst).


**Estimated Time:** 3-5 minutes

## Protocol: Startup in Sticky Rouse

**Multi-mode overshoot** from hierarchical Rouse spectrum. Fast modes overshoot early, slow modes later.

> **Sticky Rouse Startup**  
> [../../docs/source/models/tnt/tnt_sticky_rouse.rst](../../docs/source/models/tnt/tnt_sticky_rouse.rst)

## Learning Objectives

1. Fit sticky Rouse startup
2. Identify mode-specific overshoots
3. Extract Rouse timescale hierarchy

## Runtime: ~10-20 min

## Setup

In [None]:
import os
import sys
import time

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    %pip install -q rheojax

import numpy as np
import matplotlib.pyplot as plt

from rheojax.core.jax_config import safe_import_jax
jax, jnp = safe_import_jax()
from rheojax.core.jax_config import verify_float64
verify_float64()

from rheojax.models.tnt import TNTStickyRouse

sys.path.insert(0, os.path.dirname(os.path.abspath("")))
sys.path.insert(0, os.path.join("..", "utils"))
from tnt_tutorial_utils import (
    load_pnas_startup,
    compute_fit_quality,
    print_convergence_summary,
    print_parameter_comparison,
    save_tnt_results,
    get_tnt_sticky_rouse_param_names,
    plot_sticky_rouse_effective_times,
    plot_mode_decomposition,
)

from utils.plotting_utils import (
    plot_nlsq_fit, display_arviz_diagnostics, plot_posterior_predictive
)

print("Setup complete. JAX devices:", jax.devices())

# Residual analysis
residuals = stress - model.predict(gamma_dot, test_mode="flow_curve")
print(f"\nResidual Statistics:")
print(f"  Mean residual = {np.mean(residuals):.4e}")
print(f"  Std residual = {np.std(residuals):.4e}")
print(f"  Max absolute residual = {np.max(np.abs(residuals)):.4e}")


### Bayesian Convergence Diagnostics

When running full Bayesian inference (FAST_MODE=0), monitor these diagnostic metrics to ensure MCMC convergence:

| Metric | Acceptable Range | Interpretation |
|--------|------------------|----------------|
| **R-hat** | < 1.01 | Measures chain convergence; values near 1.0 indicate chains mixed well |
| **ESS (Effective Sample Size)** | > 400 | Number of independent samples; higher is better |
| **Divergences** | < 1% of samples | Indicates numerical instability; should be near zero |
| **BFMI (Bayesian Fraction of Missing Information)** | > 0.3 | Low values suggest reparameterization needed |

**Troubleshooting poor diagnostics:**
- High R-hat (>1.01): Increase `num_warmup` or `num_chains`
- Low ESS (<400): Increase `num_samples` or check for strong correlations
- Many divergences: Increase `target_accept` (default 0.8) or use NLSQ warm-start


## Theory: Startup Dynamics

**Stress Buildup:**

Each mode builds stress on timescale τ_eff,k:
$$\sigma_k(t) = G_k \tau_{\textrm{eff},k} \dot{\gamma} \left[1 - \exp\left(-\frac{t}{\tau_{\textrm{eff},k}}\right)\right]$$

Total stress: $\sigma(t) = \sum_k \sigma_k(t) + \eta_s \dot{\gamma}$

**Key Physics:**
- **Sticker-dominated modes** (τ_R,k < τ_s): All respond on timescale τ_s → collective buildup
- **Rouse-dominated modes** (τ_R,k > τ_s): Individual timescales → sequential buildup

**Observable Signatures:**
- Sticker regime: Single exponential approach at early times
- Rouse regime: Multi-exponential with separated timescales
- Steady state: Same as flow curve at corresponding shear rate

## Load Data

In [None]:
# Load startup shear data at γ̇ = 1.0 s^-1
time_data, stress = load_pnas_startup(gamma_dot=1.0)
gamma_dot = 1.0

print(f"Data shape: {len(time_data)} points")
print(f"Time range: {time_data.min():.2e} - {time_data.max():.2e} s")
print(f"Shear rate: {gamma_dot} s^-1")
print(f"Stress range: {stress.min():.2e} - {stress.max():.2e} Pa")

# Plot raw data
fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(time_data, stress, 'ko', label='PNAS data', markersize=6)
ax.set_xlabel('Time (s)', fontsize=12)
ax.set_ylabel('Stress (Pa)', fontsize=12)
ax.set_title(f'Startup Shear at γ̇ = {gamma_dot} s$^{{-1}}$', fontsize=14, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
plt.close("all")
plt.close('all')

## NLSQ Fitting

In [None]:
# Initialize model
model = TNTStickyRouse(n_modes=3)
param_names = get_tnt_sticky_rouse_param_names(n_modes=3)
print(f"Model parameters ({len(param_names)}): {param_names}")

# Fit using NLSQ
print("\nFitting with NLSQ...")
start_time = time.time()
model.fit(time_data, stress, test_mode="startup", gamma_dot=gamma_dot, method='scipy')
fit_time = time.time() - start_time

# Compute metrics
stress_pred_train = model.predict(time_data, test_mode="startup", gamma_dot=gamma_dot)
metrics_nlsq = compute_fit_quality(stress, stress_pred_train)

print(f"\nFit completed in {fit_time:.2f} seconds")
print(f"R² = {metrics_nlsq['R2']:.6f}")
print(f"RMSE = {metrics_nlsq['RMSE']:.4e} Pa")

## Fitted Parameters

In [None]:
# Extract fitted parameters
params_nlsq = {name: model.parameters.get_value(name) for name in param_names}

print("\nFitted Parameters:")
print("-" * 50)
for name, value in params_nlsq.items():
    if 'tau' in name:
        print(f"{name:10s} = {value:12.4e} s")
    elif 'eta' in name:
        print(f"{name:10s} = {value:12.4e} Pa·s")
    else:
        print(f"{name:10s} = {value:12.4e} Pa")

# Analyze effective relaxation times
tau_s = params_nlsq['tau_s']
print(f"\nSticker lifetime: τ_s = {tau_s:.4e} s")
print("\nMode Dynamics Analysis:")
print("-" * 50)
for i in range(3):
    tau_R = params_nlsq[f'tau_R_{i}']
    tau_eff = max(tau_R, tau_s)
    regime = "STICKER-DOMINATED" if tau_s > tau_R else "ROUSE-DOMINATED"
    print(f"Mode {i}: τ_eff = {tau_eff:.4e} s ({regime})")

## NLSQ Prediction vs Data

In [None]:
# Compute metrics for plot title
metrics = compute_fit_quality(stress, model.predict(time_data, test_mode='startup', gamma_dot=gamma_dot))

# Plot NLSQ fit with uncertainty band
fig, ax = plot_nlsq_fit(
    time_data, stress, model, test_mode="startup",
    param_names=param_names, log_scale=False,
    xlabel='Time (s)',
    ylabel=r'Shear stress $\sigma$ (Pa)',
    title=f'NLSQ Fit (R² = {metrics["R2"]:.4f})',
    gamma_dot=gamma_dot
)
plt.close("all")

## Mode-by-Mode Buildup

In [None]:
# Generate fine grid for smooth predictions
time_fine = np.linspace(time_data.min(), time_data.max(), 200)
stress_pred = model.predict(time_fine, test_mode='startup', gamma_dot=gamma_dot)

# Compute individual mode contributions
fig, ax = plt.subplots(figsize=(10, 7))

# Total prediction
ax.plot(time_fine, stress_pred, 'k-', label='Total', linewidth=2.5, zorder=5)

# Individual modes
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1']
for i in range(3):
    G_i = params_nlsq[f'G_{i}']
    tau_R_i = params_nlsq[f'tau_R_{i}']
    tau_eff_i = max(tau_R_i, tau_s)
    
    stress_i = G_i * tau_eff_i * gamma_dot * (1.0 - np.exp(-time_fine / tau_eff_i))
    
    regime = "sticker" if tau_s > tau_R_i else "Rouse"
    ax.plot(time_fine, stress_i, '--', color=colors[i], 
            label=f'Mode {i} ({regime}, τ_eff={tau_eff_i:.2e}s)', linewidth=1.5)

# Solvent contribution
eta_s = params_nlsq['eta_s']
stress_solvent = eta_s * gamma_dot * np.ones_like(time_fine)
ax.axhline(stress_solvent[0], color='gray', linestyle=':', linewidth=1.5, label=f'Solvent (η_s={eta_s:.2e})')

ax.set_xlabel('Time (s)', fontsize=12)
ax.set_ylabel('Stress (Pa)', fontsize=12)
ax.set_title('Mode-by-Mode Stress Buildup', fontsize=14, fontweight='bold')
ax.legend(fontsize=9, loc='best')
ax.grid(True, alpha=0.3)
plt.close("all")
plt.close('all')

## Sticker-Dominated Initial Response

In [None]:
# Analyze early-time collective response
tau_s = params_nlsq['tau_s']
n_sticker_modes = sum(1 for i in range(3) if params_nlsq[f'tau_R_{i}'] < tau_s)

print(f"Sticker-dominated modes: {n_sticker_modes}/3")
print(f"Sticker timescale: τ_s = {tau_s:.4e} s")

if n_sticker_modes > 0:
    # Collective stress from sticker-dominated modes
    G_sticker = sum(params_nlsq[f'G_{i}'] for i in range(3) if params_nlsq[f'tau_R_{i}'] < tau_s)
    stress_sticker_collective = G_sticker * tau_s * gamma_dot * (1.0 - np.exp(-time_fine / tau_s))
    
    fig, ax = plt.subplots(figsize=(10, 7))
    ax.plot(time_data, stress, 'ko', label='Data', markersize=6, alpha=0.5)
    ax.plot(time_fine, stress_pred, 'b-', label='Full Model', linewidth=2)
    ax.plot(time_fine, stress_sticker_collective, 'r--', 
            label=f'Sticker Collective (G_s={G_sticker:.2e} Pa, τ_s={tau_s:.2e}s)', linewidth=2)
    
    # Highlight early-time region
    t_early = 3.0 * tau_s
    ax.axvline(t_early, color='green', linestyle=':', linewidth=1.5, 
               label=f'Early regime (t < 3τ_s = {t_early:.2e}s)')
    
    ax.set_xlabel('Time (s)', fontsize=12)
    ax.set_ylabel('Stress (Pa)', fontsize=12)
    ax.set_title('Sticker-Dominated Initial Response', fontsize=14, fontweight='bold')
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3)
    plt.close("all")
    plt.close('all')
    
    print(f"\nCollective sticker modulus: G_sticker = {G_sticker:.4e} Pa")
    print(f"Early-time regime: t < {t_early:.4e} s")
else:
    print("\nNo sticker-dominated modes detected. All modes exhibit Rouse dynamics.")

## Effective Relaxation Time Analysis

In [None]:
# Visualize sticker-mode interaction
fig = plot_sticky_rouse_effective_times(model)
plt.close("all")
plt.close('all')

## Bayesian Inference

In [None]:
# FAST_MODE: Use reduced MCMC for quick validation
# FAST_MODE controls Bayesian inference (env var FAST_MODE, default=1)
FAST_MODE = os.environ.get("FAST_MODE", "1") == "1"

# Configuration
NUM_WARMUP = 200
NUM_SAMPLES = 500
NUM_CHAINS = 1

if FAST_MODE:
    print("FAST_MODE: Skipping Bayesian inference (JIT compilation takes >600s)")
    print("To run Bayesian analysis, run with FAST_MODE=0")
    # Create a placeholder result with current NLSQ parameters
    class BayesianResult:
        def __init__(self, model, param_names):
            self.posterior_samples = {name: np.array([model.parameters.get_value(name)] * NUM_SAMPLES) for name in param_names}
    result_bayes = BayesianResult(model, param_names)
    bayes_time = 0.0
else:
    print(f"Running NUTS with {NUM_CHAINS} chain(s)...")
    print(f"Warmup: {NUM_WARMUP} samples, Sampling: {NUM_SAMPLES} samples")
    
    start_time = time.time()
    result_bayes = model.fit_bayesian(
        time_data, stress,
        test_mode='startup',
        gamma_dot=gamma_dot,
        num_warmup=NUM_WARMUP,
        num_samples=NUM_SAMPLES,
        num_chains=NUM_CHAINS,
        seed=42
    )
    bayes_time = time.time() - start_time
    
    print(f"\nBayesian inference completed in {bayes_time:.1f} seconds")


## Convergence Diagnostics

In [None]:
# Skip convergence diagnostics in CI mode
if not FAST_MODE:
    print_convergence_summary(result_bayes, param_names)
else:
    print("FAST_MODE: Skipping convergence diagnostics")


## Parameter Comparison: NLSQ vs Bayesian

In [None]:
# Compare point estimates
print_parameter_comparison(model, result_bayes.posterior_samples, param_names)

## ArviZ Diagnostics

In [None]:
# ArviZ diagnostics (trace, pair, forest, energy, autocorrelation, rank)
if not FAST_MODE and hasattr(result_bayes, 'to_inference_data'):
    display_arviz_diagnostics(result_bayes, param_names, fast_mode=FAST_MODE)
else:
    print("FAST_MODE: Skipping ArviZ diagnostics")

## Posterior Predictive Distribution

In [None]:
# Posterior predictive check
if not FAST_MODE and hasattr(result_bayes, 'posterior_samples'):
    fig, ax = plot_posterior_predictive(
        time_data, stress, model, result_bayes,
        test_mode="startup", param_names=param_names,
        log_scale=False,
        xlabel=r'Time (s)',
        ylabel=r'Shear stress $\\sigma$ (Pa)',
        gamma_dot=gamma_dot
    )
    plt.close("all")
else:
    print("FAST_MODE: Skipping posterior predictive")

## Physical Interpretation

In [None]:
# Extract posterior means
posterior = result_bayes.posterior_samples
params_bayes = {name: float(np.mean(posterior[name])) for name in param_names}
tau_s_bayes = params_bayes['tau_s']

print("Physical Interpretation (Posterior Means):")
print("=" * 60)
print(f"\nSticker Lifetime: τ_s = {tau_s_bayes:.4e} s")
print(f"Imposed Shear Rate: γ̇ = {gamma_dot} s^-1")
print(f"Deborah Number: De = γ̇ × τ_s = {gamma_dot * tau_s_bayes:.4f}")

print("\nMode-by-Mode Buildup Dynamics:")
print("-" * 60)
for i in range(3):
    G_i = params_bayes[f'G_{i}']
    tau_R_i = params_bayes[f'tau_R_{i}']
    tau_eff_i = max(tau_R_i, tau_s_bayes)
    stress_ss_i = G_i * tau_eff_i * gamma_dot / (1.0 + (tau_eff_i * gamma_dot)**2)
    
    print(f"\nMode {i}:")
    print(f"  Buildup timescale: τ_eff = {tau_eff_i:.4e} s")
    print(f"  Steady-state stress: σ_ss = {stress_ss_i:.4e} Pa")
    print(f"  Time to 95% steady-state: t_95 = {3.0*tau_eff_i:.4e} s")
    
    if tau_s_bayes > tau_R_i:
        print(f"  ✓ STICKER-DOMINATED: Responds on collective timescale τ_s")
    else:
        print(f"  ✓ ROUSE-DOMINATED: Responds on intrinsic timescale τ_R")

# Total steady-state stress
sigma_ss_total = sum(params_bayes[f'G_{i}'] * max(params_bayes[f'tau_R_{i}'], tau_s_bayes) * gamma_dot / 
                     (1.0 + (max(params_bayes[f'tau_R_{i}'], tau_s_bayes) * gamma_dot)**2) 
                     for i in range(3)) + params_bayes['eta_s'] * gamma_dot
print(f"\nTotal Steady-State Stress: σ_ss = {sigma_ss_total:.4e} Pa")

# Longest relaxation time
tau_max = max(max(params_bayes[f'tau_R_{i}'], tau_s_bayes) for i in range(3))
print(f"Longest relaxation time: τ_max = {tau_max:.4e} s")
print(f"Time to full steady state: t_full ≈ {3.0*tau_max:.4e} s")

## Save Results

In [None]:
# Save results to disk
output_path = save_tnt_results(model, result_bayes, "sticky_rouse", "startup", param_names)
print(f"Results saved to: {output_path}")

## Further Reading

### TNT Documentation

- **[TNT Model Family Overview](../../docs/source/models/tnt/index.rst)**: Complete guide to all 5 TNT models
- **[TNT Protocols Reference](../../docs/source/models/tnt/tnt_protocols.rst)**: Mathematical framework for all protocols
- **[TNT Knowledge Extraction](../../docs/source/models/tnt/tnt_knowledge_extraction.rst)**: Guide for interpreting fitted parameters

### Related Notebooks

Explore other protocols in this model family and compare with advanced TNT models.

## Next Steps

- **Notebook 27**: Stress relaxation with hierarchical power-law decay
- **Notebook 28**: Creep compliance analysis
- **Notebook 29**: SAOS frequency sweeps with Rouse scaling
- **Advanced Models**: Compare with other TNT variants (Notebooks 01-24)

### Key References

1. **Tanaka, F., & Edwards, S. F.** (1992). Viscoelastic properties of physically crosslinked networks. 1. Transient network theory. *Macromolecules*, 25(5), 1516-1523. [DOI: 10.1021/ma00031a024](https://doi.org/10.1021/ma00031a024)
   - **Original TNT framework**: Conformation tensor dynamics for reversible networks

2. **Green, M. S., & Tobolsky, A. V.** (1946). A new approach to the theory of relaxing polymeric media. *Journal of Chemical Physics*, 14(2), 80-92. [DOI: 10.1063/1.1724109](https://doi.org/10.1063/1.1724109)
   - **Transient network foundation**: Network strand creation and breakage kinetics

3. **Yamamoto, M.** (1956). The visco-elastic properties of network structure I. General formalism. *Journal of the Physical Society of Japan*, 11(4), 413-421. [DOI: 10.1143/JPSJ.11.413](https://doi.org/10.1143/JPSJ.11.413)
   - **Network viscoelasticity theory**: Mathematical formulation of temporary networks

4. **Bell, G. I.** (1978). Models for the specific adhesion of cells to cells. *Science*, 200(4342), 618-627. [DOI: 10.1126/science.347575](https://doi.org/10.1126/science.347575)
   - **Bell breakage model**: Stress-dependent bond dissociation kinetics

5. **Sprakel, J., Spruijt, E., Cohen Stuart, M. A., van der Gucht, J., & Besseling, N. A. M.** (2008). Universal route to a state of pure shear flow. *Physical Review Letters*, 101(24), 248304. [DOI: 10.1103/PhysRevLett.101.248304](https://doi.org/10.1103/PhysRevLett.101.248304)
   - **TNT experimental validation**: Flow curve measurements and rheological signatures
