# FIKH Model: Startup Shear

## Learning Objectives

1. Fit the **FIKH model** to startup shear (stress overshoot) data
2. Understand how **alpha_structure** affects stress overshoot timing and magnitude
3. Analyze fractional breakdown dynamics during flow inception
4. Compare FIKH predictions across different alpha values
5. Use Bayesian inference to quantify parameter uncertainty

## Prerequisites

- NB01: FIKH Flow Curve (calibrated parameters)
- Bayesian inference fundamentals (bayesian/01_bayesian_basics.ipynb)

## Runtime

- Fast demo (NUM_CHAINS=1, NUM_SAMPLES=500): ~3-5 minutes
- Full run (NUM_CHAINS=4, NUM_SAMPLES=2000): ~15-20 minutes

## 1. Setup

In [None]:
# Google Colab setup
import sys

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    %pip install -q rheojax
    import os
    os.environ["JAX_ENABLE_X64"] = "true"
    print("RheoJAX installed successfully.")

In [None]:
# Imports
%matplotlib inline
import os
import sys
import time
import warnings

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display

from rheojax.core.jax_config import safe_import_jax, verify_float64
from rheojax.models.fikh import FIKH

# Add examples/utils to path
# Robust path resolution for execution from any directory
from pathlib import Path
_nb_dir = Path(__file__).parent if "__file__" in dir() else Path.cwd()
_utils_candidates = [_nb_dir / ".." / "utils", Path("examples/utils"), _nb_dir.parent / "utils"]
for _p in _utils_candidates:
    if (_p / "fikh_tutorial_utils.py").exists():
        sys.path.insert(0, str(_p.resolve()))
        break
from fikh_tutorial_utils import (
    load_pnas_startup,
    save_fikh_results,
    print_convergence_summary,
    print_parameter_comparison,
    compute_fit_quality,
    get_fikh_param_names,
    plot_alpha_sweep,
    print_alpha_interpretation,
)

jax, jnp = safe_import_jax()
verify_float64()

print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")

## 2. Theory: Startup Shear with Fractional Memory

During startup shear at constant rate $\dot{\gamma}$, the FIKH model predicts:

1. **Initial elastic response**: $\sigma \approx G \cdot \gamma$
2. **Stress overshoot**: Peak occurs when plastic flow begins
3. **Steady-state approach**: Governed by structure evolution

### Alpha Effect on Startup

The fractional order $\alpha$ affects:
- **Overshoot timing**: Lower $\alpha$ → later peak (slower structure breakdown)
- **Overshoot magnitude**: Modified by memory kernel
- **Approach to steady-state**: Power-law vs exponential convergence

## 3. Load Data

In [None]:
# Load PNAS startup data at gamma_dot = 1.0 s^-1
GAMMA_DOT = 1.0
time_data, stress_data = load_pnas_startup(gamma_dot=GAMMA_DOT)

print(f"Data points: {len(time_data)}")
print(f"Time range: [{time_data.min():.4f}, {time_data.max():.2f}] s")
print(f"Stress range: [{stress_data.min():.2f}, {stress_data.max():.2f}] Pa")
print(f"Shear rate: {GAMMA_DOT} 1/s")

In [None]:
# Plot raw data
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(time_data, stress_data, "ko", markersize=5, label="Data")
ax.set_xlabel("Time [s]", fontsize=12)
ax.set_ylabel("Stress [Pa]", fontsize=12)
ax.set_title(f"PNAS Startup Shear Data ($\\dot{{\\gamma}}$ = {GAMMA_DOT} s$^{{-1}}$)", fontsize=13)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
display(fig)
plt.close(fig)

## 4. NLSQ Fitting

In [None]:
# Create and fit FIKH model
model = FIKH(include_thermal=False, alpha_structure=0.7)

# Compute strain from time and shear rate
strain_data = GAMMA_DOT * time_data

t0 = time.time()
model.fit(time_data, stress_data, test_mode="startup", strain=strain_data, method='scipy')
t_nlsq = time.time() - t0

param_names = get_fikh_param_names(include_thermal=False)

print(f"NLSQ fit time: {t_nlsq:.2f} s")
print(f"\nFitted parameters:")
for name in param_names:
    val = model.parameters.get_value(name)
    print(f"  {name:15s} = {val:.4g}")

In [None]:
# Compute fit quality and plot
stress_pred = model.predict(time_data, test_mode="startup", strain=strain_data)
metrics = compute_fit_quality(stress_data, stress_pred)

print(f"\nFit Quality:")
print(f"  R^2:   {metrics['R2']:.6f}")
print(f"  RMSE:  {metrics['RMSE']:.4g} Pa")

# Fine time array for smooth prediction
time_fine = np.linspace(0.01, time_data.max(), 300)
stress_pred_fine = model.predict(time_fine, test_mode="startup", strain=GAMMA_DOT * time_fine)

fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(time_data, stress_data, "ko", markersize=5, alpha=0.7, label="Data")
ax.plot(time_fine, stress_pred_fine, "-", lw=2.5, color="C0", label="FIKH fit")
ax.set_xlabel("Time [s]", fontsize=12)
ax.set_ylabel("Stress [Pa]", fontsize=12)
ax.set_title(f"FIKH Startup Fit (R$^2$ = {metrics['R2']:.5f})", fontsize=13)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
display(fig)
plt.close(fig)

## 5. Alpha Exploration

In [None]:
# Alpha sweep for startup shear
alpha_values = [0.3, 0.5, 0.7, 0.9, 0.99]

fig = plot_alpha_sweep(
    model,
    protocol="startup",
    alpha_values=alpha_values,
    x_data=time_fine,
    gamma_dot=GAMMA_DOT,
    figsize=(14, 5),
)

# Add data to left panel
fig.axes[0].plot(time_data, stress_data, "ko", markersize=3, alpha=0.5, label="Data")
fig.axes[0].legend(fontsize=8, loc="best")

display(fig)
plt.close(fig)

In [None]:
# Physical interpretation
fitted_alpha = model.parameters.get_value("alpha_structure")
print_alpha_interpretation(fitted_alpha)

## 6. Bayesian Inference

In [None]:
# Bayesian inference with NLSQ warm-start
initial_values = {name: model.parameters.get_value(name) for name in param_names}

NUM_WARMUP = 200
NUM_SAMPLES = 500
NUM_CHAINS = 1
# NUM_WARMUP = 1000; NUM_SAMPLES = 2000; NUM_CHAINS = 4  # production

print(f"Running NUTS: {NUM_WARMUP} warmup + {NUM_SAMPLES} samples x {NUM_CHAINS} chain(s)")
t0 = time.time()
result = model.fit_bayesian(
    time_data,
    stress_data,
    test_mode="startup",
    strain=strain_data,
    num_warmup=NUM_WARMUP,
    num_samples=NUM_SAMPLES,
    num_chains=NUM_CHAINS,
    initial_values=initial_values,
    seed=42,
)
t_bayes = time.time() - t0
print(f"\nBayesian inference time: {t_bayes:.1f} s")

In [None]:
# Convergence diagnostics
all_pass = print_convergence_summary(result, param_names)

In [None]:
# Trace plots
idata = result.to_inference_data()
key_params = ["G", "sigma_y0", "tau_thix", "alpha_structure"]
axes = az.plot_trace(idata, var_names=key_params, figsize=(12, 8))
fig = axes.ravel()[0].figure
fig.suptitle("Trace Plots (Key Parameters)", fontsize=14, y=1.00)
plt.tight_layout()
display(fig)
plt.close(fig)

In [None]:
# Posterior predictive
posterior = result.posterior_samples
# FAST mode: 10 draws for batch testing
# FULL mode: 100 draws for publication
FAST_POSTERIOR_PREDICTIVE = True
n_draws = min(10 if FAST_POSTERIOR_PREDICTIVE else 100, len(list(posterior.values())[0]))

pred_samples = []
for i in range(n_draws):
    # Set model parameters from posterior sample
    for name in param_names:
        model.parameters.set_value(name, float(posterior[name][i]))
    pred_i = model.predict(time_fine, test_mode="startup", strain=GAMMA_DOT * time_fine)
    pred_samples.append(np.array(pred_i))

pred_samples = np.array(pred_samples)
pred_median = np.median(pred_samples, axis=0)
pred_lo = np.percentile(pred_samples, 2.5, axis=0)
pred_hi = np.percentile(pred_samples, 97.5, axis=0)

fig, ax = plt.subplots(figsize=(10, 6))
ax.fill_between(time_fine, pred_lo, pred_hi, alpha=0.3, color="C0", label="95% CI")
ax.plot(time_fine, pred_median, "-", lw=2, color="C0", label="Posterior median")
ax.plot(time_data, stress_data, "ko", markersize=5, label="Data")
ax.set_xlabel("Time [s]", fontsize=12)
ax.set_ylabel("Stress [Pa]", fontsize=12)
ax.set_title("FIKH Startup Posterior Predictive", fontsize=13)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
display(fig)
plt.close(fig)

In [None]:
# Parameter comparison
print_parameter_comparison(model, posterior, param_names)

## 7. Save Results

In [None]:
save_fikh_results(model, result, "fikh", "startup", param_names)
print("\nResults saved for downstream analysis.")

## Key Takeaways

1. **Startup shear reveals fractional dynamics** through stress overshoot behavior
2. **Lower alpha** → later overshoot peak, slower structure breakdown
3. **Higher alpha** → approaches classical IKH exponential behavior
4. **Startup data constrains kinematic hardening** parameters (C, gamma_dyn)
5. **Combined with flow curve data** provides better alpha constraints

### Next Steps

- **NB03**: Stress relaxation (power-law tails show alpha most clearly)
- **NB04**: Creep response (delayed yielding with memory)
- **NB06**: LAOS (intra-cycle memory effects)