# Recreate torchdiffeq's defaults in torchode

In [1]:
import torch
import torch.nn as nn
import torchode as to
import torchdiffeq as tde

torch.random.manual_seed(180819023);

Consider a two-layer, randomly initialized MLP.

In [2]:
class Model(nn.Module):
    def __init__(self, n_features, n_hidden):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(n_features, n_hidden),
            nn.Softplus(),
            nn.Linear(n_hidden, n_hidden),
            nn.Softplus(),
            nn.Linear(n_hidden, n_features)
        )
    
    def forward(self, t, y):
        return self.layers(y)
    
n_features = 5
model = Model(n_features, 16)

We would like to evaluate this model on the following initial data `y0` and time points `t_eval`.

In [3]:
batch_size = 16
n_steps = 10
y0 = torch.randn((batch_size, n_features))
t_eval = torch.linspace(0.0, 1.0, n_steps)

With torchdiffeq that looks as follows.

In [4]:
sol_tde = tde.odeint(model, y0, t_eval)

In torchode, we set up the components and then put them together to create a solver from them that backpropagates by autodiffing through the solver operations (discretize-then-optimize).

In [5]:
term = to.ODETerm(model)
step_method = to.Dopri5(term=term)
step_size_controller = to.IntegralController(atol=1e-9, rtol=1e-7, term=term)
adjoint = to.AutoDiffAdjoint(step_method, step_size_controller)

Now we can reuse the solver in `adjoint` for any problem we want to solve, for example the one from above. For that we create a problem instance and pass it to the solver. Note that we have to repeat the evaluation time points for each sample in the batch because torchode solves a separate ODE for each sample.

In [6]:
problem = to.InitialValueProblem(y0=y0, t_eval=t_eval.repeat((batch_size, 1)))
sol = adjoint.solve(problem)

Comparing the two solutions shows that they are very close.

In [7]:
abs_err = (sol.ys - sol_tde.transpose(0, 1)).abs()

abs_err.mean().item(), abs_err.max().item()

(3.3444638347646105e-07, 6.198883056640625e-06)

Finally, let's look at the solution statistics that torchode gives us.

In [8]:
sol.stats

{'n_f_evals': tensor([50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50]),
 'n_steps': tensor([5, 5, 6, 8, 5, 6, 6, 5, 6, 7, 5, 5, 5, 5, 5, 7]),
 'n_accepted': tensor([5, 5, 6, 7, 5, 6, 5, 5, 5, 7, 5, 5, 5, 5, 5, 7]),
 'n_initialized': tensor([10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10])}