# Advanced Neural Differential Equations

ReactorTwin provides **5 Neural DE variants** for modeling chemical reactor dynamics:

1. **Neural ODE** - Standard neural ordinary differential equation
2. **Augmented Neural ODE** - Extended state space for more expressive dynamics
3. **Latent Neural ODE** - VAE-style encoder-decoder with ODE in latent space
4. **Neural SDE** - Stochastic differential equations with built-in uncertainty
5. **Neural CDE** - Controlled differential equations for irregular time series

This notebook demonstrates each variant on the same CSTR reference data, comparing their expressiveness, parameter counts, and training behavior.

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp

from reactor_twin import ArrheniusKinetics, CSTRReactor, NeuralODE
from reactor_twin.core import AugmentedNeuralODE, LatentNeuralODE, NeuralSDE, NeuralCDE

np.random.seed(42)
torch.manual_seed(42)

## Generating Reference Data

We create a simple CSTR dataset (A -> B reaction) that all five Neural DE models will learn from.

In [None]:
kinetics = ArrheniusKinetics(
    name="A_to_B", num_reactions=1,
    params={"k0": np.array([0.5]), "Ea": np.array([0.0]),
            "stoich": np.array([[-1, 1]]), "orders": np.array([[1, 0]])},
)
reactor = CSTRReactor(
    name="cstr", num_species=2,
    params={"V": 10.0, "F": 1.0, "C_feed": [1.0, 0.0], "T_feed": 350.0},
    kinetics=kinetics, isothermal=True,
)

y0 = reactor.get_initial_state()
t_eval = np.linspace(0, 10, 50)
sol = solve_ivp(reactor.ode_rhs, [0, 10], y0, t_eval=t_eval, method="LSODA")

z0 = torch.tensor(y0, dtype=torch.float32).unsqueeze(0)
t_span = torch.tensor(t_eval, dtype=torch.float32)
targets = torch.tensor(sol.y.T, dtype=torch.float32).unsqueeze(0)

print(f"Data: z0={z0.shape}, t={t_span.shape}, targets={targets.shape}")
plt.plot(t_eval, sol.y[0], 'b-', label='C_A')
plt.plot(t_eval, sol.y[1], 'r-', label='C_B')
plt.xlabel('Time'); plt.ylabel('Concentration'); plt.legend(); plt.grid(True, alpha=0.3)
plt.title('Reference CSTR Trajectory')
plt.show()

## 1. Standard Neural ODE

The baseline Neural ODE models dynamics as `dy/dt = f_\theta(t, y)` where `f_\theta` is a simple MLP. This is the most straightforward variant and works well for smooth, deterministic systems.

In [None]:
def train_model(model, z0, t_span, targets, num_epochs=200, lr=1e-3):
    """Train any Neural DE model and return losses."""
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    losses = []
    model.train()
    for epoch in range(num_epochs):
        optimizer.zero_grad()
        preds = model(z0, t_span)
        loss_dict = model.compute_loss(preds, targets)
        loss_dict["total"].backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        losses.append(loss_dict["total"].item())
    return losses

model_ode = NeuralODE(state_dim=2, hidden_dim=32, num_layers=2, solver="rk4", adjoint=False)
losses_ode = train_model(model_ode, z0, t_span, targets)
print(f"Neural ODE - Final loss: {losses_ode[-1]:.6f}, Params: {sum(p.numel() for p in model_ode.parameters()):,}")

## 2. Augmented Neural ODE

The Augmented Neural ODE extends the state space with extra dimensions, allowing the model to learn more expressive dynamics. This is particularly useful when the true dynamics require crossing trajectories in the original state space, which standard Neural ODEs cannot represent due to the uniqueness theorem of ODEs.

In [None]:
model_aug = AugmentedNeuralODE(
    state_dim=2, augment_dim=4,  # 4 extra dimensions
    hidden_dim=32, num_layers=2, solver="rk4", adjoint=False,
)
losses_aug = train_model(model_aug, z0, t_span, targets)
print(f"Augmented ODE - Final loss: {losses_aug[-1]:.6f}, Params: {sum(p.numel() for p in model_aug.parameters()):,}")

## 3. Latent Neural ODE

The Latent Neural ODE uses an encoder-decoder architecture:
1. **Encoder**: Maps observations into a latent space (VAE-style with reparameterization)
2. **ODE**: Evolves the latent state forward in time
3. **Decoder**: Maps latent trajectories back to observation space

This is especially powerful for high-dimensional systems where the intrinsic dynamics are low-dimensional.

In [None]:
model_latent = LatentNeuralODE(
    state_dim=2, latent_dim=4,
    encoder_hidden_dim=32, decoder_hidden_dim=32,
    encoder_type="mlp",
    hidden_dim=32, num_layers=2,
    solver="rk4", adjoint=False,
)

optimizer = torch.optim.Adam(model_latent.parameters(), lr=1e-3)
losses_latent = []
model_latent.train()
for epoch in range(200):
    optimizer.zero_grad()
    preds = model_latent(z0, t_span)
    loss_dict = model_latent.compute_loss(preds, targets)
    loss_dict["total"].backward()
    torch.nn.utils.clip_grad_norm_(model_latent.parameters(), max_norm=1.0)
    optimizer.step()
    losses_latent.append(loss_dict["total"].item())

print(f"Latent ODE - Final loss: {losses_latent[-1]:.6f}, Params: {sum(p.numel() for p in model_latent.parameters()):,}")

# Examine latent space
model_latent.eval()
with torch.no_grad():
    z_mean, z_logvar = model_latent.encode(z0)
print(f"Latent mean: {z_mean[0].numpy()}")
print(f"Latent logvar: {z_logvar[0].numpy()}")

## 4. Neural SDE

The Neural SDE models stochastic dynamics: `dy = f_\theta(t,y)dt + g_\phi(t,y)dW`, where `f` is the drift and `g` is the diffusion. This provides **built-in uncertainty quantification** since each forward pass samples a different Brownian path.

**Note:** Requires the `torchsde` package (`pip install torchsde`).

In [None]:
try:
    model_sde = NeuralSDE(
        state_dim=2, hidden_dim=32, num_layers=2,
        noise_type="diagonal", sde_type="ito",
        solver="euler", adjoint=False,
    )
    losses_sde = train_model(model_sde, z0, t_span, targets, num_epochs=100)
    print(f"Neural SDE - Final loss: {losses_sde[-1]:.6f}")
    
    # Multiple forward passes give different results (stochastic)
    model_sde.eval()
    sde_preds = []
    with torch.no_grad():
        for _ in range(5):
            pred = model_sde(z0, t_span)
            sde_preds.append(pred[0].numpy())
    sde_preds = np.array(sde_preds)
    print(f"Prediction spread (std): {sde_preds.std(axis=0).mean():.6f}")
except ImportError:
    print("torchsde not installed - skipping Neural SDE demo")
    print("Install with: pip install torchsde")
    losses_sde = None

## 5. Neural CDE

Neural Controlled Differential Equations are designed for **irregular time series**. Instead of a fixed ODE, the dynamics are driven by a continuous control signal (interpolated from observations):

`dy = f_\theta(y) dX/dt dt`

This makes them naturally suited to data with missing values or non-uniform sampling.

In [None]:
try:
    model_cde = NeuralCDE(
        state_dim=2, hidden_dim=32, num_layers=2,
        solver="rk4", adjoint=False,
    )
    losses_cde = train_model(model_cde, z0, t_span, targets, num_epochs=100)
    print(f"Neural CDE - Final loss: {losses_cde[-1]:.6f}")
except Exception as e:
    print(f"Neural CDE requires torchcde: {e}")
    losses_cde = None

## Comparison

Let's compare all trained models on training loss convergence and prediction quality.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curves
axes[0].semilogy(losses_ode, label='Neural ODE')
axes[0].semilogy(losses_aug, label='Augmented ODE')
axes[0].semilogy(losses_latent, label='Latent ODE')
if losses_sde: axes[0].semilogy(losses_sde, label='Neural SDE')
if losses_cde: axes[0].semilogy(losses_cde, label='Neural CDE')
axes[0].set_xlabel('Epoch'); axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss Comparison'); axes[0].legend(); axes[0].grid(True, alpha=0.3)

# Predictions
true_np = targets[0].numpy()
axes[1].plot(t_eval, true_np[:, 0], 'k--', linewidth=2, label='True C_A')

for model, name in [(model_ode, 'Neural ODE'), (model_aug, 'Augmented'), (model_latent, 'Latent')]:
    model.eval()
    with torch.no_grad():
        pred = model(z0, t_span)
    axes[1].plot(t_eval, pred[0, :, 0].numpy(), label=name)

axes[1].set_xlabel('Time'); axes[1].set_ylabel('C_A')
axes[1].set_title('C_A Predictions'); axes[1].legend(); axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## Summary

| Variant | Key Feature | Best For |
|---------|-------------|----------|
| **Neural ODE** | Simple, efficient | Smooth deterministic dynamics |
| **Augmented Neural ODE** | Extended state space | Complex dynamics with crossing trajectories |
| **Latent Neural ODE** | Encoder-decoder + latent ODE | High-dimensional systems with low-dimensional intrinsic dynamics |
| **Neural SDE** | Stochastic diffusion term | Systems with inherent noise, uncertainty quantification |
| **Neural CDE** | Control-driven dynamics | Irregular time series, missing data |

All variants share the same training interface (`compute_loss`, `forward`) and can be used interchangeably in the digital twin pipeline.