diff --git a/torchdiffeq/_impl/adjoint.py b/torchdiffeq/_impl/adjoint.py index ca9ba7f..001a53d 100644 --- a/torchdiffeq/_impl/adjoint.py +++ b/torchdiffeq/_impl/adjoint.py @@ -5,6 +5,8 @@ from .misc import _check_inputs, _flat_to_shape from .misc import _mixed_norm +from snopt import SNOptAdjointCollector + class OdeintAdjointMethod(torch.autograd.Function): @@ -109,6 +111,8 @@ def augmented_dynamics(t, y_aug): # Solve adjoint ODE # ################################## + snopt_collector = SNOptAdjointCollector(func) if adjoint_options['use_snopt'] else None + if t_requires_grad: time_vjps = torch.empty(len(t), dtype=t.dtype, device=t.device) else: @@ -123,9 +127,14 @@ def augmented_dynamics(t, y_aug): time_vjps[i] = dLd_cur_t # Run the augmented system backwards in time. + samp_t = t[i - 1:i + 1].flip(0) + + if snopt_collector: + adjoint_options, samp_t = snopt_collector.check_inputs(adjoint_options, samp_t) + aug_state = odeint( augmented_dynamics, tuple(aug_state), - t[i - 1:i + 1].flip(0), + samp_t, rtol=adjoint_rtol, atol=adjoint_atol, method=adjoint_method, options=adjoint_options ) aug_state = [a[1] for a in aug_state] # extract just the t[i - 1] value diff --git a/torchdiffeq/_impl/solvers.py b/torchdiffeq/_impl/solvers.py index 6915f2b..4a79f59 100644 --- a/torchdiffeq/_impl/solvers.py +++ b/torchdiffeq/_impl/solvers.py @@ -5,7 +5,7 @@ class AdaptiveStepsizeODESolver(metaclass=abc.ABCMeta): - def __init__(self, dtype, y0, norm, **unused_kwargs): + def __init__(self, dtype, y0, norm, snopt_collector=None, **unused_kwargs): _handle_unused_kwargs(self, unused_kwargs) del unused_kwargs @@ -13,6 +13,7 @@ def __init__(self, dtype, y0, norm, **unused_kwargs): self.dtype = dtype self.norm = norm + self.snopt_collector = snopt_collector def _before_integrate(self, t): pass @@ -28,6 +29,8 @@ def integrate(self, t): self._before_integrate(t) for i in range(1, len(t)): solution[i] = self._advance(t[i]) + if self.snopt_collector: + self.snopt_collector.call_invoke(self.func, t[i], solution[i]) return solution @@ -48,7 +51,7 @@ def integrate_until_event(self, t0, event_fn): class FixedGridODESolver(metaclass=abc.ABCMeta): order: int - def __init__(self, func, y0, step_size=None, grid_constructor=None, interp="linear", perturb=False, **unused_kwargs): + def __init__(self, func, y0, step_size=None, grid_constructor=None, interp="linear", perturb=False, snopt_collector=None, **unused_kwargs): self.atol = unused_kwargs.pop('atol') unused_kwargs.pop('rtol', None) unused_kwargs.pop('norm', None) @@ -62,6 +65,7 @@ def __init__(self, func, y0, step_size=None, grid_constructor=None, interp="line self.step_size = step_size self.interp = interp self.perturb = perturb + self.snopt_collector = snopt_collector if step_size is None: if grid_constructor is None: @@ -113,6 +117,8 @@ def integrate(self, t): solution[j] = self._cubic_hermite_interp(t0, y0, f0, t1, y1, f1, t[j]) else: raise ValueError(f"Unknown interpolation method {self.interp}") + if self.snopt_collector: + self.snopt_collector.call_invoke(self.func, t[j], solution[j]) j += 1 y0 = y1