# TNT Sticky Rouse: Stress Relaxation

## Objectives
- Fit TNT Sticky Rouse model to stress relaxation data
- Analyze multi-exponential relaxation spectrum
- Understand sticker-truncated spectrum and plateau formation
- Demonstrate mode-resolved relaxation dynamics

## Setup

In [None]:
import os
import sys
import time

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    %pip install -q rheojax

import numpy as np
import matplotlib.pyplot as plt
import arviz as az

from rheojax.core.jax_config import safe_import_jax
jax, jnp = safe_import_jax()
from rheojax.core.jax_config import verify_float64
verify_float64()

from rheojax.models.tnt import TNTStickyRouse

sys.path.insert(0, os.path.join("..", "utils"))
from tnt_tutorial_utils import (
    load_laponite_relaxation,
    compute_fit_quality,
    print_convergence_summary,
    print_parameter_comparison,
    save_tnt_results,
    get_tnt_sticky_rouse_param_names,
    plot_sticky_rouse_effective_times,
    plot_mode_decomposition,
)

print("Setup complete. JAX devices:", jax.devices())

## Theory: Multi-Exponential Relaxation

**Relaxation Modulus:**

$$G(t) = \sum_k G_k \exp\left(-\frac{t}{\tau_{\textrm{eff},k}}\right)$$

where τ_eff,k = max(τ_R,k, τ_s)

**Sticker-Truncated Spectrum:**
- Modes with τ_R,k < τ_s: All relax on timescale τ_s → plateau in G(t)
- Modes with τ_R,k > τ_s: Individual exponential decay

**Key Physics:**
- **Early times (t ≪ τ_s)**: All modes frozen → G(t) ≈ G_total = Σ G_k
- **Intermediate times (t ~ τ_s)**: Sticker-dominated modes relax collectively → plateau
- **Late times (t ≫ τ_s)**: Rouse-dominated modes decay individually

**Observable Signatures:**
- Initial elastic modulus: G(0) = Σ G_k
- Sticker plateau: G_plateau = Σ G_k (for modes with τ_R,k < τ_s)
- Long-time decay: Determined by slowest Rouse mode

## Load Data

In [None]:
# Load Laponite relaxation data (aged 1800s = 30 min)
time_data, G_t = load_laponite_relaxation(aging_time=1800)

print(f"Data shape: {len(time_data)} points")
print(f"Time range: {time_data.min():.2e} - {time_data.max():.2e} s")
print(f"G(t) range: {G_t.min():.2e} - {G_t.max():.2e} Pa")
print(f"Aging time: 1800 s (30 minutes)")

# Plot raw data
fig, ax = plt.subplots(figsize=(8, 6))
ax.loglog(time_data, G_t, 'ko', label='Laponite data', markersize=6)
ax.set_xlabel('Time (s)', fontsize=12)
ax.set_ylabel('G(t) (Pa)', fontsize=12)
ax.set_title('Stress Relaxation: Laponite Gel', fontsize=14, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
plt.close("all")
plt.close('all')

## NLSQ Fitting

In [None]:
# Initialize model
model = TNTStickyRouse(n_modes=3)
param_names = get_tnt_sticky_rouse_param_names(n_modes=3)
print(f"Model parameters ({len(param_names)}): {param_names}")

# Fit using NLSQ
print("\nFitting with NLSQ...")
start_time = time.time()
model.fit(time_data, G_t, test_mode="relaxation", method='scipy')
fit_time = time.time() - start_time

# Compute metrics
G_pred_train = model.predict(time_data, test_mode="relaxation")
metrics_nlsq = compute_fit_quality(G_t, G_pred_train)

print(f"\nFit completed in {fit_time:.2f} seconds")
print(f"R² = {metrics_nlsq['R2']:.6f}")
print(f"RMSE = {metrics_nlsq['RMSE']:.4e} Pa")

## Fitted Parameters

In [None]:
# Extract fitted parameters
params_nlsq = {name: model.parameters.get_value(name) for name in param_names}

print("\nFitted Parameters:")
print("-" * 50)
for name, value in params_nlsq.items():
    if 'tau' in name:
        print(f"{name:10s} = {value:12.4e} s")
    elif 'eta' in name:
        print(f"{name:10s} = {value:12.4e} Pa·s")
    else:
        print(f"{name:10s} = {value:12.4e} Pa")

# Compute total modulus
G_total = sum(params_nlsq[f'G_{i}'] for i in range(3))
print(f"\nTotal elastic modulus: G(0) = {G_total:.4e} Pa")

# Analyze effective relaxation times
tau_s = params_nlsq['tau_s']
print(f"Sticker lifetime: τ_s = {tau_s:.4e} s")
print("\nRelaxation Spectrum:")
print("-" * 50)
for i in range(3):
    tau_R = params_nlsq[f'tau_R_{i}']
    tau_eff = max(tau_R, tau_s)
    G_i = params_nlsq[f'G_{i}']
    regime = "STICKER-TRUNCATED" if tau_s > tau_R else "ROUSE"
    print(f"Mode {i}: G = {G_i:.3e} Pa, τ_eff = {tau_eff:.3e} s ({regime})")

## NLSQ Prediction vs Data

In [None]:
# Generate predictions
time_fine = np.logspace(np.log10(time_data.min()), np.log10(time_data.max()), 300)
G_pred = model.predict(time_fine, test_mode="relaxation")

# Plot
fig, ax = plt.subplots(figsize=(10, 7))
ax.loglog(time_data, G_t, 'ko', label='Data', markersize=6, zorder=3)
ax.loglog(time_fine, G_pred, 'r-', label='NLSQ Fit', linewidth=2, zorder=2)
ax.set_xlabel('Time (s)', fontsize=12)
ax.set_ylabel('G(t) (Pa)', fontsize=12)
ax.set_title(f'Stress Relaxation Fit (R² = {metrics_nlsq['R2']:.6f})', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
plt.close("all")
plt.close('all')

## Multi-Exponential Decomposition

In [None]:
# Plot individual mode contributions
fig, ax = plt.subplots(figsize=(10, 7))

# Total prediction
ax.loglog(time_fine, G_pred, 'k-', label='Total', linewidth=2.5, zorder=5)

# Individual modes
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1']
for i in range(3):
    G_i = params_nlsq[f'G_{i}']
    tau_R_i = params_nlsq[f'tau_R_{i}']
    tau_eff_i = max(tau_R_i, tau_s)
    
    G_i_t = G_i * np.exp(-time_fine / tau_eff_i)
    
    regime = "sticker" if tau_s > tau_R_i else "Rouse"
    ax.loglog(time_fine, G_i_t, '--', color=colors[i], 
             label=f'Mode {i} ({regime}, τ_eff={tau_eff_i:.2e}s)', linewidth=1.5)

ax.set_xlabel('Time (s)', fontsize=12)
ax.set_ylabel('G(t) (Pa)', fontsize=12)
ax.set_title('Multi-Exponential Decomposition', fontsize=14, fontweight='bold')
ax.legend(fontsize=9, loc='best')
ax.grid(True, alpha=0.3)
plt.close("all")
plt.close('all')

## Sticker-Truncated Spectrum Visualization

In [None]:
# Analyze sticker plateau formation
tau_s = params_nlsq['tau_s']
n_sticker_modes = sum(1 for i in range(3) if params_nlsq[f'tau_R_{i}'] < tau_s)

print(f"Sticker-truncated modes: {n_sticker_modes}/3")

if n_sticker_modes > 0:
    # Sticker plateau modulus
    G_plateau = sum(params_nlsq[f'G_{i}'] for i in range(3) if params_nlsq[f'tau_R_{i}'] < tau_s)
    
    # Compute G(t) with and without sticker truncation
    G_without_sticker = sum(params_nlsq[f'G_{i}'] * np.exp(-time_fine / params_nlsq[f'tau_R_{i}']) 
                            for i in range(3))
    
    fig, ax = plt.subplots(figsize=(10, 7))
    ax.loglog(time_data, G_t, 'ko', label='Data', markersize=6, alpha=0.5)
    ax.loglog(time_fine, G_pred, 'b-', label='With Stickers (Fit)', linewidth=2)
    ax.loglog(time_fine, G_without_sticker, 'r--', label='Without Stickers (Rouse only)', linewidth=2)
    
    # Highlight sticker plateau region
    ax.axhline(G_plateau, color='green', linestyle=':', linewidth=1.5, 
               label=f'Sticker Plateau (G_s={G_plateau:.2e} Pa)')
    ax.axvline(tau_s, color='purple', linestyle=':', linewidth=1.5, 
               label=f'Sticker timescale (τ_s={tau_s:.2e}s)')
    
    ax.set_xlabel('Time (s)', fontsize=12)
    ax.set_ylabel('G(t) (Pa)', fontsize=12)
    ax.set_title('Sticker-Truncated vs Pure Rouse Relaxation', fontsize=14, fontweight='bold')
    ax.legend(fontsize=9)
    ax.grid(True, alpha=0.3)
    plt.close("all")
    plt.close('all')
    
    print(f"\nSticker plateau modulus: G_plateau = {G_plateau:.4e} Pa")
    print(f"Plateau fraction: G_plateau/G_total = {G_plateau/G_total:.2%}")
else:
    print("\nNo sticker-truncated modes detected. All modes exhibit intrinsic Rouse relaxation.")

## Effective Relaxation Time Analysis

In [None]:
# Visualize sticker-mode interaction
fig = plot_sticky_rouse_effective_times(model)
plt.close("all")
plt.close('all')

## Bayesian Inference

In [None]:
# FAST_MODE: Use reduced MCMC for quick validation
# FAST_MODE controls Bayesian inference (env var FAST_MODE, default=1)
FAST_MODE = os.environ.get("FAST_MODE", "0") == "1"

# Configuration
NUM_WARMUP = 200
NUM_SAMPLES = 500
NUM_CHAINS = 1

if FAST_MODE:
    print("FAST_MODE: Skipping Bayesian inference (JIT compilation takes >600s)")
    print("To run Bayesian analysis, run with FAST_MODE=0")
    # Create a placeholder result with current NLSQ parameters
    class BayesianResult:
        def __init__(self, model, param_names):
            self.posterior_samples = {name: np.array([model.parameters.get_value(name)] * NUM_SAMPLES) for name in param_names}
    result_bayes = BayesianResult(model, param_names)
    bayes_time = 0.0
else:
    print(f"Running NUTS with {NUM_CHAINS} chain(s)...")
    print(f"Warmup: {NUM_WARMUP} samples, Sampling: {NUM_SAMPLES} samples")
    
    start_time = time.time()
    result_bayes = model.fit_bayesian(
        time_data, G_t,
        test_mode='relaxation',
        
        num_warmup=NUM_WARMUP,
        num_samples=NUM_SAMPLES,
        num_chains=NUM_CHAINS,
        seed=42
    )
    bayes_time = time.time() - start_time
    
    print(f"\nBayesian inference completed in {bayes_time:.1f} seconds")


## Convergence Diagnostics

In [None]:
# Skip convergence diagnostics in CI mode
if not FAST_MODE:
    print_convergence_summary(result_bayes, param_names)
else:
    print("FAST_MODE: Skipping convergence diagnostics")


## Parameter Comparison: NLSQ vs Bayesian

In [None]:
# Compare point estimates
print_parameter_comparison(model, result_bayes.posterior_samples, param_names)

## ArviZ: Trace Plot

In [None]:
# Skip trace plot in CI mode
if not FAST_MODE:
    # Convert to ArviZ InferenceData
    idata = az.from_dict(posterior={name: result_bayes.posterior_samples[name][None, :] for name in param_names})
    
    # Trace plot
    axes = az.plot_trace(idata, compact=False, figsize=(12, 2*len(param_names)))
    fig = axes.ravel()[0].figure
    fig.suptitle('MCMC Trace Plot', fontsize=14, fontweight='bold', y=1.001)
    fig.tight_layout()
    plt.close("all")
    plt.close('all')
else:
    print("FAST_MODE: Skipping trace plot")


## ArviZ: Posterior Distributions

In [None]:
# Skip ArviZ plot in CI mode
if not FAST_MODE:
    # Posterior plot
    axes = az.plot_posterior(idata, figsize=(14, 2*len(param_names)//3+2), textsize=10)
    fig = axes.ravel()[0].figure
    fig.suptitle('Posterior Distributions (95% HDI)', fontsize=14, fontweight='bold', y=1.001)
    fig.tight_layout()
    plt.close("all")
    plt.close('all')
else:
    print("FAST_MODE: Skipping ArviZ plot")


## ArviZ: Pair Plot

In [None]:
# Skip pair plot in CI mode
if not FAST_MODE:
    # Pair plot for correlations
    key_params = ['G_0', 'tau_R_0', 'tau_s', 'eta_s']
    axes = az.plot_pair(idata, var_names=key_params, figsize=(10, 10), divergences=False)
    fig = axes.ravel()[0].figure
    fig.suptitle('Parameter Correlations (Key Parameters)', fontsize=14, fontweight='bold', y=1.001)
    plt.close("all")
    plt.close('all')
else:
    print("FAST_MODE: Skipping pair plot")


## Posterior Predictive Distribution

In [None]:
# Generate predictions from posterior samples
posterior = result_bayes.posterior_samples
n_draws = min(200, NUM_SAMPLES)
indices = np.linspace(0, NUM_SAMPLES-1, n_draws, dtype=int)

predictions = []
for i in indices:
    for name in param_names:
        model.parameters.set_value(name, float(posterior[name][i]))
    pred_i = model.predict(time_fine, test_mode="relaxation")
    predictions.append(np.array(pred_i))

predictions = np.array(predictions)
pred_mean = np.mean(predictions, axis=0)
pred_lower = np.percentile(predictions, 2.5, axis=0)
pred_upper = np.percentile(predictions, 97.5, axis=0)

# Plot
fig, ax = plt.subplots(figsize=(10, 7))
ax.loglog(time_data, G_t, 'ko', label='Data', markersize=6, zorder=3)
ax.loglog(time_fine, pred_mean, 'b-', label='Posterior Mean', linewidth=2, zorder=2)
ax.fill_between(time_fine, pred_lower, pred_upper, alpha=0.3, color='blue', label='95% Credible Interval')
ax.set_xlabel('Time (s)', fontsize=12)
ax.set_ylabel('G(t) (Pa)', fontsize=12)
ax.set_title('Posterior Predictive Distribution', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
plt.close("all")
plt.close('all')

## Physical Interpretation

In [None]:
# Extract posterior means
params_bayes = {name: float(np.mean(posterior[name])) for name in param_names}
tau_s_bayes = params_bayes['tau_s']

print("Physical Interpretation (Posterior Means):")
print("=" * 60)

# Total modulus
G_total_bayes = sum(params_bayes[f'G_{i}'] for i in range(3))
print(f"\nInitial Elastic Modulus: G(0) = {G_total_bayes:.4e} Pa")
print(f"Sticker Lifetime: τ_s = {tau_s_bayes:.4e} s")

print("\nRelaxation Spectrum Analysis:")
print("-" * 60)
for i in range(3):
    G_i = params_bayes[f'G_{i}']
    tau_R_i = params_bayes[f'tau_R_{i}']
    tau_eff_i = max(tau_R_i, tau_s_bayes)
    weight = G_i / G_total_bayes
    
    print(f"\nMode {i}:")
    print(f"  Modulus: G_{i} = {G_i:.4e} Pa ({weight:.1%} of total)")
    print(f"  Rouse time: τ_R,{i} = {tau_R_i:.4e} s")
    print(f"  Effective time: τ_eff,{i} = {tau_eff_i:.4e} s")
    
    if tau_s_bayes > tau_R_i:
        print(f"  ✓ STICKER-TRUNCATED: Relaxation limited by sticker lifetime")
        print(f"    Without stickers would relax {tau_s_bayes/tau_R_i:.1f}x faster")
    else:
        print(f"  ✓ ROUSE RELAXATION: Intrinsic chain dynamics dominate")
        print(f"    Stickers dissolve {tau_R_i/tau_s_bayes:.1f}x faster than chain relaxes")

# Sticker plateau analysis
n_sticker_modes = sum(1 for i in range(3) if params_bayes[f'tau_R_{i}'] < tau_s_bayes)
if n_sticker_modes > 0:
    G_plateau = sum(params_bayes[f'G_{i}'] for i in range(3) if params_bayes[f'tau_R_{i}'] < tau_s_bayes)
    print(f"\nSticker Plateau:")
    print(f"  Number of truncated modes: {n_sticker_modes}/3")
    print(f"  Plateau modulus: G_plateau = {G_plateau:.4e} Pa")
    print(f"  Plateau fraction: {G_plateau/G_total_bayes:.1%} of total modulus")
    print(f"  Plateau lifetime: τ_s = {tau_s_bayes:.4e} s")

# Longest relaxation time
tau_max = max(max(params_bayes[f'tau_R_{i}'], tau_s_bayes) for i in range(3))
print(f"\nLongest Relaxation Time: τ_max = {tau_max:.4e} s")
print(f"Terminal relaxation: G(t→∞) decays as exp(-t/{tau_max:.2e})")

## Save Results

In [None]:
# Save results to disk
output_path = save_tnt_results(model, result_bayes, "sticky_rouse", "relaxation", param_names)
print(f"Results saved to: {output_path}")

## Key Takeaways

1. **Multi-Exponential Nature**: Relaxation modulus is sum of exponentials with distinct timescales

2. **Sticker Truncation**: Fast Rouse modes (τ_R,k < τ_s) are limited by sticker lifetime → collective decay

3. **Plateau Formation**: Modes with τ_R < τ_s create a plateau at intermediate times in G(t)

4. **Spectrum Modification**: Stickers fundamentally alter the relaxation spectrum by truncating fast processes

5. **Timescale Hierarchy**: The effective relaxation spectrum reveals which modes are sticker-limited vs Rouse-limited

6. **Bayesian Insights**: Posterior uncertainty quantifies confidence in mode assignment and spectrum structure