# SGR Creep Compliance: Power-Law Creep in Biological Soft Matter

**Learning Objectives:**
- Fit SGR creep model to biological soft matter data
- Understand J(t) ~ t^(2-x) power-law creep behavior
- Handle limited data in Bayesian inference (wider posteriors)

**Prerequisites:** `basic/01_getting_started.ipynb`, `bayesian/01_bayesian_intro.ipynb`

**Runtime:**
- Fast mode (1 chain): ~2 min
- Full mode (4 chains): ~5 min

## 1. Setup

In [None]:
import sys
IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    %pip install -q rheojax
    import os
    os.environ["JAX_ENABLE_X64"] = "true"
    print("RheoJAX installed successfully.")

In [None]:
%matplotlib inline
import os
import json
import time
import warnings

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display

from rheojax.core.jax_config import safe_import_jax, verify_float64
from rheojax.models.sgr import SGRConventional, SGRGeneric

jax, jnp = safe_import_jax()
verify_float64()

warnings.filterwarnings("ignore", category=FutureWarning)
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")

In [None]:
def compute_fit_quality(y_true, y_pred):
    """Compute R² and RMSE."""
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    residuals = y_true - y_pred
    if y_true.ndim > 1:
        residuals = residuals.ravel()
        y_true = y_true.ravel()
    ss_res = np.sum(residuals**2)
    ss_tot = np.sum((y_true - np.mean(y_true))**2)
    r2 = 1.0 - ss_res / ss_tot if ss_tot > 0 else 0.0
    rmse = np.sqrt(np.mean(residuals**2))
    return {"R2": r2, "RMSE": rmse}

## 2. Theory: SGR Creep

The SGR creep compliance is given by:

$$J(t) = \frac{(1 + t/\tau_0)^{2-x}}{G_0 \cdot G_0(x)}$$

where $G_0(x) = \Gamma(2-x)\Gamma(x)/\Gamma(2)$ is the normalization factor.

**Growth exponent:** $(2-x)$

**Phase regimes:**
- $x < 1$: Very slow creep → elastic plateau (glass)
- $1 < x < 2$: Power-law creep (sublinear)
- $x = 2$: Linear creep (Newtonian fluid)
- $x > 2$: Super-diffusive creep

**Biological soft matter** (mucus) often exhibits soft glassy behavior with $1 < x < 2$, reflecting structural rearrangements under constant stress.

## 3. Load Data

In [None]:
data_path = os.path.join("..", "data", "creep", "biological", "creep_mucus_data.csv")
raw = np.loadtxt(data_path, delimiter="\t", skiprows=1)
t = raw[:, 0]
J_t = raw[:, 1]

print(f"Data points: {len(t)}")
print(f"Time range: {t.min():.1f} – {t.max():.1f} s")
print(f"J(t) range: {J_t.min():.4f} – {J_t.max():.4f} 1/Pa")

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(t, J_t, "ko-", markersize=6)
ax1.set_xlabel("Time [s]")
ax1.set_ylabel("J(t) [1/Pa]")
ax1.set_title("Mucus Creep Compliance (linear)")
ax1.grid(True, alpha=0.3)

ax2.loglog(t, J_t, "ko-", markersize=6)
ax2.set_xlabel("Time [s]")
ax2.set_ylabel("J(t) [1/Pa]")
ax2.set_title("Mucus Creep Compliance (log-log)")
ax2.grid(True, alpha=0.3, which="both")

plt.tight_layout()
display(fig)
plt.close(fig)

## 4. NLSQ Fitting

In [None]:
model = SGRConventional()

t0_fit = time.time()
model.fit(t, J_t, test_mode="creep", method='scipy')
t_nlsq = time.time() - t0_fit

# Compute fit quality
J_pred_fit = model.predict(t)
metrics = compute_fit_quality(J_t, J_pred_fit)

print(f"NLSQ fit time: {t_nlsq:.2f} s")
print(f"R²: {metrics['R2']:.6f}")
print(f"RMSE: {metrics['RMSE']:.4g} 1/Pa")
print("\nFitted parameters:")
for name in ["x", "G0", "tau0"]:
    val = model.parameters.get_value(name)
    print(f"  {name:5s} = {val:.4g}")
print(f"Phase regime: {model.get_phase_regime()}")

In [None]:
t_fine = np.logspace(np.log10(max(t.min(), 0.1)) - 0.3, np.log10(t.max()) + 0.3, 200)
J_pred = model.predict(t_fine)

fig, ax = plt.subplots(figsize=(8, 5))
ax.loglog(t, J_t, "ko", markersize=6, label="Mucus data")
ax.loglog(t_fine, J_pred, "-", lw=2, color="C0", label="SGR fit")

# Power-law reference
x_fit = model.parameters.get_value("x")
slope = 2 - x_fit
ref_t = np.logspace(0, 1.5, 50)
ref_J = J_t[0] * (ref_t / t[0]) ** slope
ax.loglog(ref_t, ref_J, ":", lw=1.5, color="gray", alpha=0.5, label=f"t^{slope:.2f} reference")

ax.set_xlabel("Time [s]")
ax.set_ylabel("J(t) [1/Pa]")
ax.set_title(f"SGR Creep Fit — x={x_fit:.3f} ({model.get_phase_regime()})")
ax.legend()
ax.grid(True, alpha=0.3, which="both")
plt.tight_layout()
display(fig)
plt.close(fig)

### 4.1 Creep Regime Exploration

Predict J(t) for different x values to show how the creep exponent changes.

In [None]:
x_values = [0.7, 1.0, 1.5, 2.0, 2.5]
t_sweep = np.logspace(-1, 3, 200)

fig, ax = plt.subplots(figsize=(8, 5))
colors = plt.cm.coolwarm(np.linspace(0, 1, len(x_values)))

for i, x_val in enumerate(x_values):
    m = SGRConventional()
    m.parameters.set_value("x", x_val)
    m.parameters.set_value("G0", 1.0)
    m.parameters.set_value("tau0", 1.0)
    m.fitted_ = True
    m._test_mode = "creep"

    J_sweep = m.predict(t_sweep)
    regime = m.get_phase_regime()
    ax.loglog(t_sweep, J_sweep, "-", color=colors[i], lw=2, label=f"x={x_val} ({regime})")

ax.set_xlabel("Time [s]")
ax.set_ylabel("J(t) [1/Pa]")
ax.set_title("SGR Creep — Phase Regime Comparison")
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3, which="both")
plt.tight_layout()
display(fig)
plt.close(fig)

**Note:** `SGRGeneric` does NOT support creep in `model_function()`. Only `SGRConventional` provides creep predictions. This is a current limitation — `SGRGeneric` supports oscillation, relaxation, and steady_shear.

## 5. Bayesian Inference with NUTS

### 5.1 Run NUTS

**Note:** With only 20 data points, posteriors will be wider than typical. This is expected and physically meaningful.

In [None]:
initial_values = {name: model.parameters.get_value(name) for name in ["x", "G0", "tau0"]}
print("Warm-start values:", initial_values)

NUM_WARMUP = 200
NUM_SAMPLES = 500
NUM_CHAINS = 1
# NUM_WARMUP = 1000; NUM_SAMPLES = 2000; NUM_CHAINS = 4  # production

t0 = time.time()
result = model.fit_bayesian(
    t, J_t, test_mode="creep",
    num_warmup=NUM_WARMUP, num_samples=NUM_SAMPLES, num_chains=NUM_CHAINS,
    initial_values=initial_values, seed=42,
)
t_bayes = time.time() - t0
print(f"\nBayesian inference time: {t_bayes:.1f} s")

### 5.2 Convergence Diagnostics

In [None]:
diag = result.diagnostics
param_names = ["x", "G0", "tau0"]

print("Convergence Diagnostics")
print("=" * 50)
print(f"{'Parameter':>10s}  {'R-hat':>8s}  {'ESS':>8s}")
print("-" * 50)
for p in param_names:
    r_hat = diag.get("r_hat", {}).get(p, float("nan"))
    ess = diag.get("ess", {}).get(p, float("nan"))
    print(f"{p:>10s}  {r_hat:8.4f}  {ess:8.0f}")
n_div = diag.get("divergences", diag.get("num_divergences", 0))
print(f"\nDivergences: {n_div}")

### 5.3 ArviZ Plots

In [None]:
idata = result.to_inference_data()
axes = az.plot_trace(idata, var_names=param_names, figsize=(12, 6))
fig = axes.ravel()[0].figure
fig.suptitle("Trace Plots", fontsize=14, y=1.02)
plt.tight_layout()
display(fig)
plt.close(fig)

In [None]:
axes = az.plot_pair(idata, var_names=param_names, kind="scatter", divergences=True, figsize=(9, 9))
fig = axes.ravel()[0].figure
plt.tight_layout()
display(fig)
plt.close(fig)

In [None]:
axes = az.plot_forest(idata, var_names=param_names, combined=True, hdi_prob=0.95, figsize=(10, 4))
fig = axes.ravel()[0].figure
plt.tight_layout()
display(fig)
plt.close(fig)

### 5.4 Posterior Predictive

In [None]:
posterior = result.posterior_samples
n_draws = min(200, len(list(posterior.values())[0]))
t_pred = np.logspace(np.log10(max(t.min(), 0.1)) - 0.3, np.log10(t.max()) + 0.3, 100)

pred_samples = []
for i in range(n_draws):
    # Set parameters from posterior
    for name in ["x", "G0", "tau0"]:
        model.parameters.set_value(name, float(posterior[name][i]))
    pred_i = model.predict(t_pred)
    pred_samples.append(np.array(pred_i))

pred_samples = np.array(pred_samples)
pred_median = np.median(pred_samples, axis=0)
pred_lo = np.percentile(pred_samples, 2.5, axis=0)
pred_hi = np.percentile(pred_samples, 97.5, axis=0)

fig, ax = plt.subplots(figsize=(9, 6))
ax.fill_between(t_pred, pred_lo, pred_hi, alpha=0.3, color="C0", label="95% CI")
ax.loglog(t_pred, pred_median, "-", lw=2, color="C0", label="Posterior median")
ax.loglog(t, J_t, "ko", markersize=6, label="Data")
ax.set_xlabel("Time [s]")
ax.set_ylabel("J(t) [1/Pa]")
ax.set_title("Posterior Predictive (20 data points → wider CI)")
ax.legend()
ax.grid(True, alpha=0.3, which="both")
plt.tight_layout()
display(fig)
plt.close(fig)

### 5.5 Limited-Data Bayesian Discussion

With only 20 data points, the posteriors are wider — this is expected and informative:

1. **Wider CI reflects genuine uncertainty** from sparse data
2. **The posterior still constrains x meaningfully** — the power-law slope is robust
3. **G₀ and τ₀ may show stronger correlations** (identifiability challenge)
4. **Production runs with 4 chains are especially important here** for R-hat validation

This is a feature, not a bug: Bayesian inference correctly propagates data scarcity into posterior uncertainty.

## 6. Save Results

In [None]:
output_dir = os.path.join("..", "outputs", "sgr", "creep")
os.makedirs(output_dir, exist_ok=True)

nlsq_params = {name: float(model.parameters.get_value(name)) for name in ["x", "G0", "tau0"]}
with open(os.path.join(output_dir, "nlsq_params.json"), "w") as f:
    json.dump(nlsq_params, f, indent=2)

posterior_dict = {k: np.array(v).tolist() for k, v in posterior.items()}
with open(os.path.join(output_dir, "posterior_samples.json"), "w") as f:
    json.dump(posterior_dict, f)

print(f"Results saved to {output_dir}/")

## Key Takeaways

1. **SGR creep** $J(t) \sim (1+t/\tau_0)^{2-x}$ — power-law growth encodes phase regime
2. **Biological soft matter** (mucus) exhibits soft glassy creep behavior
3. **Limited data (20 pts)** gives wider posteriors — this is physically meaningful, not a bug
4. **The creep exponent $(2-x)$** is complementary to the relaxation exponent $(x-2)$: they sum to zero
5. **SGRGeneric does not support creep** — `SGRConventional` only for this protocol

**Next:** NB 05 (startup) or NB 06 (LAOS) for nonlinear protocols