# TensorialEPM Usage Examples

This notebook demonstrates the full capabilities of the **TensorialEPM** model in RheoJAX, which extends the scalar Lattice EPM to track the complete stress tensor [σ_xx, σ_yy, σ_xy].

**Key Features**:
- Prediction of normal stress differences (N₁, N₂)
- Von Mises and Hill anisotropic yield criteria
- Flexible fitting modes (shear-only or combined)
- Comprehensive visualization tools

**Contents**:
1. Basic Flow Curve with N₁ Prediction
2. Fitting to Shear-Only Data (Backward Compatible)
3. Fitting to Combined [σ_xy, N₁] Data
4. Visualization Gallery
5. Comparison of Von Mises vs Hill Criteria
6. Animation of Avalanche Dynamics

In [None]:
# Import required modules
import numpy as np
import matplotlib.pyplot as plt

from rheojax.core.jax_config import safe_import_jax
from rheojax.models.epm.tensor import TensorialEPM
from rheojax.models.epm.lattice import LatticeEPM
from rheojax.core.data import RheoData
from rheojax.visualization.epm_plots import (
    plot_lattice_fields,
    plot_tensorial_fields,
    plot_normal_stress_field,
    plot_von_mises_field,
    plot_normal_stress_ratio,
    animate_tensorial_evolution,
)

jax, jnp = safe_import_jax()

# Set plot style
plt.style.use('seaborn-v0_8-darkgrid')
%matplotlib inline

## Example 1: Basic Flow Curve with N₁ Prediction

The simplest use case: predict shear stress and normal stress differences for a range of shear rates.

In [None]:
# Initialize TensorialEPM with default parameters
model = TensorialEPM(
    L=32,  # Smaller lattice for faster computation
    dt=0.01,
    mu=1.0,
    nu=0.48,
    sigma_c_mean=1.0,
    sigma_c_std=0.15,
)

# Define shear rates (logarithmic spacing)
shear_rates = np.logspace(-2, 1, 10)
data = RheoData(x=shear_rates, y=None, initial_test_mode="flow_curve")

# Predict flow curve (smooth mode for differentiable predictions)
print("Running flow curve simulation...")
result = model.predict(data, smooth=True, seed=42)

# Extract results
sigma_xy = result.y
N1 = result.metadata["N1"]

print(f"Shear rates: {shear_rates}")
print(f"Shear stress σ_xy: {sigma_xy}")
print(f"Normal stress N₁: {N1}")

In [None]:
# Plot flow curve and normal stress
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Shear stress vs shear rate
axes[0].loglog(shear_rates, sigma_xy, 'o-', label='σ_xy (Shear Stress)', linewidth=2)
axes[0].set_xlabel('Shear Rate γ̇ [1/s]', fontsize=12)
axes[0].set_ylabel('Shear Stress σ_xy [Pa]', fontsize=12)
axes[0].set_title('Flow Curve', fontsize=14)
axes[0].grid(True, which='both', alpha=0.3)
axes[0].legend()

# Normal stress difference vs shear rate
axes[1].loglog(shear_rates, N1, 's-', color='C1', label='N₁', linewidth=2)
axes[1].set_xlabel('Shear Rate γ̇ [1/s]', fontsize=12)
axes[1].set_ylabel('First Normal Stress Difference N₁ [Pa]', fontsize=12)
axes[1].set_title('Normal Stress Difference', fontsize=14)
axes[1].grid(True, which='both', alpha=0.3)
axes[1].legend()

plt.tight_layout()
plt.show()

# Check power-law behavior
# Fit σ_xy ~ γ̇^n and N₁ ~ γ̇^m
log_rates = np.log10(shear_rates[2:])  # Skip low rates
log_sigma = np.log10(sigma_xy[2:])
log_N1 = np.log10(N1[2:])

n_shear = np.polyfit(log_rates, log_sigma, 1)[0]
n_N1 = np.polyfit(log_rates, log_N1, 1)[0]

print(f"\nPower-law exponents:")
print(f"  σ_xy ~ γ̇^{n_shear:.2f}")
print(f"  N₁ ~ γ̇^{n_N1:.2f}")
print(f"\nTypical behavior: σ_xy ~ γ̇^0.5-0.9 (shear thinning), N₁ ~ γ̇^1.0-2.0")

## Example 2: Fitting to Shear-Only Data (Backward Compatible)

When only shear stress data is available, TensorialEPM can be fitted like the scalar LatticeEPM.

In [None]:
# Generate synthetic experimental data (shear stress only)
gamma_dot_exp = np.array([0.01, 0.1, 0.5, 1.0, 5.0, 10.0])

# Herschel-Bulkley: σ = σ_y + K*γ̇^n
sigma_y = 0.5
K = 0.8
n = 0.7
sigma_xy_exp = sigma_y + K * gamma_dot_exp**n + 0.05 * np.random.randn(len(gamma_dot_exp))

print("Synthetic experimental data:")
for rate, stress in zip(gamma_dot_exp, sigma_xy_exp):
    print(f"  γ̇ = {rate:5.2f} 1/s → σ_xy = {stress:.3f} Pa")

In [None]:
# Create RheoData
rheo_data = RheoData(x=gamma_dot_exp, y=sigma_xy_exp, initial_test_mode="flow_curve")

# Initialize model with reasonable bounds
model_fit = TensorialEPM(L=32, dt=0.01)

# Fit to shear stress data
print("\nFitting TensorialEPM to shear-only data...")
model_fit.fit(rheo_data, max_iter=50, method='scipy')

# Print fitted parameters
print("\nFitted parameters:")
for param_name in ["mu", "sigma_c_mean", "sigma_c_std", "tau_pl_shear"]:
    value = model_fit.params.get_value(param_name)
    print(f"  {param_name}: {value:.4f}")

# Predict with fitted model
pred_result = model_fit.predict(rheo_data, smooth=True)
sigma_xy_pred = pred_result.y

# Calculate R²
ss_res = np.sum((sigma_xy_exp - sigma_xy_pred)**2)
ss_tot = np.sum((sigma_xy_exp - np.mean(sigma_xy_exp))**2)
r_squared = 1 - ss_res / ss_tot
print(f"\nR² = {r_squared:.4f}")

In [None]:
# Plot fit quality
fig, ax = plt.subplots(figsize=(8, 6))

ax.loglog(gamma_dot_exp, sigma_xy_exp, 'o', markersize=10, label='Experimental Data', zorder=3)
ax.loglog(gamma_dot_exp, sigma_xy_pred, 's--', markersize=8, label='TensorialEPM Fit', zorder=2)

ax.set_xlabel('Shear Rate γ̇ [1/s]', fontsize=12)
ax.set_ylabel('Shear Stress σ_xy [Pa]', fontsize=12)
ax.set_title(f'Shear-Only Fitting (R² = {r_squared:.4f})', fontsize=14)
ax.grid(True, which='both', alpha=0.3)
ax.legend(fontsize=11)

plt.tight_layout()
plt.show()

## Example 3: Fitting to Combined [σ_xy, N₁] Data

When normal stress measurements are available, we can use them to constrain the model parameters.

**Note**: Current implementation fits shear first, then validates N₁. Future versions will support multi-objective fitting.

In [None]:
# Generate synthetic data with normal stresses
gamma_dot_combined = np.array([0.1, 0.5, 1.0, 5.0, 10.0])

# Shear stress: σ_xy = σ_y + K*γ̇^n
sigma_xy_combined = 0.5 + 0.8 * gamma_dot_combined**0.7

# Normal stress: N₁ ~ γ̇^m (typically m > n)
N1_combined = 0.3 * gamma_dot_combined**1.2 + 0.02 * np.random.randn(len(gamma_dot_combined))

print("Synthetic data with normal stresses:")
for rate, sigma, N1 in zip(gamma_dot_combined, sigma_xy_combined, N1_combined):
    print(f"  γ̇ = {rate:5.2f} 1/s → σ_xy = {sigma:.3f} Pa, N₁ = {N1:.3f} Pa")

In [None]:
# Best practice: Fit shear first, then check N₁ predictions
rheo_combined = RheoData(x=gamma_dot_combined, y=sigma_xy_combined, initial_test_mode="flow_curve")

model_combined = TensorialEPM(L=32, dt=0.01, w_N1=2.0)  # Higher weight for future N₁ fitting

print("\nFitting to shear stress...")
model_combined.fit(rheo_combined, max_iter=100, method='scipy')

# Predict both σ_xy and N₁
pred_combined = model_combined.predict(rheo_combined, smooth=True)
sigma_xy_pred_combined = pred_combined.y
N1_pred_combined = pred_combined.metadata["N1"]

# Calculate errors
rmse_sigma = np.sqrt(np.mean((sigma_xy_pred_combined - sigma_xy_combined)**2))
rmse_N1 = np.sqrt(np.mean((N1_pred_combined - N1_combined)**2))

print(f"\nPrediction errors:")
print(f"  σ_xy RMSE: {rmse_sigma:.4f} Pa")
print(f"  N₁ RMSE: {rmse_N1:.4f} Pa")
print(f"  Relative N₁ error: {rmse_N1 / np.mean(N1_combined) * 100:.1f}%")

In [None]:
# Plot combined fit
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Shear stress
axes[0].loglog(gamma_dot_combined, sigma_xy_combined, 'o', markersize=10, label='Experimental σ_xy')
axes[0].loglog(gamma_dot_combined, sigma_xy_pred_combined, 's--', markersize=8, label='Predicted σ_xy')
axes[0].set_xlabel('Shear Rate γ̇ [1/s]', fontsize=12)
axes[0].set_ylabel('Shear Stress σ_xy [Pa]', fontsize=12)
axes[0].set_title(f'Shear Stress (RMSE = {rmse_sigma:.3f} Pa)', fontsize=13)
axes[0].grid(True, which='both', alpha=0.3)
axes[0].legend()

# Normal stress
axes[1].loglog(gamma_dot_combined, N1_combined, 'o', markersize=10, color='C1', label='Experimental N₁')
axes[1].loglog(gamma_dot_combined, N1_pred_combined, 's--', markersize=8, color='C3', label='Predicted N₁')
axes[1].set_xlabel('Shear Rate γ̇ [1/s]', fontsize=12)
axes[1].set_ylabel('Normal Stress N₁ [Pa]', fontsize=12)
axes[1].set_title(f'Normal Stress (RMSE = {rmse_N1:.3f} Pa)', fontsize=13)
axes[1].grid(True, which='both', alpha=0.3)
axes[1].legend()

plt.tight_layout()
plt.show()

## Example 4: Visualization Gallery

RheoJAX provides comprehensive visualization tools for tensorial stress fields.

In [None]:
# Run a simulation to get a stress field snapshot
model_viz = TensorialEPM(L=32, dt=0.01, mu=1.0, nu=0.48, sigma_c_std=0.2)

# Simulate steady state at γ̇ = 1.0
shear_rate_single = np.array([1.0])
data_single = RheoData(x=shear_rate_single, y=None, initial_test_mode="flow_curve")

print("Running simulation to generate stress field...")
result_viz = model_viz.predict(data_single, smooth=False, seed=123)

# Note: To access intermediate stress fields, we need to extract from simulation history
# For demonstration, we'll create a synthetic stress field with realistic structure
L = 32
x, y = np.meshgrid(np.arange(L), np.arange(L), indexing='ij')

# Create spatially correlated stress field
np.random.seed(42)
stress_field = np.zeros((3, L, L))

# σ_xx: Extensional pattern with disorder
stress_field[0] = 0.5 * np.sin(2 * np.pi * x / L) + 0.2 * np.random.randn(L, L)

# σ_yy: Compressional (opposite sign to σ_xx)
stress_field[1] = -0.5 * np.sin(2 * np.pi * x / L) + 0.2 * np.random.randn(L, L)

# σ_xy: Shear (dominant component)
stress_field[2] = 1.5 * np.cos(2 * np.pi * x / L) * np.sin(2 * np.pi * y / L) + 0.3 * np.random.randn(L, L)

# Yield thresholds with disorder
thresholds = np.abs(1.0 + 0.2 * np.random.randn(L, L))

print(f"Stress field shape: {stress_field.shape}")
print(f"Mean stresses: σ_xx={np.mean(stress_field[0]):.3f}, σ_yy={np.mean(stress_field[1]):.3f}, σ_xy={np.mean(stress_field[2]):.3f}")

In [None]:
# Visualization 1: Auto-detection of tensorial fields
print("Plot 1: Auto-detection (3-panel tensorial view)")
fig1 = plot_lattice_fields(stress_field, thresholds)
plt.show()

In [None]:
# Visualization 2: Tensorial field components
print("\nPlot 2: Tensorial field components (detailed view)")
fig2, axes2 = plot_tensorial_fields(stress_field)
plt.show()

In [None]:
# Visualization 3: Normal stress difference N₁
print("\nPlot 3: First normal stress difference N₁ = σ_xx - σ_yy")
fig3, ax3 = plot_normal_stress_field(stress_field, nu=0.48)
plt.show()

In [None]:
# Visualization 4: Von Mises effective stress
print("\nPlot 4: Von Mises effective stress and yield map")
fig4, axes4 = plot_von_mises_field(stress_field, thresholds, nu=0.48)
plt.show()

In [None]:
# Visualization 5: Normal stress ratio vs shear rate
print("\nPlot 5: Normal stress ratio N₁/σ_xy vs shear rate")

# Use data from Example 1
N1_ratio = N1 / sigma_xy

fig5, ax5 = plot_normal_stress_ratio(shear_rates, N1, sigma_xy)
plt.show()

print(f"\nN₁/σ_xy ratio range: {np.min(N1_ratio):.3f} to {np.max(N1_ratio):.3f}")
print("Typical polymers: N₁/σ_xy ~ 0.1-2.0")

## Example 5: Comparison of Von Mises vs Hill Criteria

Compare isotropic (von Mises) and anisotropic (Hill) yield criteria.

In [None]:
# Prepare test data
gamma_dot_test = np.logspace(-1, 1, 8)
data_test = RheoData(x=gamma_dot_test, y=None, initial_test_mode="flow_curve")

# Model 1: Von Mises (isotropic)
model_vm = TensorialEPM(
    L=32,
    dt=0.01,
    mu=1.0,
    nu=0.48,
    sigma_c_mean=1.0,
    sigma_c_std=0.15,
    yield_criterion="von_mises"
)

print("Simulating with Von Mises criterion...")
result_vm = model_vm.predict(data_test, smooth=True, seed=100)
sigma_vm = result_vm.y
N1_vm = result_vm.metadata["N1"]

# Model 2: Hill (anisotropic)
model_hill = TensorialEPM(
    L=32,
    dt=0.01,
    mu=1.0,
    nu=0.48,
    sigma_c_mean=1.0,
    sigma_c_std=0.15,
    yield_criterion="hill"
)

# Set anisotropic parameters
model_hill.params.set_value("hill_H", 1.5)  # Stronger normal stress coupling
model_hill.params.set_value("hill_N", 2.0)  # Modified shear response

print("Simulating with Hill criterion (H=1.5, N=2.0)...")
result_hill = model_hill.predict(data_test, smooth=True, seed=100)
sigma_hill = result_hill.y
N1_hill = result_hill.metadata["N1"]

print("\nDone!")

In [None]:
# Compare predictions
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Shear stress comparison
axes[0].loglog(gamma_dot_test, sigma_vm, 'o-', label='Von Mises (Isotropic)', linewidth=2, markersize=8)
axes[0].loglog(gamma_dot_test, sigma_hill, 's--', label='Hill (Anisotropic H=1.5, N=2.0)', linewidth=2, markersize=8)
axes[0].set_xlabel('Shear Rate γ̇ [1/s]', fontsize=12)
axes[0].set_ylabel('Shear Stress σ_xy [Pa]', fontsize=12)
axes[0].set_title('Yield Criterion Effect on Shear Stress', fontsize=13)
axes[0].grid(True, which='both', alpha=0.3)
axes[0].legend(fontsize=10)

# Normal stress comparison
axes[1].loglog(gamma_dot_test, N1_vm, 'o-', label='Von Mises N₁', linewidth=2, markersize=8, color='C0')
axes[1].loglog(gamma_dot_test, N1_hill, 's--', label='Hill N₁', linewidth=2, markersize=8, color='C1')
axes[1].set_xlabel('Shear Rate γ̇ [1/s]', fontsize=12)
axes[1].set_ylabel('Normal Stress N₁ [Pa]', fontsize=12)
axes[1].set_title('Yield Criterion Effect on Normal Stress', fontsize=13)
axes[1].grid(True, which='both', alpha=0.3)
axes[1].legend(fontsize=10)

plt.tight_layout()
plt.show()

# Quantify differences
rel_diff_sigma = np.mean(np.abs(sigma_hill - sigma_vm) / sigma_vm) * 100
rel_diff_N1 = np.mean(np.abs(N1_hill - N1_vm) / N1_vm) * 100

print(f"\nAverage relative differences:")
print(f"  Shear stress: {rel_diff_sigma:.1f}%")
print(f"  Normal stress: {rel_diff_N1:.1f}%")
print(f"\nInterpretation:")
print(f"  - Hill criterion with H≠1, N≠3 changes the effective yield threshold")
print(f"  - Stronger coupling to normal stresses (H>1) increases anisotropy")
print(f"  - Use Hill for materials with directional microstructure (fibers, LCP)")

## Example 6: Animation of Avalanche Dynamics

Create an animation showing the evolution of stress fields during shear startup.

**Note**: This creates an animated GIF. For notebook display, use `HTML(anim.to_jshtml())`.

In [None]:
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

# Create synthetic stress evolution (shear startup)
print("Creating synthetic stress evolution...")

L = 16  # Small lattice for faster animation
T = 30  # Number of time frames
time = np.linspace(0, 3.0, T)

stress_history = np.zeros((T, 3, L, L))
x_grid, y_grid = np.meshgrid(np.arange(L), np.arange(L), indexing='ij')

for t_idx, t in enumerate(time):
    # Build up stress over time (exponential approach to steady state)
    amplitude = 1.0 - np.exp(-t)
    
    # Add spatial patterns + disorder
    stress_history[t_idx, 0] = 0.3 * amplitude * np.sin(2 * np.pi * x_grid / L) + 0.05 * np.random.randn(L, L)
    stress_history[t_idx, 1] = -0.3 * amplitude * np.sin(2 * np.pi * x_grid / L) + 0.05 * np.random.randn(L, L)
    stress_history[t_idx, 2] = amplitude * np.cos(2 * np.pi * x_grid / L) * np.sin(2 * np.pi * y_grid / L) + 0.1 * np.random.randn(L, L)

history_dict = {
    'stress': stress_history,
    'time': time
}

print(f"Stress history shape: {stress_history.shape}")
print(f"Time points: {T}")

In [None]:
# Create animation (all components)
print("Creating animation of all stress components...")
anim_all = animate_tensorial_evolution(history_dict, component='all', interval=100, nu=0.48)

# Display in notebook
HTML(anim_all.to_jshtml())

In [None]:
# Create animation (von Mises effective stress only)
print("Creating animation of von Mises stress...")
anim_vm = animate_tensorial_evolution(history_dict, component='vm', interval=100, nu=0.48)

# Display in notebook
HTML(anim_vm.to_jshtml())

In [None]:
# Optionally save to file
# anim_all.save('tensorial_evolution.gif', writer='pillow', fps=10)
# print("Animation saved to tensorial_evolution.gif")

## Summary and Best Practices

### When to Use TensorialEPM

1. **Normal stress data available**: N₁ or N₂ measurements
2. **Anisotropic materials**: Fiber suspensions, liquid crystalline polymers
3. **Flow instabilities**: Shear banding, edge fracture analysis
4. **3D flow predictions**: Rod climbing, die swell, secondary flows

### Performance Tips

- **Lattice size**: Use L=32 for fitting, L=64+ for production simulations
- **Smooth mode**: Always use `smooth=True` for fitting (differentiable)
- **Reproducibility**: Set `seed` parameter for deterministic results
- **GPU acceleration**: Automatically used if JAX detects CUDA

### Fitting Strategy

1. **Shear-first approach**: Fit σ_xy data first (fast)
2. **Validate N₁**: Check normal stress predictions with fitted parameters
3. **Refine if needed**: Adjust `w_N1` weight and re-fit (future feature)
4. **Check convergence**: Monitor R² and RMSE for both σ_xy and N₁

### Visualization Workflow

1. **Auto-detection**: Use `plot_lattice_fields()` for quick inspection
2. **Component analysis**: Use `plot_tensorial_fields()` for detailed view
3. **Von Mises**: Use `plot_von_mises_field()` to identify yielding regions
4. **Normal stress ratio**: Use `plot_normal_stress_ratio()` for power-law analysis
5. **Animations**: Use `animate_tensorial_evolution()` for dynamics

### Further Reading

- **Handbook**: See `docs/source/models/epm/tensorial_epm.rst` for theory
- **API Reference**: Full parameter documentation in `docs/source/api/models.rst`
- **Comparison**: LatticeEPM vs TensorialEPM decision tree in handbook
- **Examples**: Visualization demo in `examples/tensorial_epm_visualization_demo.py`