# FluiditySaramitoLocal: Small-Amplitude Oscillatory Shear (SAOS)

Linear viscoelastic response of the fluidity-Saramito model in the small-strain limit.

**Learning Objectives:**
- Understand SAOS behavior of thixotropic elastoviscoplastic materials
- Maxwell-like linear response: G'(ω) and G''(ω)
- NLSQ + Bayesian parameter inference from frequency sweeps
- Cole-Cole analysis for model validation
- Diagnostic interpretation (R-hat, ESS, divergences)

## 1. Setup

In [None]:
# Google Colab setup (uncomment if running on Colab)
# !pip install rheojax nlsq numpyro arviz

In [None]:
import sys
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

# JAX-safe import (critical for float64)
from rheojax.core.jax_config import safe_import_jax
from rheojax.utils.metrics import compute_fit_quality
jax, jnp = safe_import_jax()

from rheojax.models.fluidity.saramito import FluiditySaramitoLocal
from rheojax.core.data import RheoData
from rheojax.logging import configure_logging, get_logger

# Configure logging
configure_logging(level="INFO")
logger = get_logger(__name__)

# Plotting style
plt.style.use('seaborn-v0_8-darkgrid')
%matplotlib inline

## 2. Theory: SAOS in the Linear Limit

### Governing Equations

For small oscillatory strain $\gamma = \gamma_0 \sin(\omega t)$ with $\gamma_0 \ll 1$:

**Maxwell backbone (coupling='minimal'):**
$$\tau + \lambda \frac{d\tau}{dt} = G \gamma$$

where $\lambda = 1/f$ is the fluidity-dependent relaxation time.

**Complex modulus:**
$$G^*(\omega) = G' + iG'' = \frac{G \omega^2 \lambda^2}{1 + \omega^2 \lambda^2} + i \frac{G \omega \lambda}{1 + \omega^2 \lambda^2}$$

**Key features:**
- Low frequency ($\omega \ll 1/\lambda$): $G' \sim \omega^2$, $G'' \sim \omega$ (viscous liquid)
- High frequency ($\omega \gg 1/\lambda$): $G' \to G$, $G'' \sim 1/\omega$ (elastic solid)
- Crossover at $\omega = 1/\lambda$: $G' = G'' = G/2$
- Cole-Cole plot: $G''$ vs $G'$ forms a semicircle (Maxwell signature)

**Thixotropic effects:**
- Fluidity evolution: $df/dt = (1-f)/t_{eq} + b|\dot{\gamma}|^n$ (rejuvenation term negligible for $\gamma_0 \ll 1$)
- Steady state: $f \approx 1$ (fully structured), $\lambda \to \lambda_{min}$
- Pre-shearing history affects initial fluidity and transient response

## 3. Generate Synthetic SAOS Data

In [None]:
# Frequency sweep (rad/s)
omega = np.logspace(-2, 2, 30)  # 0.01 to 100 rad/s

# Ground truth parameters (using actual model parameter names with valid bounds)
params_true = {
    'G': 1e5,          # Elastic modulus (Pa) - bounds (10, 1e8)
    'eta_s': 100.0,    # Solvent viscosity (Pa·s) - bounds (0, 1000)
    'tau_y0': 50.0,    # Yield stress (Pa) - bounds (0.1, 1e5)
    'K_HB': 100.0,     # Consistency coefficient - bounds (0.01, 1e5)
    'n_HB': 0.5,       # Power-law exponent - bounds (0.1, 1.5)
    't_a': 10.0,       # Aging time (s) - bounds (0.01, 1e5)
    'b': 0.5,          # Rejuvenation coefficient - bounds (0, 1000)
    'n_rej': 1.0,      # Rejuvenation exponent - bounds (0.1, 3)
    'f_age': 1e-4,     # Equilibrium fluidity - bounds (1e-12, 0.01)
    'f_flow': 0.1,     # Maximum fluidity - bounds (1e-6, 1)
}

# Create model and generate data
model_true = FluiditySaramitoLocal(coupling='minimal')

# Set parameters (check bounds)
for name, value in params_true.items():
    if name in model_true.parameters.keys():
        try:
            model_true.parameters.set_value(name, value)
        except ValueError as e:
            logger.warning(f"Could not set {name}={value}: {e}")

# Predict complex modulus
G_pred = model_true.predict(omega, test_mode='oscillation')

# Handle output format
if hasattr(G_pred, 'ndim') and G_pred.ndim == 2:
    G_prime_true = np.array(G_pred[:, 0])
    G_double_prime_true = np.array(G_pred[:, 1])
else:
    # Scalar output (magnitude)
    G_prime_true = np.abs(np.array(G_pred))
    G_double_prime_true = G_prime_true * 0.5  # Approximate

# Add 5% noise
np.random.seed(42)
noise_level = 0.05
G_prime = G_prime_true * (1 + noise_level * np.random.randn(len(omega)))
G_double_prime = G_double_prime_true * (1 + noise_level * np.random.randn(len(omega)))
G_star = np.column_stack([G_prime, G_double_prime])

logger.info(f"Generated SAOS data: {len(omega)} frequency points")
logger.info(f"Frequency range: {omega.min():.2e} to {omega.max():.2e} rad/s")
logger.info(f"G' range: {G_prime.min():.2e} to {G_prime.max():.2e} Pa")
logger.info(f"G'' range: {G_double_prime.min():.2e} to {G_double_prime.max():.2e} Pa")

In [None]:
# Visualize raw data
fig, ax = plt.subplots(figsize=(8, 6))

ax.loglog(omega, G_prime, 'o', label="G' (Storage)", markersize=8, alpha=0.7)
ax.loglog(omega, G_double_prime, 's', label="G'' (Loss)", markersize=8, alpha=0.7)
ax.loglog(omega, G_prime_true, '-', color='C0', label='G\' (True)', linewidth=2, alpha=0.5)
ax.loglog(omega, G_double_prime_true, '-', color='C1', label='G\'\' (True)', linewidth=2, alpha=0.5)

ax.set_xlabel('Angular Frequency ω (rad/s)', fontsize=12)
ax.set_ylabel('Modulus (Pa)', fontsize=12)
ax.set_title('SAOS Data: G\' and G\'\' vs Frequency', fontsize=14, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, which='both', alpha=0.3)

plt.tight_layout()
plt.show()

# Crossover frequency (where G' = G'')
idx_cross = np.argmin(np.abs(G_prime - G_double_prime))
omega_cross = omega[idx_cross]
logger.info(f"Crossover frequency ω_c ≈ {omega_cross:.3f} rad/s (λ ≈ {1/omega_cross:.3f} s)")

## 4. NLSQ Fitting

In [None]:
# Initialize model for fitting
model = FluiditySaramitoLocal(coupling='minimal')

# Print actual parameter names
logger.info("Model parameters:")
for name in model.parameters.keys():
    val = model.parameters.get_value(name)
    logger.info(f"  {name}: {val:.4e}")

logger.info("\nStarting NLSQ optimization...")
model.fit(omega, G_star, test_mode='oscillation', method='scipy')

# Compute fit quality
G_fit = model.predict(omega, test_mode='oscillation')
if G_fit.ndim == 2:
    G_prime_fit = G_fit[:, 0]
    G_double_prime_fit = G_fit[:, 1]
    G_star_fit = np.abs(G_fit[:, 0] + 1j * G_fit[:, 1])
else:
    G_prime_fit = G_fit
    G_double_prime_fit = G_fit
    G_star_fit = G_fit

G_star_data = np.abs(G_prime + 1j * G_double_prime)
metrics = compute_fit_quality(G_star_data, np.array(G_star_fit).flatten())

logger.info(f"\nNLSQ Results:")
logger.info(f"R² = {metrics['R2']:.6f}")
logger.info(f"RMSE = {metrics['RMSE']:.4e} Pa")

In [None]:
# Compare fitted vs true parameters
print("\nParameter Comparison:")
print("=" * 70)
print(f"{'Parameter':<15} {'True':<15} {'NLSQ':<15} {'Error (%)':<15}")
print("=" * 70)

for name in model.parameters.keys():
    true_val = params_true.get(name, None)
    fitted_val = model.parameters.get_value(name)
    if true_val is not None:
        error_pct = 100 * abs(fitted_val - true_val) / abs(true_val) if true_val != 0 else 0
        print(f"{name:<15} {true_val:<15.3e} {fitted_val:<15.3e} {error_pct:<15.2f}")
    else:
        print(f"{name:<15} {'N/A':<15} {fitted_val:<15.3e} {'N/A':<15}")

print("=" * 70)

In [None]:
# Plot fit quality
G_fit_full = model.predict(omega, test_mode='oscillation')
if G_fit_full.ndim == 2:
    G_prime_fit = np.array(G_fit_full[:, 0])
    G_double_prime_fit = np.array(G_fit_full[:, 1])
else:
    G_prime_fit = np.array(G_fit_full)
    G_double_prime_fit = np.array(G_fit_full)

fig, axes = plt.subplots(2, 1, figsize=(10, 10))

# G' and G''
ax = axes[0]
ax.loglog(omega, G_prime, 'o', label="G' (Data)", markersize=8, alpha=0.7)
ax.loglog(omega, G_double_prime, 's', label="G'' (Data)", markersize=8, alpha=0.7)
ax.loglog(omega, G_prime_fit, '-', color='C0', label="G' (NLSQ Fit)", linewidth=2)
ax.loglog(omega, G_double_prime_fit, '-', color='C1', label="G'' (NLSQ Fit)", linewidth=2)
ax.set_xlabel('Angular Frequency ω (rad/s)', fontsize=12)
ax.set_ylabel('Modulus (Pa)', fontsize=12)
ax.set_title(f"NLSQ Fit (R² = {metrics['R2']:.6f})", fontsize=14, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, which='both', alpha=0.3)

# Residuals
ax = axes[1]
residuals_prime = (G_prime - G_prime_fit) / np.maximum(G_prime_fit, 1e-10) * 100
residuals_double_prime = (G_double_prime - G_double_prime_fit) / np.maximum(G_double_prime_fit, 1e-10) * 100
ax.semilogx(omega, residuals_prime, 'o-', label="G' Residuals", markersize=6)
ax.semilogx(omega, residuals_double_prime, 's-', label="G'' Residuals", markersize=6)
ax.axhline(0, color='k', linestyle='--', linewidth=1)
ax.set_xlabel('Angular Frequency ω (rad/s)', fontsize=12)
ax.set_ylabel('Residuals (%)', fontsize=12)
ax.set_title('Fit Residuals', fontsize=14, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, which='both', alpha=0.3)

plt.tight_layout()
plt.show()

## 5. Bayesian Inference

In [None]:
# Bayesian inference with NLSQ warm-start (critical for convergence)
logger.info("\nStarting Bayesian inference (NUTS)...")
logger.info("Using NLSQ solution as warm-start")

# Get initial values from fitted model
initial_values = {
    name: model.parameters.get_value(name)
    for name in model.parameters.keys()
}

bayes_result = model.fit_bayesian(
    omega,
    G_star,
    test_mode='oscillation',
    num_warmup=500,
    num_samples=1000,
    num_chains=1,        # Fast demo; use 4 for production
    seed=42,             # Reproducibility
    initial_values=initial_values,
)

logger.info("Bayesian inference complete")

## 6. ArviZ Diagnostics

In [None]:
import arviz as az

# Convert to ArviZ InferenceData
idata = bayes_result.to_inference_data()

# Get parameter names from model
param_names = list(model.parameters.keys())

# Summary statistics  
summary = az.summary(idata, var_names=param_names, hdi_prob=0.95)
print("\nBayesian Parameter Estimates (95% HDI):")
print("=" * 80)
print(summary)
print("=" * 80)

# Check diagnostics
logger.info("\nDiagnostic Checks:")
diag = bayes_result.diagnostics
for param in param_names:
    r_hat = diag.get("r_hat", {}).get(param, float("nan"))
    ess = diag.get("ess", {}).get(param, float("nan"))
    status = "PASS" if (r_hat < 1.01 and ess > 400) else "WARN"
    logger.info(f"{param:<12} R-hat={r_hat:.4f}, ESS={ess:.0f} [{status}]")

In [None]:
# Trace plots (MCMC convergence)
az.plot_trace(idata, compact=True, figsize=(12, 10))
plt.suptitle('MCMC Trace Plots', fontsize=16, fontweight='bold', y=1.001)
plt.tight_layout()
plt.show()

In [None]:
# Forest plot (credible intervals)
az.plot_forest(idata, hdi_prob=0.95, figsize=(8, 6))
plt.title('Parameter Credible Intervals (95% HDI)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
# Pair plot (correlations)
# Select a subset of key parameters
key_params = ['G', 'eta_s', 't_a', 'b']
key_params = [p for p in key_params if p in bayes_result.posterior_samples][:4]

if len(key_params) >= 2:
    az.plot_pair(
        idata,
        var_names=key_params,
        divergences=True,
        figsize=(10, 10)
    )
    plt.suptitle('Parameter Correlations', fontsize=16, fontweight='bold', y=1.001)
    plt.tight_layout()
    plt.show()
else:
    logger.info("Not enough parameters for pair plot")

In [None]:
# Energy plot (HMC diagnostics)
# Note: Energy plot requires sample_stats with 'energy' field
try:
    az.plot_energy(idata, figsize=(8, 5))
    plt.title('HMC Energy Diagnostic', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
except (AttributeError, KeyError) as e:
    logger.info(f"Energy plot skipped: {e}")
    logger.info("Note: Energy diagnostic requires MCMC sample stats with energy field")

## 7. Cole-Cole Plot Analysis

In [None]:
# Cole-Cole plot: G'' vs G' (should form semicircle for Maxwell model)
fig, ax = plt.subplots(figsize=(8, 8))

# Data
ax.plot(G_prime, G_double_prime, 'o', label='Data', markersize=10, alpha=0.7)

# NLSQ fit
ax.plot(G_prime_fit, G_double_prime_fit, '-', label='NLSQ Fit', linewidth=2)

# Theoretical semicircle (Maxwell)
G_fit_param = model.parameters.get_value('G')
theta = np.linspace(0, np.pi, 100)
G_circle = G_fit_param / 2 * (1 + np.cos(theta))
G_double_circle = G_fit_param / 2 * np.sin(theta)
ax.plot(G_circle, G_double_circle, '--', color='gray', label='Maxwell Semicircle', linewidth=2, alpha=0.5)

ax.set_xlabel("G' (Storage Modulus) [Pa]", fontsize=12)
ax.set_ylabel("G'' (Loss Modulus) [Pa]", fontsize=12)
ax.set_title('Cole-Cole Plot', fontsize=14, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_aspect('equal', adjustable='box')

plt.tight_layout()
plt.show()

logger.info("Cole-Cole analysis: Data should follow semicircle for ideal Maxwell behavior")

## 8. Posterior Predictive Analysis

In [None]:
# Compute credible intervals from posterior samples
posterior_samples = bayes_result.posterior_samples
n_samples = len(list(posterior_samples.values())[0])

# Random subset for efficiency (50 posterior samples for speed)
np.random.seed(42)
n_subsample = min(50, n_samples)
sample_indices = np.random.choice(n_samples, size=n_subsample, replace=False)

G_posterior_samples = []
param_names = list(model.parameters.keys())

for idx in sample_indices:
    # Set parameters from posterior sample
    for name in param_names:
        if name in posterior_samples:
            model.parameters.set_value(name, float(posterior_samples[name][idx]))
    
    G_pred = model.predict(omega, test_mode='oscillation')
    G_posterior_samples.append(G_pred)

G_posterior_samples = np.array(G_posterior_samples)

if G_posterior_samples.ndim == 3:
    G_prime_posterior = G_posterior_samples[:, :, 0]
    G_double_prime_posterior = G_posterior_samples[:, :, 1]
else:
    G_prime_posterior = G_posterior_samples
    G_double_prime_posterior = G_posterior_samples

# Compute percentiles
G_prime_median = np.median(G_prime_posterior, axis=0)
G_prime_lower = np.percentile(G_prime_posterior, 2.5, axis=0)
G_prime_upper = np.percentile(G_prime_posterior, 97.5, axis=0)

G_double_prime_median = np.median(G_double_prime_posterior, axis=0)
G_double_prime_lower = np.percentile(G_double_prime_posterior, 2.5, axis=0)
G_double_prime_upper = np.percentile(G_double_prime_posterior, 97.5, axis=0)

logger.info("Posterior predictive credible intervals computed")

In [None]:
# Plot with credible intervals
fig, ax = plt.subplots(figsize=(10, 7))

# Data
ax.loglog(omega, G_prime, 'o', label="G' (Data)", markersize=8, alpha=0.7, color='C0')
ax.loglog(omega, G_double_prime, 's', label="G'' (Data)", markersize=8, alpha=0.7, color='C1')

# Posterior median
ax.loglog(omega, G_prime_median, '-', label="G' (Posterior Median)", linewidth=2, color='C0')
ax.loglog(omega, G_double_prime_median, '-', label="G'' (Posterior Median)", linewidth=2, color='C1')

# 95% credible intervals
ax.fill_between(omega, G_prime_lower, G_prime_upper, alpha=0.3, color='C0', label="G' 95% CI")
ax.fill_between(omega, G_double_prime_lower, G_double_prime_upper, alpha=0.3, color='C1', label="G'' 95% CI")

ax.set_xlabel('Angular Frequency ω (rad/s)', fontsize=12)
ax.set_ylabel('Modulus (Pa)', fontsize=12)
ax.set_title('Bayesian Posterior Predictive (95% Credible Intervals)', fontsize=14, fontweight='bold')
ax.legend(fontsize=9, loc='best')
ax.grid(True, which='both', alpha=0.3)

plt.tight_layout()
plt.show()

## 9. Save Results

In [None]:
# Create output directory
import json
output_dir = Path('../outputs/fluidity/saramito_local/saos')
output_dir.mkdir(parents=True, exist_ok=True)

# Save NLSQ results
nlsq_results = {
    'parameters': {name: float(model.parameters.get_value(name)) for name in model.parameters.keys()},
    'r_squared': float(metrics['R2']),
    'rmse': float(metrics['RMSE']),
}

with open(output_dir / 'nlsq_results.json', 'w') as f:
    json.dump(nlsq_results, f, indent=2)

# Save posterior samples
posterior_dict = {k: np.array(v).tolist() for k, v in posterior_samples.items()}
with open(output_dir / 'posterior_samples.json', 'w') as f:
    json.dump(posterior_dict, f)

# Save summary
summary.to_csv(output_dir / 'bayesian_summary.csv')

logger.info(f"Results saved to {output_dir}")
print(f"\nOutput files:")
print(f"  - nlsq_results.json")
print(f"  - posterior_samples.json")
print(f"  - bayesian_summary.csv")

## 10. Key Takeaways

### Physical Insights

1. **Maxwell-like SAOS behavior:**
   - Storage modulus G'(ω) transitions from ω² (viscous) to plateau (elastic)
   - Loss modulus G''(ω) exhibits peak at crossover frequency ω_c = 1/λ
   - Cole-Cole plot forms semicircle (characteristic of single relaxation time)

2. **Thixotropic signatures:**
   - Initial fluidity f_0 affects transient approach to steady oscillation
   - Equilibrium time t_eq controls structural recovery during rest periods
   - Yield stress τ_y has minimal effect in linear regime (γ_0 ≪ 1)

3. **Parameter identifiability:**
   - Elastic modulus G well-constrained by high-frequency plateau
   - Viscosity η_∞ (via fluidity) from low-frequency slope G'' ~ ω
   - Relaxation time λ from crossover frequency ω_c
   - Thixotropic parameters (b, n, t_eq) weakly identifiable from SAOS alone

### Computational Best Practices

4. **NLSQ warm-start critical:**
   - Bayesian inference requires good initialization (use `.fit()` first)
   - Multi-start optimization recommended for complex models
   - Check R² > 0.95 before proceeding to Bayesian

5. **Diagnostic interpretation:**
   - R-hat < 1.01: chains converged
   - ESS > 400: sufficient effective samples
   - Energy plot: marginal/conditional distributions should match
   - Divergences: indicate problematic geometry (reparameterize or increase warmup)

6. **Model validation:**
   - Cole-Cole semicircle confirms Maxwell backbone
   - Residuals < 5% indicate good fit
   - Posterior predictive intervals should encompass data

### Experimental Design

7. **Frequency sweep strategy:**
   - Log-spaced frequencies spanning 3-4 decades
   - Include both low (viscous) and high (elastic) frequency limits
   - Strain amplitude γ_0 < 1% to ensure linear regime

8. **Pre-shearing protocol:**
   - Control initial fluidity f_0 via rest time after pre-shear
   - Long rest (t ≫ t_eq): fully structured (f → f_min)
   - Short rest (t ≪ t_eq): partially rejuvenated (f > f_min)

### Next Steps

- **LAOS (Large-Amplitude Oscillatory Shear):** Nonlinear viscoelasticity and yield transition
- **Startup flows:** Stress overshoot and thixotropic kinetics
- **Step-strain relaxation:** Direct probe of relaxation time spectrum
- **Combined protocols:** Simultaneous fitting of SAOS + flow curves for full parameter identification