# FluidityNonlocal: Startup Shear with Fluidity Profile Evolution

## Learning Objectives

1. **Fluidity Profile Evolution**: Track spatial distribution f(y,t) across gap during startup
2. **Shear Banding Onset**: Detect localization and band formation from fluidity gradients
3. **1D Couette Flow**: Understand wall boundary conditions and gap-averaged stress
4. **Nonlocal Effects**: Quantify diffusion D_f influence on band width and stability
5. **NLSQ + Bayesian Pipeline**: Fit startup curves with spatially-resolved fluidity diagnostics

**Physical Context**: Startup shear reveals transient localization dynamics that precede steady-state banding. The fluidity profile f(y,t) evolves from homogeneous (high fluidity) to localized (low fluidity in arrested regions), controlled by competition between destructuring (shear) and aging (thixotropy).

**Model**: FluidityNonlocal with diffusion term D_f∇²f prevents singular band interfaces.

## Setup

In [None]:
%matplotlib inline
# Colab detection and installation
try:
    import google.colab
    IN_COLAB = True
    !pip install -q rheojax nlsq numpyro arviz
except ImportError:
    IN_COLAB = False

import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# JAX with float64 (CRITICAL for numerical stability)
from rheojax.core.jax_config import safe_import_jax
jax, jnp = safe_import_jax()

from rheojax.models.fluidity import FluidityNonlocal
from rheojax.core.data import RheoData
from rheojax.logging import configure_logging, get_logger

# Configure logging
configure_logging(level="INFO")
logger = get_logger(__name__)

# Output directory
output_dir = Path("../outputs/fluidity/nonlocal/startup")
output_dir.mkdir(parents=True, exist_ok=True)

print(f"JAX devices: {jax.devices()}")
print(f"Float64 enabled: {jax.config.jax_enable_x64}")
# Flag for conditional Bayesian sections
bayesian_completed = False


## Theory: 1D Couette Flow with Fluidity Diffusion

### Governing Equations

**Stress evolution** (Maxwell backbone):
$$
\frac{\partial \sigma}{\partial t} = G \dot{\gamma}(y,t) - f(y,t) \sigma(y,t)
$$

**Fluidity evolution** (aging-rejuvenation with diffusion):
$$
\frac{\partial f}{\partial t} = \frac{1}{\tau_{\text{age}}} - \alpha f + D_f \nabla^2 f
$$
- $\alpha = a |\dot{\gamma}|^c / \tau_{\text{age}}$: Shear-induced destructuring
- $D_f$: Fluidity diffusion coefficient (nonlocal coupling)

**Mechanical equilibrium** (1D Couette):
$$
\frac{\partial \sigma}{\partial y} = 0 \quad \Rightarrow \quad \sigma(y,t) = \sigma(t) \quad \text{(uniform stress)}
$$

**Boundary conditions** (gap width h, top plate velocity V):
$$
\dot{\gamma}(0,t) = 0, \quad \int_0^h \dot{\gamma}(y,t) \, dy = V \quad \Rightarrow \quad \langle \dot{\gamma} \rangle = V/h
$$

**Startup protocol**:
- Initial condition: $f(y,0) = f_0$ (homogeneous, equilibrated)
- Applied shear rate: $\dot{\gamma}_{\text{avg}} = V/h$ (constant)
- Observe: $\sigma(t)$ (gap-averaged stress), $f(y,t)$ (spatial profile)

### Shear Banding Criterion

**Fluidity gradient threshold**:
$$
\xi = \frac{\max_y f(y) - \min_y f(y)}{\langle f(y) \rangle} > \xi_{\text{thresh}} \quad \text{(e.g., 0.3)}
$$

**Band width** (characteristic length from diffusion):
$$
\delta \sim \sqrt{D_f \tau_{\text{age}}}
$$

## Load Calibrated Parameters or Use Defaults

In [None]:
# Try to load calibrated parameters from flow_curve fitting
param_file = output_dir.parent / "flow_curve" / "fluidity_nonlocal_params.npz"

# Default parameters for demonstration
default_params = {
    'G': 1000.0,          # Elastic modulus (Pa)
    'tau_eq': 10.0,       # Aging timescale (s)
    'a': 1.0,             # Destructuring coefficient
    'c': 1.0,             # Shear rate exponent
    'f_eq': 0.1,          # Equilibrium fluidity (1/s)
    'D_f': 1e-6,          # Fluidity diffusion (m²/s)
}

# Initialize model
model = FluidityNonlocal(
    N_y=51,  # Spatial resolution
    gap_width=1e-3  # 1 mm gap
)

if param_file.exists():
    logger.info(f"Loading calibrated parameters from {param_file}")
    loaded_params = np.load(param_file)
    
    # Set parameters from file
    for key in default_params.keys():
        if key in loaded_params:
            model.parameters.set_value(key, float(loaded_params[key]))
            logger.info(f"  {key} = {loaded_params[key]:.6e}")
        else:
            model.parameters.set_value(key, default_params[key])
            logger.info(f"  {key} = {default_params[key]:.6e} (default)")
else:
    logger.warning("No calibrated parameters found, using defaults")
    
    # Set physically reasonable defaults
    for key, value in default_params.items():
        try:
            model.parameters.set_value(key, value)
        except Exception as e:
            logger.warning(f"Could not set {key}: {e}")
    
    logger.info("Using default parameters")

print(f"\nModel configuration:")
print(f"  Spatial points: {model.N_y}")
print(f"  Gap width: {model.gap_width*1e3:.2f} mm")
print(f"  Grid spacing: {model.gap_width/(model.N_y-1)*1e6:.2f} μm")

## Generate Synthetic Startup Data

In [None]:
# Startup protocol parameters
gamma_dot = 1.0  # Applied shear rate (1/s)
t_end = 100.0    # Total time (s) - capture transient and steady state
n_points = 200   # Temporal resolution

# Generate time array (logarithmic spacing to capture early transient)
t_early = np.logspace(-2, 0, 50)  # 0.01 to 1 s
t_late = np.linspace(1.0, t_end, 150)  # 1 to 100 s
t = np.unique(np.concatenate([t_early, t_late]))

logger.info(f"Simulating startup shear at γ̇ = {gamma_dot} 1/s for {t_end} s")

# Simulate startup (this populates model._f_field_trajectory)
sigma_true = model.predict(t, test_mode='startup', gamma_dot=gamma_dot)

# Add realistic noise (5% relative error)
noise_level = 0.05
np.random.seed(42)
sigma_noisy = sigma_true + noise_level * np.abs(sigma_true) * np.random.randn(len(sigma_true))

# Create RheoData object
data = RheoData(
    x=t,
    y=sigma_noisy,
    initial_test_mode='startup',
    metadata={'gamma_dot': gamma_dot, 'noise_level': noise_level}
)

logger.info(f"Generated {len(t)} data points with {noise_level*100}% noise")
logger.info(f"Stress range: {sigma_noisy.min():.2f} to {sigma_noisy.max():.2f} Pa")

# Plot synthetic data
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

# Linear time
ax1.plot(t, sigma_true, 'k-', label='True', linewidth=2)
ax1.plot(t, sigma_noisy, 'o', markersize=3, alpha=0.5, label='Noisy')
ax1.set_xlabel('Time (s)')
ax1.set_ylabel('Stress (Pa)')
ax1.set_title('Startup Shear: Linear Time')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Log time (emphasize transient)
ax2.plot(t, sigma_true, 'k-', label='True', linewidth=2)
ax2.plot(t, sigma_noisy, 'o', markersize=3, alpha=0.5, label='Noisy')
ax2.set_xlabel('Time (s)')
ax2.set_ylabel('Stress (Pa)')
ax2.set_title('Startup Shear: Log Time')
ax2.set_xscale('log')
ax2.legend()
ax2.grid(True, alpha=0.3, which='both')

plt.tight_layout()
plt.savefig(output_dir / 'synthetic_data.png', dpi=150, bbox_inches='tight')
plt.show()
plt.close('all')

print(f"\nData characteristics:")
print(f"  Peak stress: {sigma_noisy.max():.2f} Pa at t = {t[np.argmax(sigma_noisy)]:.2f} s")
print(f"  Steady-state stress: {sigma_noisy[-10:].mean():.2f} Pa")
steady_mean = sigma_noisy[-10:].mean()
print(f"  Overshoot ratio: {sigma_noisy.max() / steady_mean:.2f}" if steady_mean > 0 else "  Overshoot ratio: N/A")


## NLSQ Fitting

In [None]:
# Initialize fresh model for fitting
model_fit = FluidityNonlocal(N_y=51, gap_width=1e-3)

# Set reasonable initial values (starting points for optimization)
initial_params = {
    'G': 500.0,
    'tau_eq': 5.0,
    'a': 0.5,
    'c': 1.0,
    'f_eq': 0.05,
    'D_f': 1e-7
}

for key, value in initial_params.items():
    try:
        model_fit.parameters.set_value(key, value)
    except Exception:
        pass

logger.info("Starting NLSQ fitting...")

# Fit with test_mode='startup' and gamma_dot
result = model_fit.fit(
    data.x,
    data.y,
    test_mode='startup',
    gamma_dot=gamma_dot,
    method='scipy'
)

# Generate predictions
sigma_pred = model_fit.predict(t, test_mode='startup', gamma_dot=gamma_dot)

# Compute fit quality
from rheojax.utils.metrics import compute_fit_quality
metrics = compute_fit_quality(sigma_noisy, sigma_pred)

logger.info(f"NLSQ completed: R² = {metrics['R2']:.6f}")
logger.info(f"Fitted parameters:")
param_names = ['G', 'tau_eq', 'a', 'c', 'f_eq', 'D_f']
for name in param_names:
    try:
        value = model_fit.parameters.get_value(name)
        logger.info(f"  {name} = {value:.6e}")
    except Exception:
        pass

# Plot fit
fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(t, sigma_noisy, 'o', markersize=4, alpha=0.5, label='Data')
ax.plot(t, sigma_pred, 'r-', linewidth=2, label=f'NLSQ Fit (R² = {metrics["R2"]:.4f})')
ax.set_xlabel('Time (s)')
ax.set_ylabel('Stress (Pa)')
ax.set_title(f'NLSQ Fit: Startup Shear at γ̇ = {gamma_dot} 1/s')
ax.set_xscale('log')
ax.legend()
ax.grid(True, alpha=0.3, which='both')
plt.tight_layout()
plt.savefig(output_dir / 'nlsq_fit.png', dpi=150, bbox_inches='tight')
plt.show()
plt.close('all')

# Residual analysis
residuals = sigma_noisy - sigma_pred
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(t, residuals, 'o', markersize=3)
ax1.axhline(0, color='k', linestyle='--', alpha=0.3)
ax1.set_xlabel('Time (s)')
ax1.set_ylabel('Residuals (Pa)')
ax1.set_title('Residuals vs Time')
ax1.set_xscale('log')
ax1.grid(True, alpha=0.3)

ax2.hist(residuals, bins=30, edgecolor='k', alpha=0.7)
ax2.set_xlabel('Residuals (Pa)')
ax2.set_ylabel('Count')
ax2.set_title(f'Residual Distribution (σ = {residuals.std():.2f} Pa)')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(output_dir / 'nlsq_residuals.png', dpi=150, bbox_inches='tight')
plt.show()
plt.close('all')


## Bayesian Inference with NUTS

In [None]:
logger.info("Starting Bayesian inference (NUTS)...")

# Run NUTS (warm-started from NLSQ)

# FAST_MODE for CI: set FAST_MODE=1 env var for quick iteration
FAST_MODE = os.environ.get('FAST_MODE', '0') == '1'
_num_warmup = 50 if FAST_MODE else 200
_num_samples = 100 if FAST_MODE else 500
_num_chains = 1

if FAST_MODE:
    print('FAST_MODE: Skipping Bayesian inference (nonlocal ODE+NUTS too memory-intensive)')
    bayesian_completed = False
else:
    bayesian_result = model_fit.fit_bayesian(
        data.x,
        data.y,
        test_mode='startup',
        gamma_dot=gamma_dot,
        num_warmup=_num_warmup,
        num_samples=_num_samples,
        num_chains=1,
        seed=42
    )

    logger.info("Bayesian inference completed")

    # Extract posterior samples
    posterior = bayesian_result.posterior_samples

    # Compute credible intervals
    intervals = model_fit.get_credible_intervals(posterior, credibility=0.95)

    print("\n95% Credible Intervals:")
    for param_name, (lower, upper) in intervals.items():
        mean = float(posterior[param_name].mean())
        print(f"  {param_name}: {mean:.6e} [{lower:.6e}, {upper:.6e}]")

    # Diagnostics
    try:
        import arviz as az

        # Convert to ArviZ InferenceData
        idata = az.from_dict(posterior={k: np.array(v)[None, :] for k, v in posterior.items()})

        # R-hat (should be < 1.01)
        rhat = az.rhat(idata)
        print("\nR-hat (convergence diagnostic, target < 1.01):")
        for var in rhat.data_vars:
            print(f"  {var}: {float(rhat[var]):.4f}")

        # ESS (should be > 400 for num_samples=2000)
        ess = az.ess(idata)
        print("\nEffective Sample Size (target > 400):")
        for var in ess.data_vars:
            print(f"  {var}: {float(ess[var]):.0f}")

        # Plot posterior distributions
        az.plot_trace(idata, compact=True, figsize=(12, 10))
        plt.tight_layout()
        plt.savefig(output_dir / 'posterior_trace.png', dpi=150, bbox_inches='tight')
        plt.close('all')

        # Plot pair plot
        az.plot_pair(idata, kind='kde', figsize=(10, 10))
        plt.tight_layout()
        plt.savefig(output_dir / 'posterior_pair.png', dpi=150, bbox_inches='tight')
        plt.close('all')

    except ImportError:
        logger.warning("ArviZ not available, skipping diagnostics")

    bayesian_completed = True


## Fluidity Profile Evolution

In [None]:
if (hasattr(model_fit, '_f_field_trajectory') and model_fit._f_field_trajectory is not None
        and hasattr(model_fit, '_y_coords') and model_fit._y_coords is not None):
    # Access fluidity field trajectory (populated during last predict() call)
    f_trajectory = model_fit._f_field_trajectory  # Shape: (n_times, n_points)
    y_coords = model_fit._y_coords  # Spatial coordinates (m)

    logger.info(f"Fluidity trajectory shape: {f_trajectory.shape}")
    logger.info(f"Spatial coordinates: {len(y_coords)} points from 0 to {model_fit.gap_width*1e3:.2f} mm")

    # Select snapshots at different times (early, peak stress, steady state)
    t_snapshots = [0.1, 1.0, 10.0, t_end]
    snapshot_indices = [np.argmin(np.abs(t - t_snap)) for t_snap in t_snapshots]

    # Plot fluidity profiles
    fig, ax = plt.subplots(figsize=(8, 6))

    colors = plt.cm.viridis(np.linspace(0, 1, len(t_snapshots)))
    for i, (idx, t_snap, color) in enumerate(zip(snapshot_indices, t_snapshots, colors)):
        f_profile = f_trajectory[idx, :]
        ax.plot(y_coords * 1e3, f_profile, '-o', color=color, label=f't = {t_snap:.1f} s', markersize=4)

    ax.set_xlabel('Position across gap (mm)')
    ax.set_ylabel('Fluidity f (1/s)')
    ax.set_title('Fluidity Profile Evolution During Startup')
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(output_dir / 'fluidity_profiles.png', dpi=150, bbox_inches='tight')
    plt.show()
    plt.close('all')

    # Spatiotemporal heatmap
    fig, ax = plt.subplots(figsize=(10, 6))

    # Use only subset of times for visualization
    t_plot_indices = np.linspace(0, len(t)-1, 100, dtype=int)
    t_plot = t[t_plot_indices]
    f_plot = f_trajectory[t_plot_indices, :]

    im = ax.pcolormesh(y_coords * 1e3, t_plot, f_plot, shading='auto', cmap='plasma')
    ax.set_xlabel('Position across gap (mm)')
    ax.set_ylabel('Time (s)')
    ax.set_title('Fluidity Field f(y,t) Spatiotemporal Evolution')
    ax.set_yscale('log')
    cbar = plt.colorbar(im, ax=ax, label='Fluidity (1/s)')
    plt.tight_layout()
    plt.savefig(output_dir / 'fluidity_heatmap.png', dpi=150, bbox_inches='tight')
    plt.show()
    plt.close('all')

    print("\nFluidity profile statistics:")
    print(f"  Initial (t=0): f_min = {f_trajectory[0].min():.4f}, f_max = {f_trajectory[0].max():.4f}")
    print(f"  Final (t={t_end}): f_min = {f_trajectory[-1].min():.4f}, f_max = {f_trajectory[-1].max():.4f}")
    print(f"  Spatial variation: {(f_trajectory[-1].max() - f_trajectory[-1].min()) / f_trajectory[-1].mean():.2%}")
else:
    print('Fluidity trajectory not available (skipped during FAST_MODE)')


## Shear Banding Detection

In [None]:
if 'f_trajectory' in dir():
    def detect_shear_banding(f_profile, threshold=0.3):
        """
        Detect shear banding from fluidity profile.

        Parameters
        ----------
        f_profile : array
            Fluidity profile f(y) across gap
        threshold : float
            Normalized gradient threshold for banding detection

        Returns
        -------
        dict
            Banding diagnostics: is_banded, localization_index, band_ratio
        """
        f_mean = np.mean(f_profile)
        f_min = np.min(f_profile)
        f_max = np.max(f_profile)

        # Localization index (normalized variation)
        localization_index = (f_max - f_min) / f_mean if f_mean > 0 else 0.0

        # Band ratio (high fluidity / low fluidity)
        band_ratio = f_max / f_min if f_min > 0 else np.inf

        # Banding detected if localization exceeds threshold
        is_banded = localization_index > threshold

        return {
            'is_banded': is_banded,
            'localization_index': localization_index,
            'band_ratio': band_ratio,
            'f_mean': f_mean,
            'f_min': f_min,
            'f_max': f_max
        }

    # Analyze banding evolution
    banding_evolution = []
    for i, t_val in enumerate(t):
        diagnostics = detect_shear_banding(f_trajectory[i, :], threshold=0.3)
        diagnostics['time'] = t_val
        banding_evolution.append(diagnostics)

    # Convert to arrays for plotting
    times = np.array([d['time'] for d in banding_evolution])
    localization = np.array([d['localization_index'] for d in banding_evolution])
    is_banded = np.array([d['is_banded'] for d in banding_evolution])

    # Find banding onset time
    banding_onset_idx = np.argmax(is_banded)
    banding_onset_time = times[banding_onset_idx] if is_banded.any() else None

    # Plot localization index evolution
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), sharex=True)

    # Localization index
    ax1.plot(times, localization, 'b-', linewidth=2)
    ax1.axhline(0.3, color='r', linestyle='--', label='Banding threshold')
    if banding_onset_time is not None:
        ax1.axvline(banding_onset_time, color='g', linestyle=':', label=f'Onset at t={banding_onset_time:.2f} s')
    ax1.fill_between(times, 0, localization, where=is_banded, alpha=0.3, color='orange', label='Banded region')
    ax1.set_ylabel('Localization Index ξ')
    ax1.set_title('Shear Banding Detection')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    ax1.set_xscale('log')

    # Stress correlation
    ax2.plot(times, sigma_pred, 'k-', linewidth=2)
    if banding_onset_time is not None:
        ax2.axvline(banding_onset_time, color='g', linestyle=':', label=f'Banding onset')
    ax2.set_xlabel('Time (s)')
    ax2.set_ylabel('Stress (Pa)')
    ax2.set_title('Stress Evolution (with banding onset marker)')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    ax2.set_xscale('log')

    plt.tight_layout()
    plt.savefig(output_dir / 'shear_banding_detection.png', dpi=150, bbox_inches='tight')
    plt.show()
    plt.close('all')

    # Summary
    print("\nShear Banding Analysis:")
    if banding_onset_time is not None:
        print(f"  Banding detected: YES")
        print(f"  Onset time: {banding_onset_time:.2f} s")
        print(f"  Final localization index: {localization[-1]:.3f}")
        print(f"  Final band ratio (f_max/f_min): {banding_evolution[-1]['band_ratio']:.2f}")
    else:
        print(f"  Banding detected: NO")
        print(f"  Maximum localization index: {localization.max():.3f}")

    # Estimate band width from diffusion length
    D_f = model_fit.parameters.get_value('D_f')
    tau_eq = model_fit.parameters.get_value('tau_eq')
    band_width = np.sqrt(D_f * tau_eq)
    print(f"\nEstimated band width (δ ~ √(D_f*τ_eq)): {band_width*1e6:.2f} μm")
    print(f"  Relative to gap: {band_width/model_fit.gap_width:.2%}")
else:
    print('Skipping shear banding analysis (fluidity trajectory not available)')

## Save Results

In [None]:
if bayesian_completed:
    # Save fitted parameters
    param_names = ['G', 'tau_eq', 'a', 'c', 'f_eq', 'D_f']
    param_dict = {}
    for name in param_names:
        try:
            param_dict[name] = model_fit.parameters.get_value(name)
        except Exception:
            pass

    np.savez(
        output_dir / 'startup_params.npz',
        **param_dict,
        gamma_dot=gamma_dot,
        r_squared=metrics['R2']
    )

    # Save fluidity trajectory
    np.savez(
        output_dir / 'fluidity_trajectory.npz',
        t=t,
        y=y_coords,
        f_trajectory=np.array(f_trajectory),
        sigma=sigma_pred
    )

    # Save banding diagnostics
    np.savez(
        output_dir / 'banding_diagnostics.npz',
        times=times,
        localization_index=localization,
        is_banded=is_banded,
        onset_time=banding_onset_time if banding_onset_time else -1.0
    )

    # Save posterior samples
    if bayesian_result is not None:
        np.savez(
            output_dir / 'posterior_samples.npz',
            **posterior
        )

    logger.info(f"Results saved to {output_dir}")
    print(f"\nAll results saved to: {output_dir}")
else:
    print('Skipping Bayesian diagnostics (inference was skipped)')


## Key Takeaways

### Fluidity Profile Evolution

1. **Initial Homogeneity**: At t=0, fluidity is spatially uniform (f(y,0) = f_eq)
2. **Shear-Induced Localization**: Regions with higher shear rate experience faster destructuring (lower fluidity)
3. **Diffusion Smoothing**: D_f prevents singular band interfaces, creating finite band width δ ~ √(D_f*τ_age)
4. **Steady-State Banding**: At long times, fluidity profile stabilizes with distinct high/low regions

### Shear Banding Onset

**Criterion**: Localization index ξ = (f_max - f_min)/⟨f⟩ exceeds threshold (~0.3)

**Onset Time**: Typically occurs after stress peak, during approach to steady state

**Stress Signature**: Banding onset often correlates with stress plateau or slight decrease

### Model Parameters from Startup

- **G**: Elastic modulus controls initial stress rise (linear regime)
- **τ_age**: Aging timescale sets time to stress peak
- **a, c**: Destructuring controls peak magnitude and overshoot ratio
- **D_f**: Diffusion coefficient determines band width and stability
- **f_eq**: Equilibrium fluidity sets steady-state stress level

### NLSQ + Bayesian Workflow

1. **NLSQ**: Fast point estimation with gap-averaged stress σ(t)
2. **NUTS**: Bayesian uncertainty quantification (R-hat < 1.01, ESS > 400)
3. **Fluidity Access**: Use `model._f_field_trajectory` for spatially-resolved diagnostics
4. **Validation**: Check banding predictions against known SGR/ITT-MCT regimes

### Experimental Connections

- **Rheo-PIV**: Compare predicted f(y,t) gradients with velocity profiles
- **Rheo-NMR**: Validate spatial localization with MRI measurements
- **Ultrasonic Velocimetry**: Time-resolved band position from Doppler shifts
- **Stress Overshoot**: Magnitude indicates strength of thixotropic memory