# TNT Loop-Bridge: Stress Relaxation

## Objectives

- Fit TNT Loop-Bridge model to stress relaxation data
- Understand bridge re-equilibration after flow cessation
- Analyze tau_a/tau_b ratio effects on relaxation dynamics
- Quantify two-timescale relaxation behavior
- Perform Bayesian inference for parameter uncertainty

## Setup

In [None]:
import os
import sys
import time

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    %pip install -q rheojax

import numpy as np
import matplotlib.pyplot as plt
import arviz as az

from rheojax.core.jax_config import safe_import_jax
jax, jnp = safe_import_jax()
from rheojax.core.jax_config import verify_float64
verify_float64()

from rheojax.models.tnt import TNTLoopBridge

sys.path.insert(0, os.path.join("..", "utils"))
from tnt_tutorial_utils import (
    load_ml_ikh_flow_curve,
    load_pnas_startup,
    load_laponite_relaxation,
    load_ml_ikh_creep,
    load_epstein_saos,
    load_pnas_laos,
    compute_fit_quality,
    print_convergence_summary,
    print_parameter_comparison,
    save_tnt_results,
    get_tnt_loop_bridge_param_names,
    plot_loop_bridge_fraction,
    plot_bell_nu_sweep,
    compute_maxwell_moduli,
    compute_bell_effective_lifetime,
    print_nu_interpretation,
)

param_names = get_tnt_loop_bridge_param_names()

## Theory: Stress Relaxation Dynamics

### Physical Picture

After flow cessation (gamma_dot → 0):
1. **Initial state**: Bridges depleted (f_B < f_B_eq) from prior shear
2. **Early times**: Elastic stress relaxes via tau_b (bridge detachment without force)
3. **Intermediate**: Bridges re-attach (loops → bridges)
4. **Late times**: Bridge fraction recovers to f_B_eq

### Governing Equations

**Bridge Fraction Recovery (no shear, gamma_dot = 0):**
```
df_B/dt = (1 - f_B)/tau_a - f_B/tau_b
```

**Steady State (equilibrium):**
```
f_B_eq = 1 / (1 + tau_a/tau_b)
```

**Solution (exponential recovery):**
```
f_B(t) = f_B_eq + (f_B_0 - f_B_eq) * exp(-t / tau_eq)
tau_eq = tau_a * tau_b / (tau_a + tau_b)
```

**Stress Relaxation (Maxwell backbone):**
```
G(t) = f_B(t) * G * exp(-t/tau_b)
```

### Two-Timescale Relaxation

1. **Fast relaxation (tau_b)**: Elastic stress decay via bridge detachment
2. **Slow recovery (tau_a)**: Modulus recovery via bridge re-attachment

If tau_a >> tau_b:
- Fast elastic relaxation followed by slow structural recovery
- G(t) decays quickly, then partially recovers

If tau_a ~ tau_b:
- Coupled relaxation-recovery, single effective timescale tau_eq

### Ratio Effects

- **tau_a/tau_b >> 1**: Slow attachment, low f_B_eq, weak recovery
- **tau_a/tau_b ~ 1**: Balanced kinetics, moderate f_B_eq
- **tau_a/tau_b << 1**: Fast attachment, high f_B_eq, strong recovery

## Load Relaxation Data

In [None]:
time_data, modulus = load_laponite_relaxation(aging_time=1800)

print(f"Data points: {len(time_data)}")
print(f"Time range: {time_data.min():.2e} - {time_data.max():.2e} s")
print(f"Modulus range: {modulus.min():.2e} - {modulus.max():.2e} Pa")
print(f"Aging time: 1800 s")

fig, ax = plt.subplots(figsize=(8, 6))
ax.loglog(time_data, modulus, 'o', label='Data', markersize=6)
ax.set_xlabel('Time (s)', fontsize=12)
ax.set_ylabel('Relaxation Modulus G(t) (Pa)', fontsize=12)
ax.set_title('Stress Relaxation Data', fontsize=14)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
display(fig)
plt.close(fig)

## NLSQ Fitting

In [None]:
model = TNTLoopBridge()

print("Starting NLSQ fit...")
t_start = time.time()

nlsq_result = model.fit(time_data, modulus, test_mode='relaxation', method='scipy')

t_nlsq = time.time() - t_start
print(f"\nNLSQ fit completed in {t_nlsq:.2f} seconds")
print(f"\nFitted parameters:")
for name in param_names:
    value = model.parameters.get_value(name)
    print(f"  {name}: {value:.4e}")

modulus_pred_fit = model.predict(time_data, test_mode='relaxation')
metrics = compute_fit_quality(modulus, modulus_pred_fit)
print(f"\nFit quality:")
print(f"  R²: {metrics['R2']:.6f}")
print(f"  RMSE: {metrics['RMSE']:.4e}")
# print(f"  Max relative error: {metrics['max_rel_error']:.2f}%")

## NLSQ Fit Visualization

In [None]:
time_pred = jnp.logspace(jnp.log10(time_data.min()), jnp.log10(time_data.max()), 200)
modulus_pred = model.predict(time_pred, test_mode='relaxation')

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

# Relaxation curve
ax1.loglog(time_data, modulus, 'o', label='Data', markersize=6, alpha=0.7)
ax1.loglog(time_pred, modulus_pred, '-', label='NLSQ Fit', linewidth=2)
ax1.set_xlabel('Time (s)', fontsize=12)
ax1.set_ylabel('G(t) (Pa)', fontsize=12)
ax1.set_title(f'Relaxation Fit (R² = {metrics["R2"]:.4f})', fontsize=14)
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# Residuals
modulus_fit = model.predict(time_data, test_mode='relaxation')
residuals = (modulus - modulus_fit) / modulus * 100
ax2.semilogx(time_data, residuals, 'o', markersize=6)
ax2.axhline(0, color='k', linestyle='--', alpha=0.3)
ax2.set_xlabel('Time (s)', fontsize=12)
ax2.set_ylabel('Relative Error (%)', fontsize=12)
ax2.set_title('Fit Residuals', fontsize=14)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
display(fig)
plt.close(fig)

## Physical Analysis: Bridge Re-equilibration

In [None]:
# Compute equilibrium and re-equilibration timescale
tau_eq = model.parameters.get_value('tau_a') * model.parameters.get_value('tau_b') / (model.parameters.get_value('tau_a') + model.parameters.get_value('tau_b'))
ratio = model.parameters.get_value('tau_a') / model.parameters.get_value('tau_b')

# Assume initial state is depleted (e.g., 50% of equilibrium)
f_B_0 = 0.5 * model.parameters.get_value('f_B_eq')
f_B_recovery = model.parameters.get_value('f_B_eq') + (f_B_0 - model.parameters.get_value('f_B_eq')) * jnp.exp(-time_pred / tau_eq)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

# Bridge fraction recovery
ax1.semilogx(time_pred, f_B_recovery, '-', linewidth=2)
ax1.axhline(model.parameters.get_value('f_B_eq'), color='r', linestyle='--', alpha=0.5, label=f'f_B_eq = {model.parameters.get_value('f_B_eq'):.4f}')
ax1.axhline(f_B_0, color='g', linestyle='--', alpha=0.5, label=f'f_B_0 = {f_B_0:.4f}')
ax1.axvline(tau_eq, color='purple', linestyle='--', alpha=0.5, label=f'τ_eq = {tau_eq:.4e} s')
ax1.set_xlabel('Time (s)', fontsize=12)
ax1.set_ylabel('Bridge Fraction f_B', fontsize=12)
ax1.set_title('Bridge Fraction Recovery', fontsize=14)
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)
ax1.set_ylim([0, 1])

# Effective modulus evolution
G_eff_recovery = f_B_recovery * model.parameters.get_value('G') * jnp.exp(-time_pred / model.parameters.get_value('tau_b'))
ax2.loglog(time_pred, G_eff_recovery, '-', linewidth=2, label='G(t) with recovery')
ax2.loglog(time_pred, model.parameters.get_value('G') * model.parameters.get_value('f_B_eq') * jnp.exp(-time_pred / model.parameters.get_value('tau_b')), '--', 
           linewidth=2, alpha=0.5, label='G(t) no recovery')
ax2.set_xlabel('Time (s)', fontsize=12)
ax2.set_ylabel('G(t) (Pa)', fontsize=12)
ax2.set_title('Modulus Evolution with Re-equilibration', fontsize=14)
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
display(fig)
plt.close(fig)

print(f"\nBridge re-equilibration:")
print(f"  Equilibrium bridge fraction: {model.parameters.get_value('f_B_eq'):.4f}")
print(f"  Re-equilibration time tau_eq: {tau_eq:.4e} s")
print(f"  tau_a / tau_b ratio: {ratio:.4f}")

## Physical Analysis: Two-Timescale Dynamics

In [None]:
# Decompose relaxation into fast (elastic) and slow (structural) components
G_elastic = model.parameters.get_value('G') * model.parameters.get_value('f_B_eq') * jnp.exp(-time_pred / model.parameters.get_value('tau_b'))
G_total = modulus_pred

fig, ax = plt.subplots(figsize=(10, 7))
ax.loglog(time_pred, G_total, '-', linewidth=2, label='Total G(t)', color='blue')
ax.loglog(time_pred, G_elastic, '--', linewidth=2, label=f'Elastic decay (τ_b = {model.parameters.get_value('tau_b'):.2e} s)', color='red', alpha=0.7)
ax.axvline(model.parameters.get_value('tau_b'), color='red', linestyle=':', alpha=0.5, label='τ_b')
ax.axvline(tau_eq, color='purple', linestyle=':', alpha=0.5, label='τ_eq')
ax.set_xlabel('Time (s)', fontsize=12)
ax.set_ylabel('G(t) (Pa)', fontsize=12)
ax.set_title('Two-Timescale Relaxation', fontsize=14)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
display(fig)
plt.close(fig)

print(f"\nTimescale separation:")
print(f"  Fast (elastic) decay: tau_b = {model.parameters.get_value('tau_b'):.4e} s")
print(f"  Slow (structural) recovery: tau_a = {model.parameters.get_value('tau_a'):.4e} s")
print(f"  Effective recovery time: tau_eq = {tau_eq:.4e} s")
print(f"  Separation factor: tau_a/tau_b = {ratio:.4f}")

if ratio > 10:
    print(f"  → Well-separated timescales: fast elastic relaxation, slow structural recovery")
elif ratio > 2:
    print(f"  → Moderate separation: both processes observable")
else:
    print(f"  → Coupled relaxation: single effective timescale")

## Physical Analysis: Ratio Effects

In [None]:
# Sweep tau_a/tau_b ratio to show relaxation variation
ratio_sweep = jnp.array([0.1, 1.0, 10.0, 100.0])
time_sweep = jnp.logspace(-3, 2, 200)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

for r in ratio_sweep:
    tau_a_sweep = r * model.parameters.get_value('tau_b')
    f_B_eq_sweep = 1.0 / (1.0 + r)
    tau_eq_sweep = tau_a_sweep * model.parameters.get_value('tau_b') / (tau_a_sweep + model.parameters.get_value('tau_b'))
    f_B_sweep = f_B_eq_sweep + (f_B_0 - f_B_eq_sweep) * jnp.exp(-time_sweep / tau_eq_sweep)
    G_sweep = f_B_sweep * model.parameters.get_value('G') * jnp.exp(-time_sweep / model.parameters.get_value('tau_b'))
    
    ax1.loglog(time_sweep, G_sweep, '-', linewidth=2, label=f'τ_a/τ_b = {r:.1f}')

ax1.set_xlabel('Time (s)', fontsize=12)
ax1.set_ylabel('G(t) (Pa)', fontsize=12)
ax1.set_title('Relaxation Modulus vs Ratio', fontsize=14)
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# Equilibrium bridge fraction vs ratio
ratio_range = jnp.logspace(-1, 2, 50)
f_B_eq_range = 1.0 / (1.0 + ratio_range)
ax2.semilogx(ratio_range, f_B_eq_range, '-', linewidth=2)
ax2.axhline(0.5, color='r', linestyle='--', alpha=0.3)
ax2.axvline(1.0, color='r', linestyle='--', alpha=0.3)
ax2.set_xlabel('τ_a / τ_b', fontsize=12)
ax2.set_ylabel('Equilibrium f_B', fontsize=12)
ax2.set_title('Bridge Fraction vs Kinetic Ratio', fontsize=14)
ax2.grid(True, alpha=0.3)
ax2.set_ylim([0, 1])

plt.tight_layout()
display(fig)
plt.close(fig)

## Bayesian Inference

In [None]:
NUM_WARMUP = 200
NUM_SAMPLES = 500
NUM_CHAINS = 1

print(f"Starting Bayesian inference with NUTS...")
print(f"  Warmup: {NUM_WARMUP}, Samples: {NUM_SAMPLES}, Chains: {NUM_CHAINS}")

t_start = time.time()
bayes_result = model.fit_bayesian(
    time_data, modulus,
    test_mode='relaxation',
    num_warmup=NUM_WARMUP,
    num_samples=NUM_SAMPLES,
    num_chains=NUM_CHAINS,
    seed=42
)
t_bayes = time.time() - t_start

print(f"\nBayesian inference completed in {t_bayes:.2f} seconds")
print(f"Speedup vs NLSQ: {t_bayes/t_nlsq:.1f}x slower (includes MCMC overhead)")

## Convergence Diagnostics

In [None]:
print_convergence_summary(bayes_result, param_names)

## Parameter Comparison: NLSQ vs Bayesian

In [None]:
print_parameter_comparison(model, bayes_result.posterior_samples, param_names)

## ArviZ: Trace Plot

In [None]:
idata = az.from_dict(posterior=bayes_result.posterior_samples)

fig = az.plot_trace(idata, var_names=param_names, compact=False, backend_kwargs={'figsize': (12, 10)})
plt.tight_layout()
display(fig)
plt.close()

## ArviZ: Posterior Distributions

In [None]:
fig = az.plot_posterior(idata, var_names=param_names, hdi_prob=0.95, backend_kwargs={'figsize': (12, 8)})
plt.tight_layout()
display(fig)
plt.close()

## ArviZ: Pair Plot

In [None]:
fig = az.plot_pair(
    idata,
    var_names=param_names,
    kind='kde',
    marginals=True,
    backend_kwargs={'figsize': (14, 14)}
)
plt.tight_layout()
display(fig)
plt.close()

## Posterior Predictive

In [None]:
posterior = bayes_result.posterior_samples
n_draws = 200
indices = np.random.choice(NUM_SAMPLES, size=n_draws, replace=False)

predictions = []
for i in indices:
    # Set parameters from posterior sample
    for name in param_names:
        model.parameters.set_value(name, float(posterior[name][i]))
    # Use predict method
    pred = model.predict(time_pred, test_mode='relaxation')
    predictions.append(np.array(pred))

predictions = np.array(predictions)
pred_mean = predictions.mean(axis=0)
pred_lower = np.percentile(predictions, 2.5, axis=0)
pred_upper = np.percentile(predictions, 97.5, axis=0)

fig, ax = plt.subplots(figsize=(10, 7))
ax.loglog(time_data, modulus, 'o', label='Data', markersize=6, alpha=0.7, zorder=3)
ax.loglog(time_pred, pred_mean, '-', label='Posterior Mean', linewidth=2, zorder=2)
ax.fill_between(time_pred, pred_lower, pred_upper, alpha=0.3, label='95% Credible Interval', zorder=1)
ax.set_xlabel('Time (s)', fontsize=12)
ax.set_ylabel('G(t) (Pa)', fontsize=12)
ax.set_title('Posterior Predictive Distribution', fontsize=14)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
display(fig)
plt.close(fig)

## Physical Interpretation

In [None]:
print("\n=== Physical Interpretation ===")
print(f"\n1. Material Properties:")
print(f"   - Plateau modulus G: {model.parameters.get_value('G'):.4e} Pa")
print(f"   - Equilibrium bridge fraction: {model.parameters.get_value('f_B_eq'):.4f}")
print(f"   - Equilibrium modulus: {model.parameters.get_value('G') * model.parameters.get_value('f_B_eq'):.4e} Pa")

print(f"\n2. Relaxation Timescales:")
print(f"   - Bridge detachment time tau_b: {model.parameters.get_value('tau_b'):.4e} s")
print(f"   - Loop attachment time tau_a: {model.parameters.get_value('tau_a'):.4e} s")
print(f"   - Re-equilibration time tau_eq: {tau_eq:.4e} s")
print(f"   - Ratio tau_a/tau_b: {ratio:.4f}")

print(f"\n3. Two-Timescale Behavior:")
if ratio > 10:
    print(f"   - Well-separated timescales detected")
    print(f"   - Fast elastic relaxation (~ tau_b) followed by slow structural recovery (~ tau_a)")
elif ratio > 2:
    print(f"   - Moderate timescale separation")
    print(f"   - Both elastic and structural processes observable")
else:
    print(f"   - Coupled relaxation-recovery")
    print(f"   - Single effective timescale (~ tau_eq)")

print(f"\n4. Bridge Re-equilibration:")
print(f"   - Initial depletion assumed: 50% of f_B_eq")
print(f"   - Recovery follows: f_B(t) = f_B_eq + (f_B_0 - f_B_eq) * exp(-t/tau_eq)")
print(f"   - Time to 90% recovery: {-tau_eq * jnp.log(0.1):.4e} s")

print(f"\n5. Modulus Decay:")
initial_modulus = modulus_pred[0]
final_modulus = modulus_pred[-1]
decay_ratio = final_modulus / initial_modulus
print(f"   - Initial modulus: {initial_modulus:.4e} Pa")
print(f"   - Final modulus: {final_modulus:.4e} Pa")
print(f"   - Decay ratio: {decay_ratio:.4f}")
print(f"   - Decades of relaxation: {jnp.log10(initial_modulus/final_modulus):.2f}")

## Save Results

In [None]:
save_tnt_results(model, bayes_result, "loop_bridge", "relaxation", param_names)
print("Results saved to reference_outputs/tnt/loop_bridge_relaxation_results.npz")

## Key Takeaways

1. **Two-Timescale Relaxation**: Fast elastic decay (tau_b) and slow structural recovery (tau_a)

2. **Bridge Re-equilibration**: f_B recovers exponentially with time constant tau_eq = tau_a * tau_b / (tau_a + tau_b)

3. **Ratio Control**: tau_a/tau_b determines equilibrium bridge fraction and recovery rate

4. **Physical Mechanism**: Detachment (loops ← bridges) competes with attachment (loops → bridges)

5. **Timescale Separation**: High ratio → well-separated processes, low ratio → coupled dynamics

6. **Model Limitation**: Assumes no aging or thixotropic effects during relaxation

7. **Experimental Design**: Relaxation after pre-shear reveals both tau_b and tau_a independently