# TNT Loop-Bridge: Startup Shear

## Objectives

- Fit TNT Loop-Bridge model to startup shear data
- Understand bridge fraction evolution during transient flow
- Analyze force-dependent stress overshoot
- Quantify bridge stretching vs detachment dynamics
- Perform Bayesian inference for parameter uncertainty

## Setup

In [None]:
import os
import sys
import time

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    %pip install -q rheojax

import numpy as np
import matplotlib.pyplot as plt
import arviz as az

from rheojax.core.jax_config import safe_import_jax
jax, jnp = safe_import_jax()
from rheojax.core.jax_config import verify_float64
verify_float64()

from rheojax.models.tnt import TNTLoopBridge

sys.path.insert(0, os.path.join("..", "utils"))
from tnt_tutorial_utils import (
    load_ml_ikh_flow_curve,
    load_pnas_startup,
    load_laponite_relaxation,
    load_ml_ikh_creep,
    load_epstein_saos,
    load_pnas_laos,
    compute_fit_quality,
    print_convergence_summary,
    print_parameter_comparison,
    save_tnt_results,
    get_tnt_loop_bridge_param_names,
    plot_loop_bridge_fraction,
    plot_bell_nu_sweep,
    compute_maxwell_moduli,
    compute_bell_effective_lifetime,
    print_nu_interpretation,
)

param_names = get_tnt_loop_bridge_param_names()

## Theory: Startup Shear Dynamics

### Physical Picture

During startup from rest:
1. **Initial state**: Bridges at equilibrium (f_B = f_B_eq)
2. **Early times**: Bridges stretch elastically → stress increases
3. **Intermediate**: Force on bridges increases → Bell detachment accelerates
4. **Overshoot**: Peak stress when stretching rate = detachment rate
5. **Late times**: Bridge fraction decreases to steady state → stress plateau

### Governing Equations

**Bridge Fraction Evolution:**
```
df_B/dt = (1 - f_B)/tau_a - f_B * exp(nu * gamma_dot * tau_b) / tau_b
```

**Stress Evolution (Maxwell backbone):**
```
sigma(t) = f_B(t) * G * int_0^t gamma_dot * exp(-(t-s)/tau_b) ds + eta_s * gamma_dot
```

For constant gamma_dot:
```
sigma(t) = f_B(t) * G * gamma_dot * tau_b * [1 - exp(-t/tau_b)] + eta_s * gamma_dot
```

### Stress Overshoot Mechanism

The overshoot arises from competition between:
- **Stretching**: Builds stress via Maxwell element
- **Detachment**: Reduces stress via f_B decrease

Peak occurs when df_B/dt is most negative (fastest detachment).

### Force-Dependent Features

- Higher nu → stronger overshoot (more force sensitivity)
- Higher gamma_dot → earlier overshoot (faster force buildup)
- Peak stress scales with G * f_B_eq
- Overshoot time ~ tau_b / (1 + nu * gamma_dot * tau_b)

## Load Startup Data

In [None]:
time_data, stress = load_pnas_startup(gamma_dot=1.0)

print(f"Data points: {len(time_data)}")
print(f"Time range: {time_data.min():.2e} - {time_data.max():.2e} s")
print(f"Stress range: {stress.min():.2f} - {stress.max():.2f} Pa")
print(f"Shear rate: 1.0 1/s")

fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(time_data, stress, 'o', label='Data', markersize=6)
ax.set_xlabel('Time (s)', fontsize=12)
ax.set_ylabel('Stress (Pa)', fontsize=12)
ax.set_title('Startup Shear Data (γ̇ = 1.0 1/s)', fontsize=14)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
plt.close('all')

## NLSQ Fitting

In [None]:
# CI mode: Skip slow NLSQ fit - use reasonable defaults
CI_MODE = os.environ.get("CI_MODE", "0") == "1"

model = TNTLoopBridge()

# Shear rate for startup (matches data loading)
gamma_dot = 1.0  # s⁻¹

if CI_MODE:
    print("CI_MODE: Using default parameters (NLSQ fit for startup is slow)")
    # Set reasonable parameters
    model.parameters.set_value('G', 500.0)  # Pa
    model.parameters.set_value('tau_b', 1.0)  # s
    model.parameters.set_value('tau_a', 5.0)  # s
    model.parameters.set_value('nu', 1.0)
    model.parameters.set_value('f_B_eq', 0.5)
    model.parameters.set_value('eta_s', 0.01)  # Pa·s
    t_nlsq = 0.0
else:
    print("Starting NLSQ fit...")
    t_start = time.time()
    nlsq_result = model.fit(time_data, stress, test_mode='startup', gamma_dot=gamma_dot, method='scipy')
    t_nlsq = time.time() - t_start
    print(f"\nNLSQ fit completed in {t_nlsq:.2f} seconds")

print(f"\nFitted parameters:")
for name in param_names:
    value = model.parameters.get_value(name)
    print(f"  {name}: {value:.4e}")

stress_pred_fit = model.predict(time_data, test_mode='startup', gamma_dot=gamma_dot)
metrics = compute_fit_quality(stress, stress_pred_fit)
print(f"\nFit quality:")
print(f"  R²: {metrics['R2']:.6f}")
print(f"  RMSE: {metrics['RMSE']:.4e}")

## NLSQ Fit Visualization

In [None]:
time_pred = jnp.linspace(time_data.min(), time_data.max(), 200)
stress_pred = model.predict(time_pred, test_mode='startup', gamma_dot=1.0)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

# Startup curve
ax1.plot(time_data, stress, 'o', label='Data', markersize=6, alpha=0.7)
ax1.plot(time_pred, stress_pred, '-', label='NLSQ Fit', linewidth=2)
ax1.set_xlabel('Time (s)', fontsize=12)
ax1.set_ylabel('Stress (Pa)', fontsize=12)
ax1.set_title(f'Startup Fit (R² = {metrics["R2"]:.4f})', fontsize=14)
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# Residuals
stress_fit = model.predict(time_data, test_mode='startup', gamma_dot=1.0)
residuals = (stress - stress_fit) / stress * 100
ax2.plot(time_data, residuals, 'o', markersize=6)
ax2.axhline(0, color='k', linestyle='--', alpha=0.3)
ax2.set_xlabel('Time (s)', fontsize=12)
ax2.set_ylabel('Relative Error (%)', fontsize=12)
ax2.set_title('Fit Residuals', fontsize=14)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()
plt.close('all')

## Physical Analysis: Bridge Fraction Evolution

In [None]:
# Analytical estimate of bridge fraction during startup
gamma_dot = 1.0
k_detach = jnp.exp(model.parameters.get_value('nu') * gamma_dot * model.parameters.get_value('tau_b')) / model.parameters.get_value('tau_b')
k_attach = 1.0 / model.parameters.get_value('tau_a')
f_B_ss = k_attach / (k_attach + k_detach)
lambda_eff = k_attach + k_detach
f_B_t = f_B_ss + (model.parameters.get_value('f_B_eq') - f_B_ss) * jnp.exp(-lambda_eff * time_pred)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

# Bridge fraction evolution
ax1.plot(time_pred, f_B_t, '-', linewidth=2)
ax1.axhline(model.parameters.get_value('f_B_eq'), color='r', linestyle='--', alpha=0.5, label=f'f_B_eq = {model.parameters.get_value('f_B_eq'):.4f}')
ax1.axhline(f_B_ss, color='g', linestyle='--', alpha=0.5, label=f'f_B_ss = {f_B_ss:.4f}')
ax1.set_xlabel('Time (s)', fontsize=12)
ax1.set_ylabel('Bridge Fraction f_B', fontsize=12)
ax1.set_title('Bridge Fraction Evolution During Startup', fontsize=14)
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)
ax1.set_ylim([0, 1])

# Effective modulus evolution
G_eff_t = f_B_t * model.parameters.get_value('G')
ax2.plot(time_pred, G_eff_t, '-', linewidth=2)
ax2.axhline(model.parameters.get_value('f_B_eq') * model.parameters.get_value('G'), color='r', linestyle='--', alpha=0.5, label='G_eff(0)')
ax2.axhline(f_B_ss * model.parameters.get_value('G'), color='g', linestyle='--', alpha=0.5, label='G_eff(∞)')
ax2.set_xlabel('Time (s)', fontsize=12)
ax2.set_ylabel('Effective Modulus (Pa)', fontsize=12)
ax2.set_title('Effective Modulus Evolution', fontsize=14)
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()
plt.close('all')

print(f"\nBridge fraction dynamics:")
print(f"  Initial f_B: {model.parameters.get_value('f_B_eq'):.4f}")
print(f"  Steady-state f_B: {f_B_ss:.4f}")
print(f"  Reduction: {(1 - f_B_ss/model.parameters.get_value('f_B_eq'))*100:.2f}%")
print(f"  Characteristic decay time: {1.0/lambda_eff:.4e} s")

## Physical Analysis: Stress Overshoot

In [None]:
# Locate stress overshoot
peak_idx = jnp.argmax(stress_pred)
peak_time = time_pred[peak_idx]
peak_stress = stress_pred[peak_idx]
steady_stress = stress_pred[-1]
overshoot_ratio = peak_stress / steady_stress

fig, ax = plt.subplots(figsize=(10, 7))
ax.plot(time_pred, stress_pred, '-', linewidth=2, label='Model Prediction')
ax.plot(peak_time, peak_stress, 'ro', markersize=10, label=f'Peak: {peak_stress:.2f} Pa at {peak_time:.3f} s')
ax.axhline(steady_stress, color='g', linestyle='--', alpha=0.5, label=f'Steady State: {steady_stress:.2f} Pa')
ax.set_xlabel('Time (s)', fontsize=12)
ax.set_ylabel('Stress (Pa)', fontsize=12)
ax.set_title(f'Stress Overshoot Analysis (Ratio = {overshoot_ratio:.3f})', fontsize=14)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
plt.close('all')

print(f"\nStress overshoot analysis:")
print(f"  Peak stress: {peak_stress:.4e} Pa")
print(f"  Peak time: {peak_time:.4e} s")
print(f"  Steady-state stress: {steady_stress:.4e} Pa")
print(f"  Overshoot ratio: {overshoot_ratio:.4f}")
print(f"  Peak time / tau_b: {peak_time / model.parameters.get_value('tau_b'):.4f}")

## Physical Analysis: Force-Dependent Overshoot

In [None]:
if CI_MODE:
    print("CI_MODE: Skipping shear rate sweep (multiple ODE predictions are slow)")
else:
    # Sweep shear rate to show overshoot variation
    gamma_dot_sweep = jnp.array([0.1, 1.0, 10.0, 100.0])
    time_sweep = jnp.linspace(0, 10.0 * model.parameters.get_value('tau_b'), 200)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

    for gd in gamma_dot_sweep:
        stress_sweep = model.predict(time_sweep, test_mode='startup', gamma_dot=gd)
        ax1.plot(time_sweep, stress_sweep, '-', linewidth=2, label=f'γ̇ = {gd:.1f} 1/s')

    ax1.set_xlabel('Time (s)', fontsize=12)
    ax1.set_ylabel('Stress (Pa)', fontsize=12)
    ax1.set_title('Startup Curves at Different Shear Rates', fontsize=14)
    ax1.legend(fontsize=10)
    ax1.grid(True, alpha=0.3)

    # Extract peak stress and time for each shear rate
    peak_stresses = []
    peak_times = []
    for gd in gamma_dot_sweep:
        stress_sweep = model.predict(time_sweep, test_mode='startup', gamma_dot=gd)
        peak_idx = jnp.argmax(stress_sweep)
        peak_stresses.append(stress_sweep[peak_idx])
        peak_times.append(time_sweep[peak_idx])

    ax2.loglog(gamma_dot_sweep, peak_stresses, 'o-', linewidth=2, markersize=8, label='Peak Stress')
    ax2.set_xlabel('Shear Rate (1/s)', fontsize=12)
    ax2.set_ylabel('Peak Stress (Pa)', fontsize=12)
    ax2.set_title('Peak Stress vs Shear Rate', fontsize=14)
    ax2.legend(fontsize=10)
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()
    plt.close('all')

    print(f"\nShear rate dependence:")
    for i, gd in enumerate(gamma_dot_sweep):
        print(f"  γ̇ = {gd:.1f} 1/s: peak = {peak_stresses[i]:.4e} Pa at t = {peak_times[i]:.4e} s")

## Bayesian Inference

In [None]:
# CI mode: Skip Bayesian inference to avoid JIT compilation timeout
# Set CI_MODE=1 environment variable to skip
CI_MODE = os.environ.get("CI_MODE", "0") == "1"

# Configuration
NUM_WARMUP = 200
NUM_SAMPLES = 500
NUM_CHAINS = 1

if CI_MODE:
    print("CI_MODE: Skipping Bayesian inference (JIT compilation takes >600s)")
    print("To run Bayesian analysis, run without CI_MODE environment variable")
    # Create a placeholder result with current NLSQ parameters
    class BayesianResult:
        def __init__(self, model, param_names):
            self.posterior_samples = {name: np.array([model.parameters.get_value(name)] * NUM_SAMPLES) for name in param_names}
    bayes_result = BayesianResult(model, param_names)
    bayes_time = 0.0
else:
    print(f"Running NUTS with {NUM_CHAINS} chain(s)...")
    print(f"Warmup: {NUM_WARMUP} samples, Sampling: {NUM_SAMPLES} samples")
    
    start_time = time.time()
    bayes_result = model.fit_bayesian(
        time_data, stress,
        test_mode='startup',
        gamma_dot=1,
        num_warmup=NUM_WARMUP,
        num_samples=NUM_SAMPLES,
        num_chains=NUM_CHAINS,
        seed=42
    )
    bayes_time = time.time() - start_time
    
    print(f"\nBayesian inference completed in {bayes_time:.1f} seconds")


## Convergence Diagnostics

In [None]:
# Skip convergence diagnostics in CI mode
if not CI_MODE:
    print_convergence_summary(bayes_result, param_names)
else:
    print("CI_MODE: Skipping convergence diagnostics")


## Parameter Comparison: NLSQ vs Bayesian

In [None]:
print_parameter_comparison(model, bayes_result.posterior_samples, param_names)

## ArviZ: Trace Plot

In [None]:
# Skip trace plot in CI mode
if not CI_MODE:
    idata = az.from_dict(posterior=bayes_result.posterior_samples)
    
    axes = az.plot_trace(idata, var_names=param_names, compact=False, backend_kwargs={'figsize': (12, 10)})
    plt.tight_layout()
    plt.show()
    plt.close()
else:
    print("CI_MODE: Skipping trace plot")


## ArviZ: Posterior Distributions

In [None]:
# Skip ArviZ plot in CI mode
if not CI_MODE:
    fig = az.plot_posterior(idata, var_names=param_names, hdi_prob=0.95, backend_kwargs={'figsize': (12, 8)})
    plt.tight_layout()
    plt.show()
    plt.close()
else:
    print("CI_MODE: Skipping ArviZ plot")


## ArviZ: Pair Plot

In [None]:
# Skip pair plot in CI mode
if not CI_MODE:
    axes = az.plot_pair(
        idata,
        var_names=param_names,
        kind='kde',
        marginals=True,
        backend_kwargs={'figsize': (14, 14)}
    )
    plt.tight_layout()
    plt.show()
    plt.close()
else:
    print("CI_MODE: Skipping pair plot")


## Posterior Predictive

In [None]:
if CI_MODE:
    print("CI_MODE: Skipping posterior predictive (200 ODE predictions would take >300s)")
    # Single NLSQ prediction for plot
    stress_pred_final = model.predict(time_pred, test_mode='startup', gamma_dot=1.0)

    fig, ax = plt.subplots(figsize=(10, 7))
    ax.plot(time_data, stress, 'o', label='Data', markersize=6, alpha=0.7)
    ax.plot(time_pred, stress_pred_final, '-', label='NLSQ fit', linewidth=2, color='C1')
    ax.set_xlabel('Time (s)', fontsize=12)
    ax.set_ylabel('Stress (Pa)', fontsize=12)
    ax.set_title('NLSQ Fit (CI Mode)', fontsize=14)
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    plt.close('all')
else:
    posterior = bayes_result.posterior_samples
    n_draws = 200
    indices = np.random.choice(NUM_SAMPLES, size=n_draws, replace=False)

    predictions = []
    for i in indices:
        # Set parameters from posterior sample
        for name in param_names:
            model.parameters.set_value(name, float(posterior[name][i]))
        # Use predict method
        pred = model.predict(time_pred, test_mode='startup', gamma_dot=1.0)
        predictions.append(np.array(pred))

    predictions = np.array(predictions)
    pred_mean = predictions.mean(axis=0)
    pred_lower = np.percentile(predictions, 2.5, axis=0)
    pred_upper = np.percentile(predictions, 97.5, axis=0)

    fig, ax = plt.subplots(figsize=(10, 7))
    ax.plot(time_data, stress, 'o', label='Data', markersize=6, alpha=0.7, zorder=3)
    ax.plot(time_pred, pred_mean, '-', label='Posterior Mean', linewidth=2, zorder=2)
    ax.fill_between(time_pred, pred_lower, pred_upper, alpha=0.3, label='95% Credible Interval', zorder=1)
    ax.set_xlabel('Time (s)', fontsize=12)
    ax.set_ylabel('Stress (Pa)', fontsize=12)
    ax.set_title('Posterior Predictive Distribution', fontsize=14)
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    plt.close('all')

## Physical Interpretation

In [None]:
print("\n=== Physical Interpretation ===")
print(f"\n1. Material Properties:")
print(f"   - Plateau modulus G: {model.parameters.get_value('G'):.4e} Pa")
print(f"   - Equilibrium bridge fraction: {model.parameters.get_value('f_B_eq'):.4f}")
print(f"   - Initial effective modulus: {model.parameters.get_value('G') * model.parameters.get_value('f_B_eq'):.4e} Pa")

print(f"\n2. Startup Dynamics:")
print(f"   - Bridge detachment time tau_b: {model.parameters.get_value('tau_b'):.4e} s")
print(f"   - Loop attachment time tau_a: {model.parameters.get_value('tau_a'):.4e} s")
print(f"   - Effective decay time: {1.0/lambda_eff:.4e} s")
print(f"   - Bridge fraction reduction: {(1 - f_B_ss/model.parameters.get_value('f_B_eq'))*100:.2f}%")

print(f"\n3. Stress Overshoot:")
print(f"   - Peak stress: {peak_stress:.4e} Pa")
print(f"   - Peak time: {peak_time:.4e} s ({peak_time/model.parameters.get_value('tau_b'):.2f} * tau_b)")
print(f"   - Steady-state stress: {steady_stress:.4e} Pa")
print(f"   - Overshoot ratio: {overshoot_ratio:.4f}")

print(f"\n4. Bell Detachment:")
print(f"   - Nu parameter: {model.parameters.get_value('nu'):.4f}")
print(f"   - Force factor at γ̇ = 1.0 1/s: {jnp.exp(model.parameters.get_value('nu') * 1.0 * model.parameters.get_value('tau_b')):.4f}")
print(f"   - Effective detachment rate: {k_detach:.4e} 1/s")
print(f"   - Rate enhancement: {k_detach * model.parameters.get_value('tau_b'):.4f}x")

print(f"\n5. Timescale Separation:")
print(f"   - tau_a/tau_b ratio: {model.parameters.get_value('tau_a')/model.parameters.get_value('tau_b'):.4f}")
if model.parameters.get_value('tau_a') > model.parameters.get_value('tau_b'):
    print(f"   - Slow attachment dominates re-equilibration")
else:
    print(f"   - Fast attachment allows rapid re-equilibration")

## Save Results

In [None]:
save_tnt_results(model, bayes_result, "loop_bridge", "startup", param_names)
print("Results saved to reference_outputs/tnt/loop_bridge_startup_results.npz")

## Key Takeaways

1. **Transient Kinetics**: Startup captures bridge fraction evolution from equilibrium to steady state

2. **Stress Overshoot**: Peak arises from competition between elastic stretching and force-enhanced detachment

3. **Force Dependence**: Higher nu → stronger overshoot, earlier peak time

4. **Bridge Depletion**: f_B decreases during startup, reducing effective modulus

5. **Timescale Control**: tau_b sets overshoot time, tau_a/tau_b controls re-equilibration

6. **Shear Rate Sensitivity**: Higher gamma_dot → earlier overshoot, higher peak stress (until saturation)

7. **Physical Consistency**: Verify overshoot ratio > 1, peak time ~ tau_b, f_B_ss < f_B_eq