# Logging Solver Information

One of the key features of torchode is that all components are replaceable and any components can log its own outputs (captured in a dictionary called `stats`). This means that you can inject your own code and log anything information that is relevant for your usecase. In this example, we will create a step size controller wrapper that logs the step times `t` and all accept decisions, i.e. if each step was accepted or rejected by the step size controller.

We begin by importing relevant modules and defining a generic model class.

In [1]:
import torch
import torch.nn as nn
import torchode as to
from torchode.step_size_controllers import StepSizeController

torch.random.manual_seed(180819023);

In [2]:
class Model(nn.Module):
    def __init__(self, n_features, n_hidden):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(n_features, n_hidden),
            nn.Softplus(),
            nn.Linear(n_hidden, n_hidden),
            nn.Softplus(),
            nn.Linear(n_hidden, n_features)
        )
    
    def forward(self, t, y):
        return self.layers(y)

Now we define the wrapper that will track the step size data. By deferring the actual functionality to another controller, we can re-use the existing controller implementations and focus on collecting the information that we care about, in this case the integration time points `t`, the step size `dt` and whether each step was accepted.

To define this custom controller, we just have to satisfy the `StepSizeController` interface. For the actual functionality we defer to another controller instance. In `init` we additionally initialize fields in the statistics dictionary for the current solve to capture `t` and so on. The `adapt_step_size` method then records the information into those fields.

Note that you could proceed in a similar way to track information about the stepping methods, e.g. dopri5, by defining a `SingleStepMethod`.

In [3]:
class StepSizeControllerTracker(StepSizeController):
    """A wrapper that collects time step and step acceptance information."""

    def __init__(self, controller: StepSizeController):
        super().__init__()

        self.controller = controller

    def init(self, term, problem, method_order: int, dt0, *, stats, args):
        stats["all_t"] = []
        stats["all_dt"] = []
        stats["all_accept"] = []

        return self.controller.init(
            term, problem, method_order, dt0, stats=stats, args=args
        )

    def adapt_step_size(self, t0, dt, y0, step, state, stats):
        accept, dt_next, state, status = self.controller.adapt_step_size(
            t0, dt, y0, step, state, stats
        )

        stats["all_t"].append(t0)
        stats["all_dt"].append(dt)
        stats["all_accept"].append(accept)

        return accept, dt_next, state, status

    def merge_states(self, running, current, previous):
        return self.controller.merge_states(running, current, previous)

Next, we construct a solver and wrap the step size controller with our tracker.

In [4]:
n_features = 5
batch_size = 3

model = Model(n_features=n_features, n_hidden=32)

In [5]:
dev = torch.device("cpu")
term = to.ODETerm(model)
step_method = to.Dopri5(term=term)
step_size_controller = to.IntegralController(atol=1e-6, rtol=1e-3, term=term)
step_size_controller = StepSizeControllerTracker(step_size_controller)
adjoint = to.AutoDiffAdjoint(step_method, step_size_controller).to(dev)

Finally, we generate some example data and evaluate the ODE defined by a randomly initialized model.

In [6]:
t_eval = torch.tile(torch.linspace(0.0, 3.0, 10), (batch_size, 1))
problem = to.InitialValueProblem(y0=torch.zeros((batch_size, n_features)).to(dev), t_eval=t_eval.to(dev))

sol = adjoint.solve(problem)

In the end, we can inspect the statistics recorded in the solution object and see that our custom step size controller collected the data.

In [7]:
print(torch.stack(sol.stats["all_t"]))

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.0000e-04, 1.0000e-04, 1.0000e-04],
        [1.1000e-03, 1.1000e-03, 1.1000e-03],
        [1.1100e-02, 1.1100e-02, 1.1100e-02],
        [1.1110e-01, 1.1110e-01, 1.1110e-01],
        [9.3798e-01, 9.3798e-01, 9.3798e-01]], grad_fn=<StackBackward0>)


In [8]:
print(torch.stack(sol.stats["all_dt"]))

tensor([[1.0000e-04, 1.0000e-04, 1.0000e-04],
        [1.0000e-03, 1.0000e-03, 1.0000e-03],
        [1.0000e-02, 1.0000e-02, 1.0000e-02],
        [1.0000e-01, 1.0000e-01, 1.0000e-01],
        [8.2688e-01, 8.2688e-01, 8.2688e-01],
        [2.0620e+00, 2.0620e+00, 2.0620e+00]], grad_fn=<StackBackward0>)


In [9]:
print(torch.stack(sol.stats["all_accept"]))

tensor([[True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True]])
