# Small Amplitude Oscillatory Shear (SAOS) - FluiditySaramitoNonlocal

**Linear Viscoelastic Response with Spatial Effects**

This notebook demonstrates SAOS analysis using the nonlocal Fluidity-Saramito model, which couples:
- Tensorial viscoelasticity (UCM backbone)
- Von Mises yield criterion
- Thixotropic fluidity evolution
- Spatial diffusion effects

For linear oscillatory shear (γ₀ → 0), we extract G'(ω) and G''(ω) across frequency ranges.

## Setup and Imports

In [None]:
# Google Colab setup
try:
    import google.colab
    IN_COLAB = True
    !pip install -q rheojax
except ImportError:
    IN_COLAB = False

import sys
from pathlib import Path

if not IN_COLAB:
    # Add project root to path for local development
    project_root = Path.cwd().parent.parent
    if project_root not in sys.path:
        sys.path.insert(0, str(project_root))

In [None]:
# JAX configuration (MUST be first)
from rheojax.core.jax_config import safe_import_jax
from rheojax.utils.metrics import compute_fit_quality
jax, jnp = safe_import_jax()

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# RheoJAX imports
from rheojax.models.fluidity import FluiditySaramitoNonlocal, FluiditySaramitoLocal
from rheojax.core.data import RheoData
from rheojax.logging import configure_logging, get_logger

# Configure logging
configure_logging(level="INFO")
logger = get_logger(__name__)

# Set random seeds for reproducibility
np.random.seed(42)

print(f"JAX version: {jax.__version__}")
print(f"JAX devices: {jax.devices()}")
print(f"Float64 enabled: {jax.config.jax_enable_x64}")

## Theory: SAOS with Spatial Effects

### Governing Equations

For small amplitude oscillatory shear γ(t) = γ₀ sin(ωt) with γ₀ → 0:

**Stress tensor:**
$$\boldsymbol{\tau} + \lambda(f) \frac{D\boldsymbol{\tau}}{Dt} = 2\eta(f)\mathbf{D}$$

where λ(f) = 1/f (minimal coupling) or includes aging effects.

**Fluidity evolution:**
$$\frac{\partial f}{\partial t} = \frac{1 - f}{t_{\text{eq}}} + b|\dot{\gamma}|^n f + D_f \nabla^2 f$$

**Linear regime:**
- Von Mises yield criterion inactive (α = 1)
- Stress response: τ(t) = G'(ω) γ₀ sin(ωt) + G''(ω) γ₀ cos(ωt)
- Storage modulus: G'(ω) ~ ω² (elastic response)
- Loss modulus: G''(ω) ~ ω (viscous dissipation)

### Spatial Effects

The diffusion term D_f∇²f can affect:
1. Spatial homogeneity of fluidity distribution
2. Effective relaxation timescale
3. High-frequency modulus plateau

## Data Loading

Load experimental SAOS data or generate synthetic data from the model.

In [None]:
# Try to load real data
data_file = Path("../data/polystyrene_saos.csv")

if data_file.exists():
    logger.info(f"Loading experimental data from {data_file}")
    rheo_data = RheoData.from_csv(
        data_file,
        x_col="omega",
        y_col="G_star",
        test_mode="oscillation"
    )
    omega = rheo_data.x
    G_star_exp = rheo_data.y
    USE_SYNTHETIC = False
else:
    logger.info("Experimental data not found, generating synthetic data")
    USE_SYNTHETIC = True
    
    # Generate synthetic SAOS data
    omega = jnp.logspace(-2, 2, 50)  # 0.01 to 100 rad/s
    
    # True parameters for synthetic data
    true_params = {
        'G': 1000.0,        # Pa - elastic modulus
        'eta_s': 100.0,     # Pa·s - solvent viscosity
        'tau_y': 50.0,      # Pa - yield stress
        't_eq': 10.0,       # s - equilibration time
        'b': 1.0,           # s^(-n) - rejuvenation rate
        'n': 1.0,           # dimensionless - power-law exponent
        'D_f': 1e-6,        # m²/s - fluidity diffusion
        'gap_width': 1e-3   # m - gap width
    }
    
    # Generate clean data
    model_true = FluiditySaramitoNonlocal(
        coupling="minimal",
        n_points=51
    )
    model_true.G.value = true_params['G']
    model_true.eta_s.value = true_params['eta_s']
    model_true.tau_y.value = true_params['tau_y']
    model_true.t_eq.value = true_params['t_eq']
    model_true.b.value = true_params['b']
    model_true.n.value = true_params['n']
    model_true.D_f.value = true_params['D_f']
    model_true.gap_width.value = true_params['gap_width']
    
    # Small amplitude for linear regime
    gamma_0 = 0.01
    
    G_star_clean = model_true.predict(omega, test_mode='oscillation', gamma_0=gamma_0)
    
    # Add realistic noise (5% for storage, 3% for loss modulus)
    G_prime = jnp.real(G_star_clean)
    G_double_prime = jnp.imag(G_star_clean)
    
    noise_G_prime = G_prime * 0.05 * np.random.randn(len(omega))
    noise_G_double_prime = G_double_prime * 0.03 * np.random.randn(len(omega))
    
    G_star_exp = (G_prime + noise_G_prime) + 1j * (G_double_prime + noise_G_double_prime)
    
    logger.info(f"Generated synthetic data: {len(omega)} points")
    logger.info(f"True parameters: {true_params}")

# Extract components
G_prime_exp = jnp.real(G_star_exp)
G_double_prime_exp = jnp.imag(G_star_exp)

print(f"\nData summary:")
print(f"  Frequency range: {omega.min():.3e} - {omega.max():.3e} rad/s")
print(f"  G' range: {G_prime_exp.min():.2e} - {G_prime_exp.max():.2e} Pa")
print(f"  G'' range: {G_double_prime_exp.min():.2e} - {G_double_prime_exp.max():.2e} Pa")

In [None]:
# Visualize experimental data
fig, ax = plt.subplots(figsize=(10, 6))

ax.loglog(omega, G_prime_exp, 'o', label="G' (storage)", markersize=6, alpha=0.7)
ax.loglog(omega, G_double_prime_exp, 's', label="G'' (loss)", markersize=6, alpha=0.7)

ax.set_xlabel('Angular Frequency ω (rad/s)', fontsize=12)
ax.set_ylabel('Modulus (Pa)', fontsize=12)
ax.set_title('Experimental SAOS Data', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## NLSQ Fitting

Fit the nonlocal Fluidity-Saramito model using NLSQ optimization.

In [None]:
# Initialize model
model = FluiditySaramitoNonlocal(
    coupling="minimal",  # λ = 1/f only
    n_points=51          # Spatial resolution
)

# Set parameter bounds based on physical constraints
model.G.bounds = (1e2, 1e5)          # Elastic modulus (Pa)
model.eta_s.bounds = (1.0, 1e4)      # Solvent viscosity (Pa·s)
model.tau_y.bounds = (1.0, 500.0)    # Yield stress (Pa)
model.t_eq.bounds = (0.1, 100.0)     # Equilibration time (s)
model.b.bounds = (0.01, 10.0)        # Rejuvenation rate
model.n.bounds = (0.5, 2.0)          # Power-law exponent
model.D_f.bounds = (1e-8, 1e-4)      # Fluidity diffusion (m²/s)
model.gap_width.bounds = (1e-4, 1e-2)  # Gap width (m)

print("Initial parameters:")
for param in model.parameters:
    print(f"  {param.name}: {param.value:.4e} (bounds: {param.bounds})")

In [None]:
# Prepare data
rheo_data = RheoData(
    x=omega,
    y=G_star_exp,
    test_mode='oscillation'
)

# Fit model
logger.info("Starting NLSQ fitting...")
result = model.fit(
    rheo_data,
    gamma_0=0.01,  # Small amplitude for linear regime
    max_iter=2000,
    verbose=True
, method='scipy')

print("\n" + "="*60)
print("NLSQ Fitting Results")
print("="*60)
print(f"Convergence: {result.success}")
print(f"R² score: {metrics['R2']:.6f}")
print(f"RMSE: {metrics['RMSE']:.4e}")
print(f"Iterations: {result.nit}")
print(f"\nOptimized parameters:")
for param in model.parameters:
    print(f"  {param.name}: {param.value:.4e}")

if USE_SYNTHETIC:
    print(f"\nTrue parameters (for comparison):")
    for key, val in true_params.items():
        print(f"  {key}: {val:.4e}")

In [None]:
# Generate predictions
G_star_fit = model.predict(omega, test_mode='oscillation', gamma_0=0.01)
G_prime_fit = jnp.real(G_star_fit)
G_double_prime_fit = jnp.imag(G_star_fit)

# Plot fit
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

# Storage modulus
ax1.loglog(omega, G_prime_exp, 'o', label='Experimental', markersize=6, alpha=0.7)
ax1.loglog(omega, G_prime_fit, '-', linewidth=2, label='NLSQ fit')
ax1.set_xlabel('ω (rad/s)', fontsize=12)
ax1.set_ylabel("G' (Pa)", fontsize=12)
ax1.set_title('Storage Modulus', fontsize=13, fontweight='bold')
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# Loss modulus
ax2.loglog(omega, G_double_prime_exp, 's', label='Experimental', markersize=6, alpha=0.7)
ax2.loglog(omega, G_double_prime_fit, '-', linewidth=2, label='NLSQ fit')
ax2.set_xlabel('ω (rad/s)', fontsize=12)
ax2.set_ylabel('G" (Pa)', fontsize=12)
ax2.set_title('Loss Modulus', fontsize=13, fontweight='bold')
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Bayesian Inference

Perform Bayesian parameter estimation with NUTS sampling, warm-started from NLSQ results.

In [None]:
# Set priors (log-uniform for scale parameters)
model.G.prior = ('LogUniform', 1e2, 1e5)
model.eta_s.prior = ('LogUniform', 1.0, 1e4)
model.tau_y.prior = ('Uniform', 1.0, 500.0)
model.t_eq.prior = ('LogUniform', 0.1, 100.0)
model.b.prior = ('LogUniform', 0.01, 10.0)
model.n.prior = ('Uniform', 0.5, 2.0)
model.D_f.prior = ('LogUniform', 1e-8, 1e-4)
model.gap_width.prior = ('LogUniform', 1e-4, 1e-2)

print("Priors:")
for param in model.parameters:
    print(f"  {param.name}: {param.prior}")

In [None]:
# Run Bayesian inference
logger.info("Starting Bayesian inference with NUTS...")

bayes_result = model.fit_bayesian(
    rheo_data,
    gamma_0=0.01,
    num_warmup=1000,
    num_samples=2000,
    num_chains=4,
    seed=42,
    progress_bar=True
)

print("\n" + "="*60)
print("Bayesian Inference Complete")
print("="*60)
print(f"Total samples: {bayes_result.num_samples}")
print(f"Number of chains: {bayes_result.num_chains}")

## ArviZ Diagnostics

Check MCMC convergence and posterior quality.

In [None]:
# Convert to ArviZ InferenceData
try:
    import arviz as az
    
    idata = bayes_result.to_inference_data()
    
    # Summary statistics
    print("\nPosterior Summary:")
    print(az.summary(
        idata,
        var_names=[p.name for p in model.parameters],
        hdi_prob=0.95
    ))
    
    # Check R-hat and ESS
    summary = az.summary(idata, var_names=[p.name for p in model.parameters])
    max_rhat = summary['r_hat'].max()
    min_ess_bulk = summary['ess_bulk'].min()
    
    print(f"\nDiagnostics:")
    print(f"  Max R-hat: {max_rhat:.4f} {'✓' if max_rhat < 1.01 else '✗ (>1.01)'}")
    print(f"  Min ESS (bulk): {min_ess_bulk:.0f} {'✓' if min_ess_bulk > 400 else '✗ (<400)'}")
    
except ImportError:
    print("ArviZ not installed. Install with: pip install arviz")
    idata = None

In [None]:
# Trace plots
if idata is not None:
    az.plot_trace(
        idata,
        var_names=[p.name for p in model.parameters],
        compact=True,
        figsize=(12, 10)
    )
    plt.tight_layout()
    plt.show()

In [None]:
# Pair plot for correlations
if idata is not None:
    az.plot_pair(
        idata,
        var_names=[p.name for p in model.parameters],
        kind='hexbin',
        figsize=(14, 14),
        divergences=True
    )
    plt.tight_layout()
    plt.show()

In [None]:
# Forest plot (credible intervals)
if idata is not None:
    az.plot_forest(
        idata,
        var_names=[p.name for p in model.parameters],
        hdi_prob=0.95,
        figsize=(10, 6)
    )
    plt.tight_layout()
    plt.show()

## Posterior Predictions

In [None]:
# Extract credible intervals
intervals = model.get_credible_intervals(
    bayes_result.posterior_samples,
    credibility=0.95
)

print("\n95% Credible Intervals:")
for param_name, (lower, median, upper) in intervals.items():
    print(f"  {param_name}: {median:.4e} [{lower:.4e}, {upper:.4e}]")
    if USE_SYNTHETIC and param_name in true_params:
        true_val = true_params[param_name]
        in_interval = lower <= true_val <= upper
        print(f"    True value: {true_val:.4e} {'✓' if in_interval else '✗'}")

In [None]:
# Posterior predictive samples
n_posterior_samples = 100
posterior_samples = bayes_result.posterior_samples

# Randomly select samples
indices = np.random.choice(
    len(posterior_samples[model.parameters[0].name]),
    size=n_posterior_samples,
    replace=False
)

G_star_posterior = []
for idx in indices:
    # Set parameters from posterior sample
    for param in model.parameters:
        param.value = float(posterior_samples[param.name][idx])
    
    # Predict
    G_star_pred = model.predict(omega, test_mode='oscillation', gamma_0=0.01)
    G_star_posterior.append(G_star_pred)

G_star_posterior = jnp.array(G_star_posterior)
G_prime_posterior = jnp.real(G_star_posterior)
G_double_prime_posterior = jnp.imag(G_star_posterior)

# Compute percentiles
G_prime_median = jnp.percentile(G_prime_posterior, 50, axis=0)
G_prime_lower = jnp.percentile(G_prime_posterior, 2.5, axis=0)
G_prime_upper = jnp.percentile(G_prime_posterior, 97.5, axis=0)

G_double_prime_median = jnp.percentile(G_double_prime_posterior, 50, axis=0)
G_double_prime_lower = jnp.percentile(G_double_prime_posterior, 2.5, axis=0)
G_double_prime_upper = jnp.percentile(G_double_prime_posterior, 97.5, axis=0)

print(f"Generated {n_posterior_samples} posterior predictive samples")

In [None]:
# Plot posterior predictions
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

# Storage modulus
ax1.loglog(omega, G_prime_exp, 'o', label='Experimental', markersize=6, alpha=0.7, zorder=3)
ax1.loglog(omega, G_prime_median, '-', linewidth=2, label='Posterior median', color='red', zorder=2)
ax1.fill_between(
    omega,
    G_prime_lower,
    G_prime_upper,
    alpha=0.3,
    label='95% credible interval',
    color='red',
    zorder=1
)
ax1.set_xlabel('ω (rad/s)', fontsize=12)
ax1.set_ylabel("G' (Pa)", fontsize=12)
ax1.set_title('Storage Modulus - Bayesian Fit', fontsize=13, fontweight='bold')
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# Loss modulus
ax2.loglog(omega, G_double_prime_exp, 's', label='Experimental', markersize=6, alpha=0.7, zorder=3)
ax2.loglog(omega, G_double_prime_median, '-', linewidth=2, label='Posterior median', color='blue', zorder=2)
ax2.fill_between(
    omega,
    G_double_prime_lower,
    G_double_prime_upper,
    alpha=0.3,
    label='95% credible interval',
    color='blue',
    zorder=1
)
ax2.set_xlabel('ω (rad/s)', fontsize=12)
ax2.set_ylabel('G" (Pa)', fontsize=12)
ax2.set_title('Loss Modulus - Bayesian Fit', fontsize=13, fontweight='bold')
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Comparison with Local Model

Compare nonlocal predictions with the local Fluidity-Saramito model (no spatial diffusion).

In [None]:
# Fit local model
model_local = FluiditySaramitoLocal(coupling="minimal")

# Set same bounds
model_local.G.bounds = (1e2, 1e5)
model_local.eta_s.bounds = (1.0, 1e4)
model_local.tau_y.bounds = (1.0, 500.0)
model_local.t_eq.bounds = (0.1, 100.0)
model_local.b.bounds = (0.01, 10.0)
model_local.n.bounds = (0.5, 2.0)

logger.info("Fitting local model for comparison...")
result_local = model_local.fit(
    rheo_data,
    gamma_0=0.01,
    max_iter=2000,
    verbose=True
, method='scipy')

print(f"\nLocal model R²: {result_local.r_squared:.6f}")
print(f"Nonlocal model R²: {metrics['R2']:.6f}")
print(f"Improvement: {(metrics['R2'] - result_local.r_squared)*100:.2f}%")

In [None]:
# Generate local predictions
G_star_local = model_local.predict(omega, test_mode='oscillation', gamma_0=0.01)
G_prime_local = jnp.real(G_star_local)
G_double_prime_local = jnp.imag(G_star_local)

# Compare fits
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

# Storage modulus comparison
ax1.loglog(omega, G_prime_exp, 'o', label='Experimental', markersize=6, alpha=0.7)
ax1.loglog(omega, G_prime_fit, '-', linewidth=2, label='Nonlocal', color='red')
ax1.loglog(omega, G_prime_local, '--', linewidth=2, label='Local', color='green')
ax1.set_xlabel('ω (rad/s)', fontsize=12)
ax1.set_ylabel("G' (Pa)", fontsize=12)
ax1.set_title('Storage Modulus - Model Comparison', fontsize=13, fontweight='bold')
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# Loss modulus comparison
ax2.loglog(omega, G_double_prime_exp, 's', label='Experimental', markersize=6, alpha=0.7)
ax2.loglog(omega, G_double_prime_fit, '-', linewidth=2, label='Nonlocal', color='red')
ax2.loglog(omega, G_double_prime_local, '--', linewidth=2, label='Local', color='green')
ax2.set_xlabel('ω (rad/s)', fontsize=12)
ax2.set_ylabel('G" (Pa)', fontsize=12)
ax2.set_title('Loss Modulus - Model Comparison', fontsize=13, fontweight='bold')
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Save Results

In [None]:
# Create output directory
output_dir = Path("../outputs/fluidity/saramito_nonlocal/saos")
output_dir.mkdir(parents=True, exist_ok=True)

# Save NLSQ results
nlsq_file = output_dir / "nlsq_results.npz"
np.savez(
    nlsq_file,
    omega=np.array(omega),
    G_star_exp=np.array(G_star_exp),
    G_star_fit=np.array(G_star_fit),
    parameters={p.name: p.value for p in model.parameters},
    r_squared=metrics["R2"],
    rmse=metrics["RMSE"]
)
print(f"Saved NLSQ results to {nlsq_file}")

# Save Bayesian results
bayes_file = output_dir / "bayesian_results.npz"
np.savez(
    bayes_file,
    omega=np.array(omega),
    G_prime_median=np.array(G_prime_median),
    G_prime_lower=np.array(G_prime_lower),
    G_prime_upper=np.array(G_prime_upper),
    G_double_prime_median=np.array(G_double_prime_median),
    G_double_prime_lower=np.array(G_double_prime_lower),
    G_double_prime_upper=np.array(G_double_prime_upper),
    posterior_samples=bayes_result.posterior_samples,
    credible_intervals=intervals
)
print(f"Saved Bayesian results to {bayes_file}")

# Save ArviZ InferenceData if available
if idata is not None:
    arviz_file = output_dir / "inference_data.nc"
    idata.to_netcdf(arviz_file)
    print(f"Saved ArviZ InferenceData to {arviz_file}")

# Save comparison with local model
comparison_file = output_dir / "local_comparison.npz"
np.savez(
    comparison_file,
    omega=np.array(omega),
    G_star_nonlocal=np.array(G_star_fit),
    G_star_local=np.array(G_star_local),
    r_squared_nonlocal=metrics["R2"],
    r_squared_local=result_local.r_squared,
    parameters_nonlocal={p.name: p.value for p in model.parameters},
    parameters_local={p.name: p.value for p in model_local.parameters}
)
print(f"Saved local comparison to {comparison_file}")

print(f"\nAll results saved to {output_dir}")

## Key Takeaways

### Model Capabilities
1. **Linear Viscoelasticity**: FluiditySaramitoNonlocal captures frequency-dependent storage (G') and loss (G'') moduli
2. **Spatial Effects**: Fluidity diffusion (D_f) introduces spatial homogenization that can affect:
   - Effective relaxation timescale
   - High-frequency modulus plateau
   - Transition between elastic and viscous regimes
3. **Comparison with Local**: Nonlocal model provides better fit when spatial heterogeneity is significant

### Workflow Summary
1. **Data**: SAOS measurements of G'(ω) and G''(ω) across frequency range
2. **NLSQ**: Fast parameter estimation with R² > 0.99 typical
3. **Bayesian**: Quantified uncertainty with credible intervals
4. **Diagnostics**: R-hat < 1.01, ESS > 400 confirms convergence
5. **Validation**: Compare with local model to assess spatial effects

### Physical Insights
- **G'(ω)**: Dominates at high ω (elastic solid-like)
- **G''(ω)**: Dominates at low ω (viscous liquid-like)
- **Crossover frequency**: Where G' = G'' indicates characteristic relaxation timescale
- **Spatial diffusion**: D_f > 0 smooths fluidity gradients, affecting moduli transitions

### Next Steps
- Explore nonlinear regime with LAOS (large amplitude oscillatory shear)
- Investigate startup and creep protocols for transient dynamics
- Compare with flow curve measurements for yield stress validation
- Study spatial fluidity profiles and shear banding phenomena