# Creating Custom Rheological Models

Learn to create custom rheological models by inheriting from `BaseModel` with automatic Bayesian capabilities.

## Learning Objectives
- Understand the Burgers model theory and its physical meaning
- Inherit from `BaseModel` to create a custom rheological model
- Implement `_fit()` and `_predict()` methods with NLSQ optimization
- Register models with `ModelRegistry` for ecosystem integration
- Write comprehensive tests for custom models
- Leverage automatic Bayesian inference via `BayesianMixin`
- Integrate with Pipeline API for production workflows
- Apply performance optimization with JAX `@jit` compilation

## Prerequisites
- Model fitting basics (Phase 1: `01-maxwell-fitting.ipynb`)
- Understanding of Zener and Maxwell models
- Basic Python class inheritance

**Estimated Time:** 50-55 minutes

## Section 1: Burgers Model Theory (10 min)

### What is the Burgers Model?

The **Burgers model** is a 4-parameter viscoelastic model combining:
- **Zener element** (Standard Linear Solid): equilibrium modulus with relaxation
- **Maxwell element** in series: provides long-term viscous flow

### Physical Interpretation

```
Mechanical Analog:

  Zener Element         Maxwell Element
┌────────────────┐     ┌──────────────┐
│  ┌─Ge─┐        │     │              │
│  │    │  ┌─η1─┐│     │              │
├──┤    ├──┤    ││  ───┤  ┌─Gm─┬─η2─┐│
│  │    │  └────┘│     │  │    │     ││
│  └────┘        │     │  └────┴─────┘│
└────────────────┘     └──────────────┘
   (in parallel)          (in series)
```

### Mathematical Formulation

#### Relaxation Modulus
$$G(t) = G_e + G_m \exp\left(-\frac{t}{\tau_m}\right)$$

where:
- $G_e$: Equilibrium modulus (long-term elasticity)
- $G_m$: Maxwell modulus (transient elasticity)
- $\tau_m = \eta_1 / G_m$: Maxwell relaxation time

#### Creep Compliance
$$J(t) = \frac{1}{G_e + G_m} + \frac{1}{G_e}\left(1 - \exp\left(-\frac{t}{\tau_r}\right)\right) + \frac{t}{\eta_2}$$

where:
- $\tau_r = \eta_1(G_e + G_m)/(G_e G_m)$: Retardation time
- $\eta_2$: Terminal viscosity (long-term flow)

### Four Parameters

1. **$G_e$** (Pa): Equilibrium modulus - material stiffness at long times
2. **$G_m$** (Pa): Maxwell modulus - transient elastic response
3. **$\eta_1$** (Pa·s): Maxwell viscosity - controls relaxation rate
4. **$\eta_2$** (Pa·s): Terminal viscosity - controls long-term flow

### Applications
- **Polymers**: Captures both elastic recovery and viscous flow
- **Soft solids**: Models materials with significant creep
- **Complex fluids**: Describes time-dependent viscosity

In [None]:
# Essential imports for custom model development
from rheojax.core.base import BaseModel
from rheojax.core.parameters import Parameter, ParameterSet
from rheojax.core.registry import ModelRegistry
from rheojax.core.jax_config import safe_import_jax
from rheojax.core.data import RheoData
from rheojax.core.test_modes import TestMode
from rheojax.utils.optimization import create_least_squares_objective, nlsq_optimize
import numpy as np

# Always use safe JAX imports in Rheo modules
jax, jnp = safe_import_jax()

print('✓ Imports for custom model development')
print(f'  JAX version: {jax.__version__}')
print(f'  JAX devices: {jax.devices()}')

## Section 2: Complete Burgers Model Implementation (15 min)

We'll implement the full 4-parameter Burgers model following Rheo's architecture patterns.

In [None]:
@ModelRegistry.register('burgers')
class BurgersModel(BaseModel):
    """Burgers viscoelastic model (Zener + Maxwell in series).
    
    The Burgers model combines a Zener element (equilibrium spring Ge with
    Maxwell element Gm-eta1 in parallel) in series with a Maxwell element (Gm-eta2).
    This provides both elastic recovery and long-term viscous flow.
    
    Parameters:
        Ge (float): Equilibrium modulus in Pa, range [1e2, 1e9], default 1e4
        Gm (float): Maxwell modulus in Pa, range [1e3, 1e9], default 5e4
        eta1 (float): Maxwell viscosity in Pa·s, range [1e1, 1e12], default 1e3
        eta2 (float): Terminal viscosity in Pa·s, range [1e2, 1e12], default 1e4
    
    Supported test modes:
        - Relaxation: Stress relaxation under constant strain
        - Creep: Strain development under constant stress
    
    Example:
        >>> model = BurgersModel()
        >>> model.fit(t, G_data)  # NLSQ optimization
        >>> result = model.fit_bayesian(t, G_data)  # Bayesian inference
        >>> G_pred = model.predict(t_new)
    
    References:
        - Ferry, J. D. (1980). Viscoelastic Properties of Polymers.
        - Tschoegl, N. W. (1989). The Phenomenological Theory of Linear 
          Viscoelastic Behavior.
    """
    
    def __init__(self):
        """Initialize Burgers model with default parameters."""
        super().__init__()
        
        # Define parameters with physical bounds
        self.parameters = ParameterSet()
        self.parameters.add(
            name='Ge',
            value=1e4,
            bounds=(1e2, 1e9),
            units='Pa',
            description='Equilibrium modulus'
        )
        self.parameters.add(
            name='Gm',
            value=5e4,
            bounds=(1e3, 1e9),
            units='Pa',
            description='Maxwell modulus'
        )
        self.parameters.add(
            name='eta1',
            value=1e3,
            bounds=(1e1, 1e12),
            units='Pa·s',
            description='Maxwell viscosity'
        )
        self.parameters.add(
            name='eta2',
            value=1e4,
            bounds=(1e2, 1e12),
            units='Pa·s',
            description='Terminal viscosity'
        )
        
        self.fitted_ = False
        self._test_mode = TestMode.RELAXATION  # Store for Bayesian inference
    
    def _fit(self, X, y, **kwargs):
        """Fit Burgers model to data using NLSQ optimization.
        
        Args:
            X: RheoData object or independent variable array
            y: Dependent variable array (if X is not RheoData)
            **kwargs: Additional fitting options
                - method: Optimization method (default: 'nlsq')
                - max_iter: Maximum iterations (default: 1000)
                - use_jax: Enable JAX acceleration (default: True)
        
        Returns:
            self for method chaining
        """
        # Handle RheoData input
        if isinstance(X, RheoData):
            rheo_data = X
            x_data = jnp.array(rheo_data.x)
            y_data = jnp.array(rheo_data.y)
            test_mode = rheo_data.test_mode
        else:
            x_data = jnp.array(X)
            y_data = jnp.array(y)
            test_mode = kwargs.get('test_mode', TestMode.RELAXATION)
        
        # Store test mode for model_function (Bayesian inference)
        self._test_mode = test_mode
        
        # Create objective function with stateless predictions
        def model_fn(x, params):
            """Model function for optimization (stateless)."""
            Ge, Gm, eta1, eta2 = params[0], params[1], params[2], params[3]
            
            # Direct prediction based on test mode
            if test_mode == TestMode.RELAXATION:
                return self._predict_relaxation(x, Ge, Gm, eta1, eta2)
            elif test_mode == TestMode.CREEP:
                return self._predict_creep(x, Ge, Gm, eta1, eta2)
            else:
                raise ValueError(f'Unsupported test mode: {test_mode}')
        
        # Create least squares objective
        objective = create_least_squares_objective(
            model_fn, x_data, y_data, normalize=True
        )
        
        # Optimize using NLSQ (GPU-accelerated)
        result = nlsq_optimize(
            objective,
            self.parameters,
            use_jax=kwargs.get('use_jax', True),
            method=kwargs.get('method', 'auto'),
            max_iter=kwargs.get('max_iter', 1000),
        )
        
        # Store result for diagnostics
        self._nlsq_result = result
        self.fitted_ = True
        
        return self
    
    def _predict(self, X):
        """Predict response based on input data.
        
        Args:
            X: RheoData object or independent variable array
        
        Returns:
            Predicted values as JAX array
        """
        # Handle RheoData input
        if isinstance(X, RheoData):
            rheo_data = X
            x_data = jnp.array(rheo_data.x)
            test_mode = rheo_data.test_mode
        else:
            x_data = jnp.array(X)
            test_mode = TestMode.RELAXATION  # Default
        
        # Get parameter values
        Ge = self.parameters.get_value('Ge')
        Gm = self.parameters.get_value('Gm')
        eta1 = self.parameters.get_value('eta1')
        eta2 = self.parameters.get_value('eta2')
        
        # Dispatch to appropriate prediction method
        if test_mode == TestMode.RELAXATION:
            return self._predict_relaxation(x_data, Ge, Gm, eta1, eta2)
        elif test_mode == TestMode.CREEP:
            return self._predict_creep(x_data, Ge, Gm, eta1, eta2)
        else:
            raise ValueError(f'Unsupported test mode: {test_mode}')
    
    def model_function(self, X, params):
        """Model function for Bayesian inference.
        
        This method is required by BayesianMixin for NumPyro NUTS sampling.
        It computes predictions given input X and a parameter array.
        
        Args:
            X: Independent variable (time for relaxation/creep)
            params: Array of parameter values [Ge, Gm, eta1, eta2]
        
        Returns:
            Model predictions as JAX array
        """
        # Extract parameters from array
        Ge, Gm, eta1, eta2 = params[0], params[1], params[2], params[3]
        
        # Use stored test mode from last fit
        test_mode = getattr(self, '_test_mode', TestMode.RELAXATION)
        
        # Dispatch to appropriate prediction method
        if test_mode == TestMode.RELAXATION:
            return self._predict_relaxation(X, Ge, Gm, eta1, eta2)
        elif test_mode == TestMode.CREEP:
            return self._predict_creep(X, Ge, Gm, eta1, eta2)
        else:
            raise ValueError(f'Unsupported test mode: {test_mode}')
    
    @staticmethod
    @jax.jit
    def _predict_relaxation(
        t: jnp.ndarray, Ge: float, Gm: float, eta1: float, eta2: float
    ) -> jnp.ndarray:
        """Predict relaxation modulus G(t).
        
        Theory: G(t) = Ge + Gm * exp(-t/tau_m)
        where tau_m = eta1/Gm is the Maxwell relaxation time
        
        Args:
            t: Time array (s)
            Ge: Equilibrium modulus (Pa)
            Gm: Maxwell modulus (Pa)
            eta1: Maxwell viscosity (Pa·s)
            eta2: Terminal viscosity (Pa·s) - not used in relaxation
        
        Returns:
            Relaxation modulus G(t) in Pa
        """
        tau_m = eta1 / Gm  # Maxwell relaxation time
        return Ge + Gm * jnp.exp(-t / tau_m)
    
    @staticmethod
    @jax.jit
    def _predict_creep(
        t: jnp.ndarray, Ge: float, Gm: float, eta1: float, eta2: float
    ) -> jnp.ndarray:
        """Predict creep compliance J(t).
        
        Theory:
        J(t) = 1/(Ge + Gm) + (1/Ge) * (1 - exp(-t/tau_r)) + t/eta2
        
        where tau_r = eta1*(Ge + Gm)/(Ge*Gm) is the retardation time
        
        Args:
            t: Time array (s)
            Ge: Equilibrium modulus (Pa)
            Gm: Maxwell modulus (Pa)
            eta1: Maxwell viscosity (Pa·s)
            eta2: Terminal viscosity (Pa·s)
        
        Returns:
            Creep compliance J(t) in 1/Pa
        """
        # Instantaneous compliance
        J0 = 1.0 / (Ge + Gm)
        
        # Retardation time
        tau_r = eta1 * (Ge + Gm) / (Ge * Gm)
        
        # Delayed elastic compliance
        J_delayed = (1.0 / Ge) * (1.0 - jnp.exp(-t / tau_r))
        
        # Viscous flow compliance
        J_viscous = t / eta2
        
        return J0 + J_delayed + J_viscous
    
    def get_relaxation_time(self) -> float:
        """Get Maxwell relaxation time tau_m = eta1/Gm.
        
        Returns:
            Relaxation time in seconds
        """
        Gm = self.parameters.get_value('Gm')
        eta1 = self.parameters.get_value('eta1')
        return eta1 / Gm
    
    def get_retardation_time(self) -> float:
        """Get retardation time for creep.
        
        Theory: tau_r = eta1 * (Ge + Gm) / (Ge * Gm)
        
        Returns:
            Retardation time in seconds
        """
        Ge = self.parameters.get_value('Ge')
        Gm = self.parameters.get_value('Gm')
        eta1 = self.parameters.get_value('eta1')
        return eta1 * (Ge + Gm) / (Ge * Gm)
    
    def __repr__(self) -> str:
        """String representation of Burgers model."""
        Ge = self.parameters.get_value('Ge')
        Gm = self.parameters.get_value('Gm')
        eta1 = self.parameters.get_value('eta1')
        eta2 = self.parameters.get_value('eta2')
        tau_m = self.get_relaxation_time()
        return (
            f'BurgersModel(Ge={Ge:.2e} Pa, Gm={Gm:.2e} Pa, '
            f'eta1={eta1:.2e} Pa·s, eta2={eta2:.2e} Pa·s, tau={tau_m:.2e} s)'
        )

print('✓ BurgersModel implemented and registered')
print(f'  Registered models: {ModelRegistry.list_models()}')

## Section 3: Testing the Custom Model (10 min)

Comprehensive testing is critical for reliable custom models.

In [None]:
# Test 1: Model creation and parameters
print('Test 1: Basic model functionality')
print('-' * 50)

model = BurgersModel()
print(f'Model created: {model}')
print(f'\nParameters:')
for name in model.parameters:
    param = model.parameters.get(name)
    print(f'  {name}: {param.value:.2e} {param.units} (bounds: {param.bounds})')

# Verify registry integration
assert 'burgers' in ModelRegistry.list_models()
model_from_registry = ModelRegistry.create('burgers')
assert isinstance(model_from_registry, BurgersModel)
print(f'\n✓ Model registered and retrievable from ModelRegistry')

In [None]:
# Test 2: Relaxation modulus prediction
print('\nTest 2: Relaxation modulus prediction')
print('-' * 50)

# Set known parameters
Ge_true = 1e4
Gm_true = 5e4
eta1_true = 1e3
eta2_true = 1e4

model.parameters.set_value('Ge', Ge_true)
model.parameters.set_value('Gm', Gm_true)
model.parameters.set_value('eta1', eta1_true)
model.parameters.set_value('eta2', eta2_true)

# Generate time points
t = jnp.logspace(-2, 2, 50)

# Predict relaxation
G_t = model._predict_relaxation(t, Ge_true, Gm_true, eta1_true, eta2_true)

# Verify analytical solution
tau_m = eta1_true / Gm_true
G_expected = Ge_true + Gm_true * np.exp(-np.array(t) / tau_m)

max_error = float(jnp.max(jnp.abs(G_t - G_expected)))
rel_error = max_error / float(jnp.max(G_expected))

print(f'Time range: {float(t[0]):.2e} - {float(t[-1]):.2e} s')
print(f'G(t) range: {float(jnp.min(G_t)):.2e} - {float(jnp.max(G_t)):.2e} Pa')
print(f'Max absolute error: {max_error:.2e} Pa')
print(f'Max relative error: {rel_error:.2e}')
assert rel_error < 1e-6, f'Prediction error too large: {rel_error}'
print('✓ Relaxation prediction matches analytical solution')

In [None]:
# Test 3: Fitting to synthetic data
print('\nTest 3: NLSQ optimization on noisy data')
print('-' * 50)

# Generate noisy synthetic data
np.random.seed(42)
t_data = np.logspace(-2, 2, 50)
G_t_true = Ge_true + Gm_true * np.exp(-t_data * Gm_true / eta1_true)
noise_level = 0.02  # 2% noise
G_t_noisy = G_t_true + np.random.normal(0, noise_level * G_t_true, size=t_data.shape)

# Create fresh model with different initial values
model_fit = BurgersModel()
model_fit.parameters.set_value('Ge', 2e4)  # Wrong initial guess
model_fit.parameters.set_value('Gm', 3e4)  # Wrong initial guess
model_fit.parameters.set_value('eta1', 5e2)  # Wrong initial guess
model_fit.parameters.set_value('eta2', 5e3)  # Wrong initial guess

print(f'Initial guesses (intentionally wrong):')
print(f'  Ge: {model_fit.parameters.get_value("Ge"):.2e} Pa (true: {Ge_true:.2e})')
print(f'  Gm: {model_fit.parameters.get_value("Gm"):.2e} Pa (true: {Gm_true:.2e})')
print(f'  eta1: {model_fit.parameters.get_value("eta1"):.2e} Pa·s (true: {eta1_true:.2e})')
print(f'  eta2: {model_fit.parameters.get_value("eta2"):.2e} Pa·s (true: {eta2_true:.2e})')

# Fit using NLSQ
print('\nFitting with NLSQ optimization...')

# Extract fitted parameters
Ge_fit = model_fit.parameters.get_value('Ge')
Gm_fit = model_fit.parameters.get_value('Gm')
eta1_fit = model_fit.parameters.get_value('eta1')
eta2_fit = model_fit.parameters.get_value('eta2')

print(f'\nFitted Parameters:')
print(f'  Ge: {Ge_fit:.2e} Pa (error: {abs(Ge_fit-Ge_true)/Ge_true*100:.1f}%)')
print(f'  Gm: {Gm_fit:.2e} Pa (error: {abs(Gm_fit-Gm_true)/Gm_true*100:.1f}%)')
print(f'  eta1: {eta1_fit:.2e} Pa·s (error: {abs(eta1_fit-eta1_true)/eta1_true*100:.1f}%)')
print(f'  eta2: {eta2_fit:.2e} Pa·s (note: not identifiable from relaxation data)')

# Check NLSQ result
result = model_fit.get_nlsq_result()
if result:
    print(f'\nOptimization diagnostics:')
    print(f'  Converged: {result.success}')
    print(f'  Iterations: {result.nit}')
    print(f'  Final cost: {result.fun:.2e}')
    print(f'  Message: {result.message}')

# Verify reasonable fit (Ge, Gm, eta1 should be recovered, eta2 is not identifiable)
assert abs(Ge_fit - Ge_true) / Ge_true < 0.1, 'Ge fit error > 10%'
assert abs(Gm_fit - Gm_true) / Gm_true < 0.1, 'Gm fit error > 10%'
assert abs(eta1_fit - eta1_true) / eta1_true < 0.15, 'eta1 fit error > 15%'
print('\n✓ NLSQ optimization successfully recovered parameters')

In [None]:
# Test 4: Edge cases and error handling
print('\nTest 4: Edge cases and robustness')
print('-' * 50)

# Test parameter bounds
try:
    model_test = BurgersModel()
    model_test.parameters.set_value('Ge', -100)  # Negative modulus
    print('✗ Should have raised ValueError for negative Ge')
except ValueError:
    print('✓ Bounds enforcement: negative Ge rejected')

try:
    model_test = BurgersModel()
    model_test.parameters.set_value('eta1', 1e15)  # Out of bounds
    print('✗ Should have raised ValueError for out-of-bounds eta1')
except ValueError:
    print('✓ Bounds enforcement: out-of-bounds eta1 rejected')

# Test very short and very long times
model_test = BurgersModel()
t_extreme = jnp.array([1e-6, 1e6])  # Microsecond to million seconds
G_extreme = model_test._predict_relaxation(
    t_extreme, Ge_true, Gm_true, eta1_true, eta2_true
)

print(f'\nExtreme time predictions:')
print(f'  t={t_extreme[0]:.2e} s → G={float(G_extreme[0]):.2e} Pa')
print(f'  t={t_extreme[1]:.2e} s → G={float(G_extreme[1]):.2e} Pa')
assert jnp.all(jnp.isfinite(G_extreme)), 'Non-finite values at extreme times'
print('✓ Model stable at extreme time scales')

# Test unsupported test mode
try:
    model_test = BurgersModel()
    model_test._test_mode = 'invalid_mode'
    model_test.predict(t_data)
    print('✗ Should have raised ValueError for invalid test mode')
except (ValueError, AttributeError):
    print('✓ Error handling: invalid test mode rejected')

## Section 4: Bayesian Inference with Custom Models (10 min)

All models inheriting from `BaseModel` automatically gain Bayesian capabilities through `BayesianMixin`.

In [None]:
# Bayesian inference with warm-start from NLSQ
print('Bayesian Inference on Custom Burgers Model')
print('=' * 60)

# Use fitted model from previous test (warm-start)
print(f'Starting NUTS sampling with warm-start from NLSQ...')
print(f'  Initial Ge: {model_fit.parameters.get_value("Ge"):.2e} Pa')
print(f'  Initial Gm: {model_fit.parameters.get_value("Gm"):.2e} Pa')
print(f'  Initial eta1: {model_fit.parameters.get_value("eta1"):.2e} Pa·s')

# Run Bayesian inference (uses NLSQ values as initial_values automatically)
result = model_fit.fit_bayesian(
    t_data, G_t_noisy,
    num_warmup=500,
    num_samples=1000,
    num_chains=1
)

print(f'\nBayesian inference completed!')
print(f'  Posterior samples: {result.num_samples}')
print(f'  Parameters sampled: {list(result.posterior_samples.keys())}')

In [None]:
# Convergence diagnostics
print('\nConvergence Diagnostics (R-hat and ESS)')
print('-' * 60)

for param_name in ['Ge', 'Gm', 'eta1']:
    r_hat = result.diagnostics['r_hat'][param_name]
    ess = result.diagnostics['ess'][param_name]
    
    # Convergence criteria
    converged = r_hat < 1.1 and ess > 400
    status = '✓' if converged else '✗'
    
    print(f'{status} {param_name:8s}: R-hat = {r_hat:.4f}, ESS = {ess:.0f}')

# Check for divergences
n_divergences = result.diagnostics.get('num_divergences', 0)
print(f'\nDivergences: {n_divergences}')
if n_divergences == 0:
    print('✓ No divergences detected (healthy sampling)')
else:
    print(f'⚠ {n_divergences} divergences (consider increasing warmup or adapt_step_size)')

In [None]:
# Posterior summary statistics
print('\nPosterior Summary (Mean ± Std)')
print('-' * 60)

for param_name in ['Ge', 'Gm', 'eta1']:
    mean = result.summary[param_name]['mean']
    std = result.summary[param_name]['std']
    
    # Compare to true values
    true_values = {'Ge': Ge_true, 'Gm': Gm_true, 'eta1': eta1_true}
    true_val = true_values[param_name]
    
    # Check if true value is within 2σ
    within_2sigma = abs(mean - true_val) < 2 * std
    status = '✓' if within_2sigma else '⚠'
    
    print(f'{status} {param_name:8s}: {mean:.2e} ± {std:.2e} (true: {true_val:.2e})')

# Credible intervals
print('\n95% Credible Intervals')
print('-' * 60)

intervals = model_fit.get_credible_intervals(result.posterior_samples, credibility=0.95)

for param_name in ['Ge', 'Gm', 'eta1']:
    lower, upper = intervals[param_name]
    true_val = true_values[param_name]
    
    # Check if true value is within credible interval
    contains_true = lower <= true_val <= upper
    status = '✓' if contains_true else '✗'
    
    print(f'{status} {param_name:8s}: [{lower:.2e}, {upper:.2e}]')
    print(f'           True value: {true_val:.2e}')

In [None]:
# Posterior predictive distribution
print('\nPosterior Predictive Distribution')
print('-' * 60)

# Sample from posterior to generate predictions
n_posterior_samples = 100  # Use subset for efficiency
posterior_Ge = result.posterior_samples['Ge'][:n_posterior_samples]
posterior_Gm = result.posterior_samples['Gm'][:n_posterior_samples]
posterior_eta1 = result.posterior_samples['eta1'][:n_posterior_samples]
posterior_eta2 = result.posterior_samples['eta2'][:n_posterior_samples]

# Generate predictions for each posterior sample
predictions = []
for i in range(n_posterior_samples):
    G_pred = model_fit._predict_relaxation(
        jnp.array(t_data),
        posterior_Ge[i],
        posterior_Gm[i],
        posterior_eta1[i],
        posterior_eta2[i]
    )
    predictions.append(G_pred)

predictions = jnp.stack(predictions)

# Compute prediction statistics
pred_mean = jnp.mean(predictions, axis=0)
pred_std = jnp.std(predictions, axis=0)

print(f'Generated {n_posterior_samples} posterior predictive samples')
print(f'Prediction mean range: [{float(jnp.min(pred_mean)):.2e}, {float(jnp.max(pred_mean)):.2e}] Pa')
print(f'Prediction std range: [{float(jnp.min(pred_std)):.2e}, {float(jnp.max(pred_std)):.2e}] Pa')
print(f'\n✓ Posterior predictive captures uncertainty in model predictions')

## Section 5: Integration with Pipeline API (5 min)

Custom models work seamlessly with Rheo's Pipeline API.

In [None]:
from rheojax.pipeline.base import Pipeline
from rheojax.pipeline.bayesian import BayesianPipeline
import matplotlib.pyplot as plt

print('Pipeline Integration with Custom Model')
print('=' * 60)

# Method 1: Direct model instance
pipeline1 = Pipeline()
pipeline1.model = BurgersModel()  # Use custom model directly
pipeline1.data = RheoData(x=t_data, y=G_t_noisy, domain='time')
pipeline1.fit()

print('Method 1: Direct model instance')
print(f'  Fitted: {pipeline1.model.fitted_}')
print(f'  Ge: {pipeline1.model.parameters.get_value("Ge"):.2e} Pa')

# Method 2: Registry-based creation
pipeline2 = Pipeline()
data_dict = {'time': t_data, 'G': G_t_noisy}
import tempfile
import pandas as pd

# Save to temporary CSV
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
    df = pd.DataFrame(data_dict)
    df.to_csv(f.name, index=False)
    csv_path = f.name

# Use fluent API
(pipeline2
 .load(csv_path, x_col='time', y_col='G')
 .fit('burgers'))  # Load custom model by name from registry

print('\nMethod 2: Registry-based (fluent API)')
print(f'  Fitted: {pipeline2.model.fitted_}')
print(f'  Ge: {pipeline2.model.parameters.get_value("Ge"):.2e} Pa')

print('\n✓ Custom model integrates seamlessly with Pipeline API')

# Cleanup
import os
os.unlink(csv_path)

In [None]:
# BayesianPipeline for comprehensive workflow
print('\nBayesianPipeline with Custom Model')
print('-' * 60)

bayesian_pipeline = BayesianPipeline()
bayesian_pipeline.model = BurgersModel()
bayesian_pipeline.data = RheoData(x=t_data, y=G_t_noisy, domain='time')

# NLSQ → NUTS workflow
print('Step 1: NLSQ point estimation...')
bayesian_pipeline.fit_nlsq()

print('Step 2: Bayesian inference with warm-start...')
bayesian_pipeline.fit_bayesian(num_samples=1000, num_warmup=500)

# Access diagnostics
diagnostics = bayesian_pipeline.get_diagnostics()
summary = bayesian_pipeline.get_posterior_summary()

print('\nPipeline Results:')
print(f'  NLSQ converged: {bayesian_pipeline.model.get_nlsq_result().success}')
print(f'  Bayesian samples: {bayesian_pipeline._bayesian_result.num_samples}')
print(f'  All R-hat < 1.1: {all(r < 1.1 for r in diagnostics["r_hat"].values())}')
print('\n✓ BayesianPipeline provides complete NLSQ→NUTS workflow')

## Section 6: Performance Optimization with JAX (5 min)

JAX `@jit` compilation provides significant speedup for model evaluation.

In [None]:
import time

print('Performance Benchmarking: JAX JIT Compilation')
print('=' * 60)

# Create large time array for benchmarking
t_large = jnp.logspace(-3, 3, 10000)

# Warm-up JIT compilation (first call compiles, subsequent calls use cached version)
_ = BurgersModel._predict_relaxation(t_large[:10], Ge_true, Gm_true, eta1_true, eta2_true)

# Benchmark JIT-compiled version
n_iterations = 100
start = time.time()
for _ in range(n_iterations):
    G_jit = BurgersModel._predict_relaxation(t_large, Ge_true, Gm_true, eta1_true, eta2_true)
    G_jit.block_until_ready()  # Wait for GPU computation to finish
elapsed_jit = time.time() - start

# Create non-JIT version for comparison
def predict_relaxation_nojit(t, Ge, Gm, eta1, eta2):
    """Non-JIT version for comparison."""
    tau_m = eta1 / Gm
    return Ge + Gm * jnp.exp(-t / tau_m)

# Benchmark non-JIT version
start = time.time()
for _ in range(n_iterations):
    G_nojit = predict_relaxation_nojit(t_large, Ge_true, Gm_true, eta1_true, eta2_true)
    G_nojit.block_until_ready()
elapsed_nojit = time.time() - start

speedup = elapsed_nojit / elapsed_jit

print(f'Benchmark results ({n_iterations} iterations on {len(t_large)} points):')
print(f'  JIT-compiled:     {elapsed_jit:.4f} s ({elapsed_jit/n_iterations*1000:.2f} ms/iter)')
print(f'  Non-JIT:          {elapsed_nojit:.4f} s ({elapsed_nojit/n_iterations*1000:.2f} ms/iter)')
print(f'  Speedup:          {speedup:.1f}x')
print(f'\n✓ JAX @jit provides {speedup:.1f}x speedup for model evaluation')

# Verify both produce same results
assert jnp.allclose(G_jit, G_nojit, rtol=1e-6)
print('✓ JIT and non-JIT versions produce identical results')

In [None]:
# Memory efficiency and float64 verification
print('\nMemory and Precision Verification')
print('-' * 60)

# Check data types
print(f'JAX array dtype: {G_jit.dtype}')
print(f'Expected: float64 (enforced by NLSQ import order)')

assert G_jit.dtype == jnp.float64, 'Expected float64 precision'
print('✓ Float64 precision maintained throughout computation')

# Check memory usage
memory_mb = G_jit.nbytes / (1024**2)
print(f'\nMemory usage for {len(t_large)} points:')
print(f'  Array size: {memory_mb:.2f} MB')
print(f'  Per point: {G_jit.nbytes / len(t_large):.1f} bytes')

## Section 7: Best Practices and Guidelines (5 min)

### Custom Model Development Checklist

#### 1. Model Class Structure
- ✓ Inherit from `BaseModel` (provides scikit-learn API + Bayesian capabilities)
- ✓ Register with `@ModelRegistry.register('model_name')` decorator
- ✓ Define `ParameterSet` in `__init__()` with bounds and units
- ✓ Implement `_fit()` for optimization (use `nlsq_optimize`)
- ✓ Implement `_predict()` for predictions
- ✓ Implement `model_function()` for Bayesian inference

#### 2. JAX Integration
- ✓ **Always use `safe_import_jax()`** from `rheo.core.jax_config`
- ✓ Never import JAX directly (breaks float64 precision)
- ✓ Use `@jax.jit` decorator on static prediction methods
- ✓ Use `jnp` (JAX NumPy) for all array operations
- ✓ Ensure all operations are JAX-compatible for autodiff

#### 3. Parameter Management
- ✓ Set physically meaningful bounds (prevent unphysical values)
- ✓ Provide reasonable default values (aid convergence)
- ✓ Include units in parameter definitions
- ✓ Add descriptions for documentation
- ✓ Consider parameter identifiability (some may not be recoverable)

#### 4. Testing Requirements
- ✓ Test basic instantiation and parameter access
- ✓ Test prediction against analytical solutions
- ✓ Test fitting with synthetic noisy data
- ✓ Test edge cases (extreme parameter values, time scales)
- ✓ Test Bayesian inference (convergence diagnostics)
- ✓ Test registry integration
- ✓ Validate numerical precision (float64)

#### 5. Documentation
- ✓ Comprehensive docstring with theory and equations
- ✓ Parameter descriptions with units and ranges
- ✓ Usage examples in docstring
- ✓ References to literature
- ✓ Supported test modes clearly stated

#### 6. Performance Optimization
- ✓ Use `@staticmethod` for prediction functions (allows JIT)
- ✓ Apply `@jax.jit` to all computational functions
- ✓ Avoid Python loops (use JAX vectorization)
- ✓ Profile code to identify bottlenecks
- ✓ Test with large datasets to verify scalability

#### 7. Error Handling
- ✓ Validate input data types and shapes
- ✓ Check for unsupported test modes
- ✓ Handle optimization failures gracefully
- ✓ Provide informative error messages
- ✓ Test boundary conditions

### Common Pitfalls to Avoid

1. **JAX Import Order**: Never `import jax` directly. Use `safe_import_jax()`.
2. **Float32 Precision**: Without NLSQ first, JAX defaults to float32 (insufficient for rheology).
3. **Mutable State in JIT**: Don't access `self.parameters` inside `@jax.jit` functions.
4. **Unbounded Parameters**: Always set bounds (prevents optimizer divergence).
5. **Forgetting model_function()**: Required for Bayesian inference to work.
6. **Poor Initial Guesses**: Provide sensible defaults (aids NLSQ convergence).
7. **No Registry**: Forgetting `@ModelRegistry.register()` prevents Pipeline integration.
8. **Inadequate Testing**: Models must be validated with known solutions.

### File Organization

```
rheo/models/
    my_model.py          # Model implementation
tests/models/
    test_my_model.py     # Comprehensive test suite
examples/
    my_model_tutorial.ipynb  # Usage examples
```

### Template for New Models

```python
from rheojax.core.base import BaseModel
from rheojax.core.parameters import ParameterSet
from rheojax.core.registry import ModelRegistry
from rheojax.core.jax_config import safe_import_jax
from rheojax.utils.optimization import create_least_squares_objective, nlsq_optimize

jax, jnp = safe_import_jax()

@ModelRegistry.register('my_model')
class MyModel(BaseModel):
    """Docstring with theory, equations, references."""
    
    def __init__(self):
        super().__init__()
        self.parameters = ParameterSet()
        # Add parameters...
    
    def _fit(self, X, y, **kwargs):
        # Optimization logic with NLSQ...
        pass
    
    def _predict(self, X):
        # Prediction logic...
        pass
    
    def model_function(self, X, params):
        # For Bayesian inference...
        pass
    
    @staticmethod
    @jax.jit
    def _predict_relaxation(t, ...):
        # JIT-compiled prediction...
        pass
```

## Summary and Key Takeaways

### What We Learned

1. **Burgers Model Theory**: 4-parameter model combining Zener and Maxwell elements
2. **BaseModel Inheritance**: Provides full Rheo ecosystem integration
3. **NLSQ Optimization**: GPU-accelerated fitting with automatic differentiation
4. **Automatic Bayesian**: BayesianMixin gives free NUTS inference
5. **ModelRegistry**: Dynamic model discovery and instantiation
6. **Testing Strategy**: Comprehensive validation ensures reliability
7. **Pipeline Integration**: Custom models work seamlessly with workflows
8. **JAX Performance**: `@jit` compilation provides significant speedup

### The Power of BaseModel

By inheriting from `BaseModel`, your custom model automatically gains:
- ✓ Scikit-learn compatible API (`fit`, `predict`, `score`)
- ✓ Bayesian inference via NUTS (`fit_bayesian`)
- ✓ Prior sampling (`sample_prior`)
- ✓ Credible intervals (`get_credible_intervals`)
- ✓ Pipeline compatibility
- ✓ Serialization (`to_dict`, `from_dict`)
- ✓ Parameter management

### Implementation Time

- **Core Model**: 30-60 min (including theory research)
- **Testing**: 30-60 min (comprehensive validation)
- **Documentation**: 15-30 min (docstrings and examples)
- **Total**: 1.5-2.5 hours for production-ready custom model

### Next Steps

1. **Explore Advanced Models**: `04-fractional-models-deep-dive.ipynb`
2. **Bayesian Workflows**: `../basic/05-bayesian-uncertainty.ipynb`
3. **Model Comparison**: `../basic/03-comparing-models.ipynb`
4. **Production Pipelines**: `../basic/04-complete-analysis-pipeline.ipynb`

### Additional Resources

- **CLAUDE.md**: Project development guidelines
- **rheo/core/base.py**: BaseModel implementation
- **rheo/models/**: 20 example model implementations
- **tests/models/**: Comprehensive test examples
- **Rheo Documentation**: https://rheo.readthedocs.io

### Questions to Explore

1. How do I implement oscillatory shear predictions (`TestMode.OSCILLATION`)?
2. Can I create models with variable number of parameters?
3. How do I implement custom priors for Bayesian inference?
4. What if my model requires numerical integration?
5. How can I implement multi-mode models (different equations per test mode)?

**Congratulations!** You now have the knowledge to create production-ready custom rheological models in Rheo.