# TNT Sticky Rouse: Flow Curve Fitting

## Objectives
- Fit TNT Sticky Rouse model to flow curve data
- Analyze sticker-dominated vs Rouse-dominated regimes
- Visualize effective relaxation times and mode contributions
- Demonstrate Bayesian inference for associative polymer systems

## Setup

In [None]:
import os
import sys
import time

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    %pip install -q rheojax

import numpy as np
import matplotlib.pyplot as plt
import arviz as az

from rheojax.core.jax_config import safe_import_jax
jax, jnp = safe_import_jax()
from rheojax.core.jax_config import verify_float64
verify_float64()

from rheojax.models.tnt import TNTStickyRouse

sys.path.insert(0, os.path.join("..", "utils"))
from tnt_tutorial_utils import (
    load_ml_ikh_flow_curve,
    compute_fit_quality,
    print_convergence_summary,
    print_parameter_comparison,
    save_tnt_results,
    get_tnt_sticky_rouse_param_names,
    plot_sticky_rouse_effective_times,
    plot_mode_decomposition,
)

print("Setup complete. JAX devices:", jax.devices())

## Theory: Sticky Rouse Model

The Sticky Rouse model describes polymers with transient physical crosslinks (stickers):

**Key Physics:**
- Rouse chain has N modes with relaxation times τ_R,k
- Physical crosslinks (stickers) have lifetime τ_s
- Effective relaxation time: **τ_eff,k = max(τ_R,k, τ_s)**

**Regime Analysis:**
- When τ_s > τ_R,k: Sticker dominates → all modes relax on τ_s (sticker plateau)
- When τ_R,k > τ_s: Rouse dynamics dominate → mode-specific relaxation

**Steady-State Viscosity:**
$$\sigma(\dot{\gamma}) = \sum_k \frac{G_k \tau_{\textrm{eff},k} \dot{\gamma}}{1 + (\tau_{\textrm{eff},k} \dot{\gamma})^2} + \eta_s \dot{\gamma}$$

**Parameters (n_modes=3):**
- G_0, τ_R,0: First mode modulus and Rouse time
- G_1, τ_R,1: Second mode
- G_2, τ_R,2: Third mode
- τ_s: Sticker lifetime (critical parameter)
- η_s: Solvent viscosity

## Load Data

In [None]:
# Load ARES flow curve data
gamma_dot, stress = load_ml_ikh_flow_curve("ARES_up")

print(f"Data shape: {len(gamma_dot)} points")
print(f"Shear rate range: {gamma_dot.min():.2e} - {gamma_dot.max():.2e} s^-1")
print(f"Stress range: {stress.min():.2e} - {stress.max():.2e} Pa")

# Plot raw data
fig, ax = plt.subplots(figsize=(8, 6))
ax.loglog(gamma_dot, stress, 'ko', label='ARES data', markersize=8)
ax.set_xlabel('Shear Rate (s$^{-1}$)', fontsize=12)
ax.set_ylabel('Stress (Pa)', fontsize=12)
ax.set_title('Flow Curve: Carbopol Gel', fontsize=14, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
display(fig)
plt.close(fig)

print("\n⚠️  Note: Carbopol is not an ionomer. This demonstrates the Sticky Rouse workflow and sticker-mode analysis.")

## NLSQ Fitting

In [None]:
# Initialize model with 3 modes
model = TNTStickyRouse(n_modes=3)
param_names = get_tnt_sticky_rouse_param_names(n_modes=3)
print(f"Model parameters ({len(param_names)}): {param_names}")

# Fit using NLSQ
print("\nFitting with NLSQ...")
start_time = time.time()
result_nlsq = model.fit(gamma_dot, stress, test_mode="flow_curve")
fit_time = time.time() - start_time

print(f"\nFit completed in {fit_time:.2f} seconds")
print(f"R² = {result_nlsq.r_squared:.6f}")
print(f"RMSE = {result_nlsq.rmse:.4e} Pa")

## Fitted Parameters

In [None]:
# Extract fitted parameters
params_nlsq = {name: result_nlsq.params[name].value for name in param_names}

print("\nFitted Parameters:")
print("-" * 50)
for name, value in params_nlsq.items():
    if 'tau' in name:
        print(f"{name:10s} = {value:12.4e} s")
    elif 'eta' in name:
        print(f"{name:10s} = {value:12.4e} Pa·s")
    else:
        print(f"{name:10s} = {value:12.4e} Pa")

# Compute effective relaxation times
tau_s = params_nlsq['tau_s']
print(f"\nSticker lifetime: τ_s = {tau_s:.4e} s")
print("\nEffective Relaxation Times:")
print("-" * 50)
for i in range(3):
    tau_R = params_nlsq[f'tau_R_{i}']
    tau_eff = max(tau_R, tau_s)
    regime = "STICKER-DOMINATED" if tau_s > tau_R else "ROUSE-DOMINATED"
    print(f"Mode {i}: τ_R = {tau_R:.4e} s, τ_eff = {tau_eff:.4e} s ({regime})")

## Physical Analysis: Effective Relaxation Times

In [None]:
# Visualize sticker-mode interaction
fig = plot_sticky_rouse_effective_times(model)
display(fig)
plt.close(fig)

## NLSQ Prediction vs Data

In [None]:
# Generate predictions
gamma_dot_fine = np.logspace(np.log10(gamma_dot.min()), np.log10(gamma_dot.max()), 200)
stress_pred = model.predict_flow_curve(gamma_dot_fine)

# Plot
fig, ax = plt.subplots(figsize=(10, 7))
ax.loglog(gamma_dot, stress, 'ko', label='Data', markersize=8, zorder=3)
ax.loglog(gamma_dot_fine, stress_pred, 'r-', label='NLSQ Fit', linewidth=2, zorder=2)
ax.set_xlabel('Shear Rate (s$^{-1}$)', fontsize=12)
ax.set_ylabel('Stress (Pa)', fontsize=12)
ax.set_title(f'Flow Curve Fit (R² = {result_nlsq.r_squared:.6f})', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
display(fig)
plt.close(fig)

## Mode Decomposition

In [None]:
# Plot individual mode contributions
fig = plot_mode_decomposition(model, gamma_dot_fine, "flow_curve")
display(fig)
plt.close(fig)

## Sticker-Limited Plateau Analysis

In [None]:
# Analyze sticker plateau regime
tau_s = params_nlsq['tau_s']
sticker_shear_rate = 1.0 / tau_s
print(f"Sticker-limited plateau region: γ̇ ~ 1/τ_s = {sticker_shear_rate:.4e} s^-1")

# Count modes in sticker-dominated regime
n_sticker_modes = sum(1 for i in range(3) if params_nlsq[f'tau_R_{i}'] < tau_s)
print(f"\nNumber of sticker-dominated modes: {n_sticker_modes}/3")

# Plateau modulus estimate (sum of sticker-dominated modes)
G_plateau = sum(params_nlsq[f'G_{i}'] for i in range(3) if params_nlsq[f'tau_R_{i}'] < tau_s)
print(f"Sticker plateau modulus: G_plateau ≈ {G_plateau:.4e} Pa")

## Bayesian Inference

In [None]:
# Bayesian configuration
NUM_WARMUP = 200
NUM_SAMPLES = 500
NUM_CHAINS = 1

print(f"Running Bayesian inference...")
print(f"Warmup: {NUM_WARMUP}, Samples: {NUM_SAMPLES}, Chains: {NUM_CHAINS}")

start_time = time.time()
result_bayes = model.fit_bayesian(
    gamma_dot, stress,
    test_mode="flow_curve",
    num_warmup=NUM_WARMUP,
    num_samples=NUM_SAMPLES,
    num_chains=NUM_CHAINS,
    seed=42
)
bayes_time = time.time() - start_time

print(f"\nBayesian inference completed in {bayes_time:.2f} seconds")
print(f"Total samples: {NUM_SAMPLES * NUM_CHAINS}")

## Convergence Diagnostics

In [None]:
# Print convergence summary
print_convergence_summary(result_bayes, param_names)

## Parameter Comparison: NLSQ vs Bayesian

In [None]:
# Compare point estimates
print_parameter_comparison(result_nlsq, result_bayes, param_names)

## ArviZ: Trace Plot

In [None]:
# Convert to ArviZ InferenceData
idata = az.from_dict(posterior={name: result_bayes.posterior_samples[name][None, :] for name in param_names})

# Trace plot
axes = az.plot_trace(idata, compact=False, figsize=(12, 2*len(param_names)))
fig = axes.ravel()[0].figure
fig.suptitle('MCMC Trace Plot', fontsize=14, fontweight='bold', y=1.001)
fig.tight_layout()
display(fig)
plt.close(fig)

## ArviZ: Posterior Distributions

In [None]:
# Posterior plot
axes = az.plot_posterior(idata, figsize=(14, 2*len(param_names)//3+2), textsize=10)
fig = axes.ravel()[0].figure
fig.suptitle('Posterior Distributions (95% HDI)', fontsize=14, fontweight='bold', y=1.001)
fig.tight_layout()
display(fig)
plt.close(fig)

## ArviZ: Pair Plot

In [None]:
# Pair plot for correlations (select key parameters to avoid overcrowding)
key_params = ['G_0', 'tau_R_0', 'tau_s', 'eta_s']
axes = az.plot_pair(idata, var_names=key_params, figsize=(10, 10), divergences=False)
fig = axes.ravel()[0].figure
fig.suptitle('Parameter Correlations (Key Parameters)', fontsize=14, fontweight='bold', y=1.001)
display(fig)
plt.close(fig)

## Posterior Predictive Distribution

In [None]:
# Generate predictions from posterior samples
posterior = result_bayes.posterior_samples
n_draws = min(200, NUM_SAMPLES)
indices = np.linspace(0, NUM_SAMPLES-1, n_draws, dtype=int)

predictions = []
for i in indices:
    params_i = jnp.array([posterior[name][i] for name in param_names])
    pred_i = model.model_function(jnp.array(gamma_dot_fine), params_i, test_mode="flow_curve")
    predictions.append(np.array(pred_i))

predictions = np.array(predictions)
pred_mean = np.mean(predictions, axis=0)
pred_lower = np.percentile(predictions, 2.5, axis=0)
pred_upper = np.percentile(predictions, 97.5, axis=0)

# Plot
fig, ax = plt.subplots(figsize=(10, 7))
ax.loglog(gamma_dot, stress, 'ko', label='Data', markersize=8, zorder=3)
ax.loglog(gamma_dot_fine, pred_mean, 'b-', label='Posterior Mean', linewidth=2, zorder=2)
ax.fill_between(gamma_dot_fine, pred_lower, pred_upper, alpha=0.3, color='blue', label='95% Credible Interval')
ax.set_xlabel('Shear Rate (s$^{-1}$)', fontsize=12)
ax.set_ylabel('Stress (Pa)', fontsize=12)
ax.set_title('Posterior Predictive Distribution', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
display(fig)
plt.close(fig)

## Physical Interpretation

In [None]:
# Extract posterior means
params_bayes = {name: float(np.mean(posterior[name])) for name in param_names}
tau_s_bayes = params_bayes['tau_s']

print("Physical Interpretation (Posterior Means):")
print("=" * 60)
print(f"\nSticker Lifetime: τ_s = {tau_s_bayes:.4e} s")
print(f"Critical Shear Rate: γ̇_c ~ 1/τ_s = {1.0/tau_s_bayes:.4e} s^-1")

print("\nMode-by-Mode Analysis:")
print("-" * 60)
for i in range(3):
    G_i = params_bayes[f'G_{i}']
    tau_R_i = params_bayes[f'tau_R_{i}']
    tau_eff_i = max(tau_R_i, tau_s_bayes)
    ratio = tau_s_bayes / tau_R_i
    
    print(f"\nMode {i}:")
    print(f"  G_{i} = {G_i:.4e} Pa")
    print(f"  τ_R,{i} = {tau_R_i:.4e} s")
    print(f"  τ_eff,{i} = {tau_eff_i:.4e} s")
    print(f"  τ_s/τ_R = {ratio:.2f}")
    
    if ratio > 1.0:
        print(f"  ✓ STICKER-DOMINATED (τ_s > τ_R): Mode contributes to sticker plateau")
    else:
        print(f"  ✓ ROUSE-DOMINATED (τ_R > τ_s): Mode exhibits intrinsic Rouse dynamics")

# Estimate zero-shear viscosity
eta_0 = sum(params_bayes[f'G_{i}'] * max(params_bayes[f'tau_R_{i}'], tau_s_bayes) for i in range(3)) + params_bayes['eta_s']
print(f"\nZero-Shear Viscosity: η_0 = {eta_0:.4e} Pa·s")

# Total modulus
G_total = sum(params_bayes[f'G_{i}'] for i in range(3))
print(f"Total Modulus: G_total = {G_total:.4e} Pa")

## Save Results

In [None]:
# Save results to disk
output_path = save_tnt_results(model, result_bayes, "sticky_rouse", "flow_curve", param_names)
print(f"Results saved to: {output_path}")

## Key Takeaways

1. **Sticker Physics**: The Sticky Rouse model captures transient crosslink dynamics through the sticker lifetime parameter τ_s

2. **Effective Relaxation**: Each mode's effective relaxation time is τ_eff,k = max(τ_R,k, τ_s), creating a sticker-limited plateau when τ_s dominates

3. **Regime Separation**:
   - Low γ̇ (γ̇ ≪ 1/τ_s): Sticker plateau, collective mode response
   - High γ̇ (γ̇ ≫ 1/τ_s): Rouse dynamics, mode-specific relaxation

4. **Model Limitations**: Carbopol is not an ionomer - this demonstrates workflow. For true ionomer data, expect stronger sticker signatures

5. **Bayesian Insights**: Posterior distributions quantify uncertainty in sticker lifetime and mode separation

6. **Parameter Count**: 8 parameters (3 modes) provide rich description but require sufficient data coverage across shear rate decades