# VLB Startup Shear: NLSQ → NUTS

**PNAS Digital Rheometer Twin — startup shear at $\dot{\gamma}$ = 1 s⁻¹**

## Learning Objectives

- Fit startup stress $\sigma(t) = \eta_0 \dot{\gamma} (1 - e^{-k_d t})$
- Extract G₀ and k_d from transient shear data
- Understand why VLB shows no stress overshoot

## Data Source

PNAS 2022 Digital Rheometer Twin dataset (startup at 1 s⁻¹).

## Estimated Runtime

- Fast demo (1 chain): ~1 min
- Full run (4 chains): ~3 min

In [None]:
import sys
IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    %pip install -q rheojax openpyxl
    import os
    os.environ["JAX_ENABLE_X64"] = "true"

In [None]:
%matplotlib inline
import warnings
import time
from pathlib import Path

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display

from rheojax.core.jax_config import safe_import_jax, verify_float64
from rheojax.models import VLBLocal

jax, jnp = safe_import_jax()
verify_float64()
warnings.filterwarnings('ignore', category=FutureWarning)

# Tutorial utilities
import sys
sys.path.insert(0, str(Path('..').resolve()))
from utils.vlb_tutorial_utils import (
    get_bayesian_config, print_convergence, print_parameter_table,
    plot_trace_and_forest, get_output_dir, save_results, save_figure,
    load_pnas_startup,
)

print(f"JAX version: {jax.__version__}")

## 2. Load Experimental Data

In [None]:
GAMMA_DOT = 1.0  # Applied shear rate [1/s]
time_data, stress_data = load_pnas_startup(GAMMA_DOT, max_points=200)

print(f'Data points: {len(time_data)}')
print(f'Time range: {time_data.min():.3f} - {time_data.max():.1f} s')
print(f'Stress range: {stress_data.min():.2f} - {stress_data.max():.1f} Pa')

fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(time_data, stress_data, 'ko', markersize=3)
ax.set_xlabel('Time [s]')
ax.set_ylabel('Stress [Pa]')
ax.set_title(f'Startup Shear (d-gamma/dt = {GAMMA_DOT} 1/s)')
ax.grid(True, alpha=0.3)
plt.tight_layout()
display(fig)
plt.close(fig)

## 3. VLB Forward Model

VLB startup stress under constant shear rate $\dot{\gamma}$:

$$\sigma(t) = \frac{G_0}{k_d} \dot{\gamma} \left(1 - e^{-k_d t}\right) = \eta_0 \dot{\gamma} \left(1 - e^{-t/t_R}\right)$$

- **No stress overshoot**: monotonic approach to steady state
- Steady-state stress: $\sigma_{\infty} = \eta_0 \dot{\gamma}$
- Rise time $\sim 3 t_R = 3/k_d$

## 4. Step 1: NLSQ Point Estimation

In [None]:
model = VLBLocal()

t0 = time.time()
model.fit(time_data, stress_data, test_mode='startup', gamma_dot=GAMMA_DOT)
t_nlsq = time.time() - t0

print(f'NLSQ fit time: {t_nlsq:.2f} s')
print(f'G0  = {model.G0:.2f} Pa')
print(f'k_d = {model.k_d:.4f} 1/s')
print(f'eta_0 = {model.G0 / model.k_d:.2f} Pa*s')
print(f'sigma_ss = {model.G0 / model.k_d * GAMMA_DOT:.2f} Pa')

# Fit quality
time_fine = np.linspace(time_data.min(), time_data.max(), 300)
stress_pred_fine = model.predict(time_fine, test_mode='startup', gamma_dot=GAMMA_DOT)
stress_pred = model.predict(time_data, test_mode='startup', gamma_dot=GAMMA_DOT)
ss_res = np.sum((stress_data - np.array(stress_pred))**2)
ss_tot = np.sum((stress_data - np.mean(stress_data))**2)
r2 = 1 - ss_res / ss_tot
print(f'R-squared = {r2:.6f}')

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

ax1.plot(time_data, stress_data, 'ko', markersize=3, label='Data')
ax1.plot(time_fine, stress_pred_fine, 'r-', lw=2, label='VLB fit')
ax1.set_xlabel('Time [s]')
ax1.set_ylabel('Stress [Pa]')
ax1.set_title('Startup Stress Fit')
ax1.legend()
ax1.grid(True, alpha=0.3)

residuals = (stress_data - np.array(stress_pred)) / stress_data * 100
ax2.plot(time_data, residuals, 'ko-', markersize=3)
ax2.axhline(0, color='gray', ls='--')
ax2.set_xlabel('Time [s]')
ax2.set_ylabel('Relative residual [%]')
ax2.set_title('Residuals')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
display(fig)
plt.close(fig)

## 4. Step 2: Bayesian Inference (NUTS)

In [None]:
config = get_bayesian_config()
initial_values = {'G0': float(model.G0), 'k_d': float(model.k_d)}
print(f'Config: {config}')
print(f'Warm-start: {initial_values}')

t0 = time.time()
result = model.fit_bayesian(
    time_data, stress_data, test_mode='startup',
    gamma_dot=GAMMA_DOT,
    initial_values=initial_values,
    seed=42,
    **config,
)
t_bayes = time.time() - t0
print(f'\nBayesian inference time: {t_bayes:.1f} s')

## 5. Convergence Diagnostics

In [None]:
param_names = ["G0", "k_d"]
converged = print_convergence(result, param_names)

In [None]:
fig_trace, fig_forest = plot_trace_and_forest(result, param_names)
display(fig_trace)
plt.close(fig_trace)
display(fig_forest)
plt.close(fig_forest)

## 6. Posterior Summary

In [None]:
posterior = result.posterior_samples
nlsq_vals = {"G0": model.G0, "k_d": model.k_d}
print_parameter_table(["G0", "k_d"], nlsq_vals, posterior)

## 7. Posterior Predictive Check

In [None]:
posterior = result.posterior_samples
n_draws = min(200, len(posterior['G0']))

stress_samples = []
for i in range(n_draws):
    model.parameters.set_value('G0', float(posterior['G0'][i]))
    model.parameters.set_value('k_d', float(posterior['k_d'][i]))
    stress_samples.append(np.array(
        model.predict(time_fine, test_mode='startup', gamma_dot=GAMMA_DOT)
    ))

stress_arr = np.array(stress_samples)

fig, ax = plt.subplots(figsize=(9, 6))
ax.fill_between(time_fine,
    np.percentile(stress_arr, 2.5, axis=0),
    np.percentile(stress_arr, 97.5, axis=0),
    alpha=0.3, color='C0', label='95% CI')
ax.plot(time_fine, np.median(stress_arr, axis=0),
        'C0-', lw=2, label='Posterior median')
ax.plot(time_data, stress_data, 'ko', markersize=3, label='Data')
ax.set_xlabel('Time [s]')
ax.set_ylabel('Stress [Pa]')
ax.set_title('Posterior Predictive Check')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)
plt.tight_layout()
display(fig)
plt.close(fig)

## 8. Save Results

In [None]:
output_dir = get_output_dir("startup")
save_results(output_dir, model, result, ["G0", "k_d"], {"gamma_dot": GAMMA_DOT})
print("Done.")

## Key Takeaways

1. **Startup captures both G₀ and k_d** from transient approach to steady state
2. **No overshoot** in VLB — overshoot requires thixotropy or nonlinear stretching
3. **Steady-state stress** confirms $\eta_0 = G_0/k_d$ from flow curve
4. **Multiple shear rates** could be fitted simultaneously for tighter constraints

## Next

- **Notebook 15**: SAOS NLSQ → NUTS (best protocol for VLB)