# Physics Constraints in ReactorTwin

ReactorTwin enforces **7 physics constraints** to ensure that Neural ODE predictions respect
fundamental physical laws. Each constraint supports two modes:

- **Hard (projection)** -- directly modifies predictions to satisfy the constraint
- **Soft (penalty)** -- adds a differentiable penalty term to the training loss

This notebook explores all constraint types:

1. Positivity
2. Mass Balance
3. Energy Balance
4. Stoichiometric
5. Port-Hamiltonian
6. GENERIC
7. Thermodynamic

We also demonstrate the `ConstraintPipeline` for composing multiple constraints and show how
to integrate constraints into a training loop.

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt

from reactor_twin import (
    NeuralODE, PositivityConstraint, MassBalanceConstraint,
    EnergyBalanceConstraint, StoichiometricConstraint,
    PortHamiltonianConstraint, GENERICConstraint,
    ThermodynamicConstraint, ConstraintPipeline,
)

torch.manual_seed(42)

## 1. Hard vs Soft Modes

Every constraint in ReactorTwin can operate in one of two modes:

| Mode | Behavior | When to use |
|------|----------|-------------|
| **Hard** | Projects predictions onto the constraint manifold | Strict guarantees needed (e.g., deployment) |
| **Soft** | Adds a weighted penalty to the loss function | During training for smoother gradients |

Let's see the difference using the positivity constraint.

In [None]:
# Create synthetic predictions with some negative values
preds = torch.randn(1, 20, 2) * 0.5  # Some values will be negative

# Hard constraint (projection)
hard_constraint = PositivityConstraint(mode="hard", method="softplus")
hard_preds, hard_violation = hard_constraint(preds)

# Soft constraint (penalty)
soft_constraint = PositivityConstraint(mode="soft", method="relu", weight=10.0)
soft_preds, soft_violation = soft_constraint(preds)

print(f"Original min:     {preds.min().item():.4f}")
print(f"Hard min:         {hard_preds.min().item():.4f}")
print(f"Soft penalty:     {soft_violation.item():.4f}")
print(f"Soft min:         {soft_preds.min().item():.4f}")  # soft_preds = preds (unchanged)

fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for ax, data, title in zip(axes, [preds, hard_preds, soft_preds],
                            ['Original', 'Hard (projected)', 'Soft (unchanged + penalty)']):
    ax.plot(data[0, :, 0].numpy(), label='Species A')
    ax.plot(data[0, :, 1].numpy(), label='Species B')
    ax.axhline(y=0, color='k', linestyle='--', alpha=0.5)
    ax.set_title(title)
    ax.legend()
    ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 2. Mass Balance Constraint

The **mass balance** constraint enforces conservation of total mass. For a closed system with
reaction A -> B, the total concentration $C_A + C_B$ must remain constant over time.

In [None]:
stoich = torch.tensor([[-1.0, 1.0]])  # A -> B

mass_hard = MassBalanceConstraint(
    mode="hard",
    stoichiometry=stoich,
    total_mass=1.0,
)

# Predictions that violate mass balance
preds_bad = torch.tensor([[[0.6, 0.5], [0.3, 0.6], [0.1, 0.7]]])  # Sum != 1.0

corrected, violation = mass_hard(preds_bad)
print("Before correction:")
print(f"  Sums: {preds_bad[0].sum(dim=-1).numpy()}")
print("After correction:")
print(f"  Sums: {corrected[0].sum(dim=-1).numpy()}")
print(f"  Violation: {violation.item():.6f}")

## 3. Energy Balance Constraint

The **energy balance** constraint enforces conservation of energy in non-isothermal reactor
systems. It monitors the temperature state variable and penalizes or corrects energy imbalances
based on heat capacity and reaction enthalpies.

In [None]:
energy_constraint = EnergyBalanceConstraint(
    mode="soft",
    state_dim=3,  # C_A, C_B, T
    temp_index=2,
    Cp=4.186,  # kJ/(kg*K)
    weight=1.0,
)

# Predictions with 3 state variables (C_A, C_B, T)
preds_energy = torch.randn(1, 10, 3) * 0.1
preds_energy[:, :, 2] += 350.0  # Temperature around 350 K

corrected_energy, energy_violation = energy_constraint(preds_energy)
print(f"Energy violation penalty: {energy_violation.item():.6f}")

## 4. Stoichiometric Constraint

The **stoichiometric constraint** ensures that predicted concentration changes follow the
stoichiometric relationships of the reaction network. For a cascade A -> B -> C, the changes
in each species must be consistent with the stoichiometric matrix.

In [None]:
stoich_constraint = StoichiometricConstraint(
    mode="hard",
    stoichiometry=torch.tensor([[-1.0, 1.0, 0.0], [0.0, -1.0, 1.0]]),  # A->B->C
)

preds_stoich = torch.randn(1, 15, 3) * 0.5 + 0.5
corrected_stoich, stoich_violation = stoich_constraint(preds_stoich)
print(f"Stoichiometric violation: {stoich_violation.item():.6f}")

## 5. Constraint Pipeline

The `ConstraintPipeline` lets you compose multiple constraints and apply them sequentially.
Hard constraints are applied as projections, and soft constraint penalties are summed into a
single total violation term.

In [None]:
pipeline = ConstraintPipeline([
    PositivityConstraint(mode="hard", method="softplus"),
    MassBalanceConstraint(mode="soft", stoichiometry=torch.tensor([[-1.0, 1.0]]), total_mass=1.0, weight=5.0),
])

preds_pipe = torch.randn(1, 20, 2) * 0.3 + 0.3
corrected_pipe, total_violation = pipeline(preds_pipe)

print(f"Pipeline output min: {corrected_pipe.min().item():.4f}")
print(f"Total violation penalty: {total_violation.item():.6f}")

## 6. Training with Constraints

In practice, constraints are integrated directly into the training loop. After the Neural ODE
produces predictions, we pass them through a constraint (or pipeline) before computing the loss.
This ensures that the model learns to produce physically consistent outputs.

Below we train a small Neural ODE on batch reactor data with a hard positivity constraint.

In [None]:
from scipy.integrate import solve_ivp
from reactor_twin import ArrheniusKinetics, BatchReactor

# Generate batch reactor data
kinetics = ArrheniusKinetics(
    name="A_to_B", num_reactions=1,
    params={"k0": np.array([0.5]), "Ea": np.array([0.0]),
            "stoich": np.array([[-1, 1]]), "orders": np.array([[1, 0]])},
)
reactor = BatchReactor(
    name="batch", num_species=2,
    params={"V": 1.0, "T": 350.0, "C_initial": [1.0, 0.0]},
    kinetics=kinetics, isothermal=True,
)
y0 = reactor.get_initial_state()
t_eval = np.linspace(0, 8, 40)
sol = solve_ivp(reactor.ode_rhs, [0, 8], y0, t_eval=t_eval, method="LSODA")

z0 = torch.tensor(y0, dtype=torch.float32).unsqueeze(0)
t_span = torch.tensor(t_eval, dtype=torch.float32)
targets = torch.tensor(sol.y.T, dtype=torch.float32).unsqueeze(0)

# Train with hard positivity
model = NeuralODE(state_dim=2, hidden_dim=32, num_layers=2, solver="rk4", adjoint=False)
constraint = PositivityConstraint(mode="hard", method="softplus")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

model.train()
losses = []
for epoch in range(200):
    optimizer.zero_grad()
    preds = model(z0, t_span)
    preds_constrained, _ = constraint(preds)
    loss_dict = model.compute_loss(preds_constrained, targets)
    loss_dict["total"].backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    losses.append(loss_dict["total"].item())

model.eval()
with torch.no_grad():
    final_preds = model(z0, t_span)
    final_constrained, _ = constraint(final_preds)

print(f"Final loss: {losses[-1]:.6f}")
print(f"Min prediction: {final_constrained.min().item():.6f} (guaranteed >= 0)")

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.semilogy(losses)
plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.title('Training Loss'); plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
pred_np = final_constrained[0].numpy()
plt.plot(t_eval, sol.y[0], 'b--', label='True C_A')
plt.plot(t_eval, sol.y[1], 'r--', label='True C_B')
plt.plot(t_eval, pred_np[:, 0], 'b-', label='Pred C_A')
plt.plot(t_eval, pred_np[:, 1], 'r-', label='Pred C_B')
plt.xlabel('Time'); plt.ylabel('Concentration'); plt.title('Constrained Predictions')
plt.legend(); plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## Summary

In this notebook we explored the full suite of physics constraints available in ReactorTwin:

| Constraint | Purpose |
|------------|---------|
| **Positivity** | Concentrations >= 0 |
| **Mass Balance** | Conservation of total mass |
| **Energy Balance** | Conservation of energy |
| **Stoichiometric** | Consistent with reaction stoichiometry |
| **Port-Hamiltonian** | Structure-preserving energy dynamics |
| **GENERIC** | Thermodynamically consistent evolution |
| **Thermodynamic** | Second law of thermodynamics |

Key takeaways:

- **Hard mode** guarantees constraint satisfaction by projecting predictions onto the feasible set
- **Soft mode** provides smoother gradients during training via differentiable penalty terms
- **ConstraintPipeline** composes multiple constraints for complex physical systems
- Constraints integrate seamlessly into the Neural ODE training loop