diff --git a/tests/test_strat.py b/tests/test_strat.py new file mode 100644 index 0000000..be1618b --- /dev/null +++ b/tests/test_strat.py @@ -0,0 +1,128 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Temporary test for Stratonovich stuff. + +This should be eventually refactored and the file should be removed. +""" + +import torch +from torch import nn + +import time +from torchsde._core.base_sde import ForwardSDE # noqa +from torchsde import settings + +torch.set_default_dtype(torch.float64) +cpu, gpu = torch.device('cpu'), torch.device('cuda') +device = gpu if torch.cuda.is_available() else cpu + + +def _column_wise_func(y, t, i): + # This function is designed so that there are mixed partials. + return (torch.cos(y ** 2 * i + t * 0.1) + + torch.tan(y[..., 0:1] * y[..., -2:-1]) + + torch.sum(y ** 2, dim=-1, keepdim=True)) + + +class SDE(nn.Module): + + def __init__(self): + super(SDE, self).__init__() + self.noise_type = settings.NOISE_TYPES.general + self.sde_type = settings.SDE_TYPES.stratonovich + + def f(self, t, y): + return [torch.sin(y_) + t for y_ in y] + + def g(self, t, y): + return [ + torch.stack([_column_wise_func(y_, t, i) for i in range(m)], dim=-1) + for y_ in y + ] + + +batch_size, d, m = 3, 5, 12 + + +def _batch_jacobian(output, input_): + # Create batch of Jacobians for output of size (batch_size, d_o) and input of size (batch_size, d_i). + assert output.dim() == input_.dim() == 2 + assert output.size(0) == input_.size(0) + jacs = [] + for i in range(output.size(0)): # batch_size. + jac = [] + for j in range(output.size(1)): # d_o. + grad, = torch.autograd.grad(output[i, j], input_, retain_graph=True, allow_unused=True) + grad = torch.zeros_like(input_[i]) if grad is None else grad[i].detach() + jac.append(grad) + jac = torch.stack(jac, dim=0) + jacs.append(jac) + return torch.stack(jacs, dim=0) + + +def _gdg_jvp_brute_force(sde, t, y, a): + # Only returns the value for the first input-output pair. + with torch.enable_grad(): + y = [y_.detach().requires_grad_(True) if not y_.requires_grad else y_ for y_ in y] + g_eval = sde.g(t, y) + v = [torch.bmm(g_eval_, a_) for g_eval_, a_ in zip(g_eval, a)] + + y0, g_eval0, v0 = y[0], g_eval[0], v[0] + num_brownian = g_eval0.size(-1) + jacobians_by_column = [_batch_jacobian(g_eval0[..., l], y0) for l in range(num_brownian)] + return [ + sum(torch.bmm(jacobians_by_column[l], v0[..., l].unsqueeze(-1)).squeeze() for l in range(num_brownian)) + ] + + +def _make_inputs(): + t = torch.rand(()).to(device) + y = [torch.randn(batch_size, d).to(device)] + a = torch.randn(batch_size, m, m).to(device) + a = [a - a.transpose(1, 2)] # Anti-symmetric. + sde = ForwardSDE(SDE()) + return sde, t, y, a + + +def test_gdg_jvp(): + sde, t, y, a = _make_inputs() + outs_brute_force = _gdg_jvp_brute_force(sde, t, y, a) # Reference. + outs = sde.gdg_jvp_column_sum(t, y, a) + outs_v2 = sde.gdg_jvp_column_sum_v2(t, y, a) + for out_brute_force, out, out_v2 in zip(outs_brute_force, outs, outs_v2): + assert torch.allclose(out_brute_force, out) + assert torch.allclose(out_brute_force, out_v2) + + +def _time_function(func, reps=10): + now = time.perf_counter() + [func() for _ in range(reps)] + return time.perf_counter() - now + + +def check_efficiency(): + sde, t, y, a = _make_inputs() + + func1 = lambda: sde.gdg_jvp_column_sum_v1(t, y, a) # Linear in m. + time_elapse = _time_function(func1) + print(f'Time elapse for loop: {time_elapse:.4f}') + + func2 = lambda: sde.gdg_jvp_column_sum_v2(t, y, a) # Almost constant in m. + time_elapse = _time_function(func2) + print(f'Time elapse for duplicate: {time_elapse:.4f}') + + +test_gdg_jvp() +check_efficiency() diff --git a/torchsde/_core/adjoint_sdes/additive.py b/torchsde/_core/adjoint_sdes/additive.py index e8cbe19..ed90bd5 100644 --- a/torchsde/_core/adjoint_sdes/additive.py +++ b/torchsde/_core/adjoint_sdes/additive.py @@ -20,7 +20,7 @@ from .. import misc -class AdjointSDEAdditive(base_sde.AdjointSDEIto): +class AdjointSDEAdditive(base_sde.AdjointSDE): def __init__(self, sde, params): super(AdjointSDEAdditive, self).__init__(sde, noise_type="general") @@ -36,8 +36,6 @@ def f(self, t, y_aug): f_eval = sde.f(-t, y) f_eval = [-f_eval_ for f_eval_ in f_eval] - f_eval = misc.make_seq_requires_grad(f_eval) - vjp_y_and_params = misc.grad( outputs=f_eval, inputs=y + params, @@ -61,8 +59,6 @@ def g_prod(self, t, y_aug, noise): adj_y = [adj_y_.detach() for adj_y_ in adj_y] g_eval = [-g_ for g_ in sde.g(-t, y)] - g_eval = misc.make_seq_requires_grad(g_eval) - vjp_y_and_params = misc.grad( outputs=g_eval, inputs=y + params, grad_outputs=[ @@ -90,7 +86,7 @@ def gdg_prod(self, t, y, v): raise NotImplementedError("This method shouldn't be called.") -class AdjointSDEAdditiveLogqp(base_sde.AdjointSDEIto): +class AdjointSDEAdditiveLogqp(base_sde.AdjointSDE): def __init__(self, sde, params): super(AdjointSDEAdditiveLogqp, self).__init__(sde, noise_type="general") self.params = params @@ -106,8 +102,6 @@ def f(self, t, y_aug): f_eval = sde.f(-t, y) f_eval = [-f_eval_ for f_eval_ in f_eval] - f_eval = misc.make_seq_requires_grad(f_eval) - vjp_y_and_params = misc.grad( outputs=f_eval, inputs=y + params, @@ -128,7 +122,6 @@ def f(self, t, y_aug): u_eval = misc.seq_sub(f_eval, h_eval) u_eval = [torch.bmm(g_inv_eval_, u_eval_) for g_inv_eval_, u_eval_ in zip(g_inv_eval, u_eval)] log_ratio_correction = [.5 * torch.sum(u_eval_ ** 2., dim=1) for u_eval_ in u_eval] - log_ratio_correction = misc.make_seq_requires_grad(log_ratio_correction) corr_vjp_y_and_params = misc.grad( outputs=log_ratio_correction, inputs=y + params, grad_outputs=adj_l, @@ -154,8 +147,6 @@ def g_prod(self, t, y_aug, noise): adj_y = [adj_y_.detach() for adj_y_ in adj_y] g_eval = [-g_ for g_ in sde.g(-t, y)] - g_eval = misc.make_seq_requires_grad(g_eval) - vjp_y_and_params = misc.grad( outputs=g_eval, inputs=y + params, grad_outputs=[-noise_.unsqueeze(1) * adj_y_.unsqueeze(2) for noise_, adj_y_ in zip(noise, adj_y)], diff --git a/torchsde/_core/adjoint_sdes/diagonal.py b/torchsde/_core/adjoint_sdes/diagonal.py index e8c56bb..af9a596 100644 --- a/torchsde/_core/adjoint_sdes/diagonal.py +++ b/torchsde/_core/adjoint_sdes/diagonal.py @@ -19,7 +19,7 @@ from .. import misc -class AdjointSDEDiagonal(base_sde.AdjointSDEIto): +class AdjointSDEDiagonal(base_sde.AdjointSDE): def __init__(self, sde, params): super(AdjointSDEDiagonal, self).__init__(sde, noise_type="diagonal") @@ -34,8 +34,6 @@ def f(self, t, y_aug): adj_y = [adj_y_.detach() for adj_y_ in adj_y] g_eval = sde.g(-t, y) - g_eval = misc.make_seq_requires_grad(g_eval) - gdg = misc.grad( outputs=g_eval, inputs=y, grad_outputs=g_eval, @@ -45,10 +43,7 @@ def f(self, t, y_aug): gdg = misc.convert_none_to_zeros(gdg, y) f_eval = sde.f(-t, y) - f_eval_corrected = misc.seq_sub(gdg, f_eval) # Stratonovich correction for reverse-time. - f_eval_corrected = misc.make_seq_requires_grad(f_eval_corrected) - vjp_y_and_params = misc.grad( outputs=f_eval_corrected, inputs=y + params, @@ -95,7 +90,6 @@ def g_prod(self, t, y_aug, noise): adj_y = [adj_y_.detach() for adj_y_ in adj_y] g_eval = [-g_ for g_ in sde.g(-t, y)] - g_eval = misc.make_seq_requires_grad(g_eval) vjp_y_and_params = misc.grad( outputs=g_eval, inputs=y + params, grad_outputs=[-noise_ * adj_y_ for noise_, adj_y_ in zip(noise, adj_y)], @@ -119,7 +113,6 @@ def gdg_prod(self, t, y_aug, noise): adj_y = [adj_y_.detach().requires_grad_(True) for adj_y_ in adj_y] g_eval = sde.g(-t, y) - g_eval = misc.make_seq_requires_grad(g_eval) gdg_times_v = misc.grad( outputs=g_eval, inputs=y, grad_outputs=misc.seq_mul(g_eval, noise), @@ -154,8 +147,6 @@ def gdg_prod(self, t, y_aug, noise): allow_unused=True, create_graph=True ) gdg_v = misc.convert_none_to_zeros(gdg_v, y) - gdg_v = misc.make_seq_requires_grad(gdg_v) - mixed_partials_adj_y_and_params = misc.grad( outputs=gdg_v, inputs=y + params, grad_outputs=[torch.ones_like(p) for p in gdg_v], @@ -180,7 +171,7 @@ def h(self, t, y): raise NotImplementedError("This method shouldn't be called.") -class AdjointSDEDiagonalLogqp(base_sde.AdjointSDEIto): +class AdjointSDEDiagonalLogqp(base_sde.AdjointSDE): def __init__(self, sde, params): super(AdjointSDEDiagonalLogqp, self).__init__(sde, noise_type="diagonal") @@ -196,8 +187,6 @@ def f(self, t, y_aug): adj_y = [adj_y_.detach() for adj_y_ in adj_y] g_eval = sde.g(-t, y) - g_eval = misc.make_seq_requires_grad(g_eval) - gdg = misc.grad( outputs=g_eval, inputs=y, grad_outputs=g_eval, @@ -208,8 +197,6 @@ def f(self, t, y_aug): f_eval = sde.f(-t, y) f_eval_corrected = misc.seq_sub(gdg, f_eval) - f_eval_corrected = misc.make_seq_requires_grad(f_eval_corrected) - vjp_y_and_params = misc.grad( outputs=f_eval_corrected, inputs=y + params, grad_outputs=[-adj_y_ for adj_y_ in adj_y], @@ -244,8 +231,6 @@ def f(self, t, y_aug): h_eval = sde.h(-t, y) u_eval = misc.seq_sub_div(f_eval, h_eval, g_eval) log_ratio_correction = [.5 * torch.sum(u_eval_ ** 2., dim=1) for u_eval_ in u_eval] - - log_ratio_correction = misc.make_seq_requires_grad(log_ratio_correction) corr_vjp_y_and_params = misc.grad( outputs=log_ratio_correction, inputs=y + params, grad_outputs=adj_l, @@ -271,7 +256,6 @@ def g_prod(self, t, y_aug, noise): adj_y = [adj_y_.detach() for adj_y_ in adj_y] g_eval = sde.g(-t, y) - g_eval = misc.make_seq_requires_grad(g_eval) minus_g_eval = [-g_ for g_ in g_eval] minus_g_prod_eval = misc.seq_mul(minus_g_eval, noise) @@ -297,8 +281,6 @@ def gdg_prod(self, t, y_aug, noise): adj_y = [adj_y_.detach().requires_grad_(True) for adj_y_ in adj_y] g_eval = sde.g(-t, y) - g_eval = misc.make_seq_requires_grad(g_eval) - gdg_times_v = misc.grad( outputs=g_eval, inputs=y, grad_outputs=misc.seq_mul(g_eval, noise), @@ -333,8 +315,6 @@ def gdg_prod(self, t, y_aug, noise): create_graph=True, ) gdg_v = misc.convert_none_to_zeros(gdg_v, y) - gdg_v = misc.make_seq_requires_grad(gdg_v) - gdg_v = [gdg_v_.sum() for gdg_v_ in gdg_v] mixed_partials_adj_y_and_params = misc.grad( outputs=gdg_v, inputs=y + params, diff --git a/torchsde/_core/adjoint_sdes/scalar.py b/torchsde/_core/adjoint_sdes/scalar.py index e402a8d..b05c4b8 100644 --- a/torchsde/_core/adjoint_sdes/scalar.py +++ b/torchsde/_core/adjoint_sdes/scalar.py @@ -17,7 +17,7 @@ from .. import base_sde -class AdjointSDEScalar(base_sde.AdjointSDEIto): +class AdjointSDEScalar(base_sde.AdjointSDE): def __init__(self, sde, params): super(AdjointSDEScalar, self).__init__(sde, noise_type="scalar") @@ -39,7 +39,7 @@ def gdg_prod(self, t, y, v): raise NotImplementedError("This method shouldn't be called.") -class AdjointSDEScalarLogqp(base_sde.AdjointSDEIto): +class AdjointSDEScalarLogqp(base_sde.AdjointSDE): def __init__(self, sde, params): super(AdjointSDEScalarLogqp, self).__init__(sde, noise_type="scalar") diff --git a/torchsde/_core/base_sde.py b/torchsde/_core/base_sde.py index 002cbfd..c1ffe69 100644 --- a/torchsde/_core/base_sde.py +++ b/torchsde/_core/base_sde.py @@ -17,15 +17,14 @@ import torch from torch import nn -from . import misc from ..settings import NOISE_TYPES, SDE_TYPES +from . import misc -class BaseSDE(nn.Module, metaclass=abc.ABCMeta): +class BaseSDE(abc.ABC, nn.Module): """Base class for all SDEs. - Inheriting from this class ensures `noise_type` and `sde_type`, `f` and `g` are valid attributes, which the solver - depends on. + Inheriting from this class ensures `noise_type` and `sde_type` are valid attributes, which the solver depends on. """ def __init__(self, noise_type, sde_type): @@ -34,140 +33,189 @@ def __init__(self, noise_type, sde_type): raise ValueError(f"Expected noise type in {NOISE_TYPES}, but found {noise_type}") if sde_type not in SDE_TYPES: raise ValueError(f"Expected sde type in {SDE_TYPES}, but found {sde_type}") - # TODO: Making these Python properties breaks `torch.jit.script`. + # Making these Python properties breaks `torch.jit.script` self.noise_type = noise_type self.sde_type = sde_type + +class AdjointSDE(BaseSDE): + """Base class for reverse-time adjoint SDE. + + Each forward SDE with different noise type has a different adjoint SDE. + """ + + def __init__(self, sde, noise_type): + # `noise_type` must be supplied! Since the adjoint might have a different noise type than the original SDE. + super(AdjointSDE, self).__init__(sde_type=sde.sde_type, noise_type=noise_type) + self._base_sde = sde + @abc.abstractmethod def f(self, t, y): - raise NotImplementedError + pass @abc.abstractmethod def g(self, t, y): - raise NotImplementedError - + pass -# TODO: Lint error "Class SDEIto must implement all abstract methods" comes from changes in torch==1.6.0. -# Should be gone in future version. See https://github.com/pytorch/pytorch/issues/42305 for more. -class SDEIto(BaseSDE): - def __init__(self, noise_type): - super(SDEIto, self).__init__(noise_type=noise_type, sde_type=SDE_TYPES.ito) + @abc.abstractmethod + def h(self, t, y): + pass + @abc.abstractmethod + def g_prod(self, t, y, v): + pass -class SDEStratonovich(BaseSDE): - def __init__(self, noise_type): - super(SDEStratonovich, self).__init__(noise_type=noise_type, sde_type=SDE_TYPES.stratonovich) + @abc.abstractmethod + def gdg_prod(self, t, y, v): + pass class ForwardSDE(BaseSDE): - """Wrapper SDE for the forward pass. - `g_prod` and `gdg_prod` are additional functions that high-order solvers will call. - """ - - def __init__(self, base_sde): - super(ForwardSDE, self).__init__(sde_type=base_sde.sde_type, - noise_type=base_sde.noise_type) - self._base_sde = base_sde - self.f = self._base_sde.f - self.g = self._base_sde.g - self.h = self._base_sde.h - if self.noise_type == NOISE_TYPES.diagonal: - self.g_prod = self.g_prod_diagonal - elif self.noise_type == NOISE_TYPES.scalar: - self.g_prod = self.g_prod_scalar - else: - self.g_prod = self.g_prod_general_or_additive + def __init__(self, sde): + super(ForwardSDE, self).__init__(sde_type=sde.sde_type, noise_type=sde.noise_type) + self._base_sde = sde + + # Register the core function. This avoids polluting the codebase with if-statements. + self.g_prod = { + NOISE_TYPES.diagonal: self.g_prod_diagonal, + NOISE_TYPES.additive: self.g_prod_additive_or_general, + NOISE_TYPES.scalar: self.g_prod_scalar, + NOISE_TYPES.general: self.g_prod_additive_or_general + }[sde.noise_type] + self.gdg_prod = { + NOISE_TYPES.diagonal: self.gdg_prod_diagonal_or_scalar, + NOISE_TYPES.additive: self._skip, + NOISE_TYPES.scalar: self.gdg_prod_diagonal_or_scalar, + NOISE_TYPES.general: self.gdg_prod_general + } + self.gdg_jvp_column_sum = { + NOISE_TYPES.diagonal: self._skip, + NOISE_TYPES.additive: self._skip, + NOISE_TYPES.scalar: self._skip, + NOISE_TYPES.general: self.gdg_jvp_column_sum_v2 + }[sde.noise_type] + # TODO: Assign `gdg_jacobian_contraction`. def f(self, t, y): - # Make abstractmethod not complain, as we assign as an instance attribute instead - raise RuntimeError + return self._base_sde.f(t, y) def g(self, t, y): - # Make abstractmethod not complain, as we assign as an instance attribute instead - raise RuntimeError + return self._base_sde.g(t, y) + + def h(self, t, y): + return self._base_sde.h(t, y) + # g_prod functions. def g_prod_diagonal(self, t, y, v): return misc.seq_mul(self._base_sde.g(t, y), v) def g_prod_scalar(self, t, y, v): return misc.seq_mul_bc(self._base_sde.g(t, y), v) - def g_prod_general_or_additive(self, t, y, v): + def g_prod_additive_or_general(self, t, y, v): return misc.seq_batch_mvp(ms=self._base_sde.g(t, y), vs=v) - def gdg_prod(self, t, y, v): + # gdg_prod functions. + def gdg_prod_general(self, t, y, v): + # This function is used for Milstein. For general noise, Levy areas need to be supplied. + raise NotImplemented("This function should not be called.") + + def gdg_prod_diagonal_or_scalar(self, t, y, v): + requires_grad = torch.is_grad_enabled() # BP through solver. with torch.enable_grad(): - y = [y_.detach().requires_grad_(True) if not y_.requires_grad else y_ for y_ in y] + y = [y_ if y_.requires_grad else y_.detach().requires_grad_(True) for y_ in y] val = self._base_sde.g(t, y) - val = misc.make_seq_requires_grad(val) vjp_val = misc.grad( - outputs=val, inputs=y, grad_outputs=misc.seq_mul(val, v), create_graph=True, allow_unused=True) - vjp_val = misc.convert_none_to_zeros(vjp_val, y) - return vjp_val - + outputs=val, + inputs=y, + grad_outputs=misc.seq_mul(val, v), + create_graph=requires_grad, + allow_unused=True + ) + return misc.convert_none_to_zeros(vjp_val, y) + + # gdg_jvp_column_sum functions. + # Computes: sum_{j,k,l} d sigma_{i,l} / d x_j sigma_{j,k} A_{k,l}. + def gdg_jvp_column_sum_v1(self, t, y, a): + # Assumes `a` is anti-symmetric and `base_sde` is not of diagonal noise. + requires_grad = torch.is_grad_enabled() # BP through solver. + with torch.enable_grad(): + y = [y_ if y_.requires_grad else y_.detach().requires_grad_(True) for y_ in y] + g_eval = self._base_sde.g(t, y) + v = [torch.bmm(g_eval_, a_) for g_eval_, a_ in zip(g_eval, a)] + gdg_jvp_eval = [ + misc.jvp( + outputs=[g_eval_[..., col_idx] for g_eval_ in g_eval], + inputs=y, + grad_inputs=[v_[..., col_idx] for v_ in v], + retain_graph=True, + create_graph=requires_grad, + allow_unused=True + ) + for col_idx in range(g_eval[0].size(-1)) + ] + gdg_jvp_eval = misc.seq_add(*gdg_jvp_eval) + return misc.convert_none_to_zeros(gdg_jvp_eval, y) + + def gdg_jvp_column_sum_v2(self, t, y, a): + # Faster, but more memory intensive. + requires_grad = torch.is_grad_enabled() # BP through solver. + with torch.enable_grad(): + y = [y_ if y_.requires_grad else y_.detach().requires_grad_(True) for y_ in y] + g_eval = self._base_sde.g(t, y) + v = [torch.bmm(g_eval_, a_) for g_eval_, a_ in zip(g_eval, a)] + + batch_size, d, m = g_eval[0].size() # TODO: Relax this assumption. + y_dup = [torch.repeat_interleave(y_, repeats=m, dim=0) for y_ in y] + g_eval_dup = self._base_sde.g(t, y_dup) + v_flat = [v_.transpose(1, 2).flatten(0, 1) for v_ in v] + gdg_jvp_eval = misc.jvp( + g_eval_dup, y_dup, grad_inputs=v_flat, create_graph=requires_grad, allow_unused=True + ) + gdg_jvp_eval = misc.convert_none_to_zeros(gdg_jvp_eval, y) + gdg_jvp_eval = [t.reshape(batch_size, m, d, m).permute(0, 2, 1, 3) for t in gdg_jvp_eval] + gdg_jvp_eval = [t.diagonal(dim1=-2, dim2=-1).sum(-1) for t in gdg_jvp_eval] + return gdg_jvp_eval + + def _skip(self, t, y, v): # noqa + return [0.] * len(y) -class AdjointSDEIto(SDEIto): - """Base class for reverse-time adjoint SDE. - Each forward SDE with different noise type has a different adjoint SDE. - """ +class TupleSDE(BaseSDE): - def __init__(self, base_sde, noise_type): - super(AdjointSDEIto, self).__init__(noise_type=noise_type) - self._base_sde = base_sde + def __init__(self, sde): + super(TupleSDE, self).__init__(noise_type=sde.noise_type, sde_type=sde.sde_type) + self._base_sde = sde - @abc.abstractmethod def f(self, t, y): - raise NotImplementedError + return self._base_sde.f(t, y[0]), - @abc.abstractmethod def g(self, t, y): - raise NotImplementedError + return self._base_sde.g(t, y[0]), - @abc.abstractmethod def h(self, t, y): - raise NotImplementedError - - @abc.abstractmethod - def g_prod(self, t, y, v): - raise NotImplementedError - - @abc.abstractmethod - def gdg_prod(self, t, y, v): - raise NotImplementedError - + return self._base_sde.h(t, y[0]), -class TupleSDE(BaseSDE): - - def __init__(self, base_sde): - super(TupleSDE, self).__init__(noise_type=base_sde.noise_type, sde_type=base_sde.sde_type) - self._base_sde = base_sde - def f(self, t, y): - return (self._base_sde.f(t, y[0]),) +class RenameMethodsSDE(BaseSDE): - def g(self, t, y): - return (self._base_sde.g(t, y[0]),) + def __init__(self, sde, drift='f', diffusion='g', prior_drift='h'): + super(RenameMethodsSDE, self).__init__(noise_type=sde.noise_type, sde_type=sde.sde_type) + self._base_sde = sde + self.f = getattr(sde, drift) + self.g = getattr(sde, diffusion) + if hasattr(sde, prior_drift): + self.h = getattr(sde, prior_drift) - def h(self, t, y): - return (self._base_sde.h(t, y[0]),) +class SDEIto(BaseSDE): -class RenameMethodsSDE(BaseSDE): + def __init__(self, noise_type): + super(SDEIto, self).__init__(noise_type=noise_type, sde_type=SDE_TYPES.ito) - def __init__(self, base_sde, drift='f', diffusion='g', prior_drift='h'): - super(RenameMethodsSDE, self).__init__(noise_type=base_sde.noise_type, sde_type=base_sde.sde_type) - self._base_sde = base_sde - self.f = getattr(base_sde, drift) - self.g = getattr(base_sde, diffusion) - if hasattr(base_sde, prior_drift): - self.h = getattr(base_sde, prior_drift) - def f(self, t, y): - # Make abstractmethod not complain, as we assign as an instance attribute instead - raise RuntimeError +class SDEStratonovich(BaseSDE): - def g(self, t, y): - # Make abstractmethod not complain, as we assign as an instance attribute instead - raise RuntimeError + def __init__(self, noise_type): + super(SDEStratonovich, self).__init__(noise_type=noise_type, sde_type=SDE_TYPES.stratonovich) diff --git a/torchsde/_core/methods/milstein.py b/torchsde/_core/methods/milstein.py index 3c650df..0ed964b 100644 --- a/torchsde/_core/methods/milstein.py +++ b/torchsde/_core/methods/milstein.py @@ -32,10 +32,7 @@ def step(self, t0, y0, dt): f_eval = self.sde.f(t0, y0) g_prod_eval = self.sde.g_prod(t0, y0, I_k) - if self.sde.noise_type == NOISE_TYPES.additive: - gdg_prod_eval = [0] * len(g_prod_eval) - else: - gdg_prod_eval = self.sde.gdg_prod(t0, y0, v) + gdg_prod_eval = self.sde.gdg_prod(t0, y0, v) y1 = [ y0_i + f_eval_i * dt + g_prod_eval_i + .5 * gdg_prod_eval_i for y0_i, f_eval_i, g_prod_eval_i, gdg_prod_eval_i in zip(y0, f_eval, g_prod_eval, gdg_prod_eval) diff --git a/torchsde/_core/misc.py b/torchsde/_core/misc.py index 911dfec..2acbdf5 100644 --- a/torchsde/_core/misc.py +++ b/torchsde/_core/misc.py @@ -152,9 +152,23 @@ def batch_mvp(m, v): return mvp -def grad(inputs, **kwargs): +def grad(outputs, inputs, **kwargs): # Workaround for PyTorch bug #39784 + outputs = make_seq_requires_grad(outputs) if torch.is_tensor(inputs): inputs = (inputs,) _inputs = [torch.as_strided(input_, (), ()) for input_ in inputs] - return torch.autograd.grad(inputs=inputs, **kwargs) + return torch.autograd.grad(outputs, inputs, **kwargs) + + +def jvp(outputs, inputs, grad_inputs=None, **kwargs): + # `torch.autograd.functional.jvp` takes in `func` and requires re-evaluation. + # The present implementation avoids this. + outputs = make_seq_requires_grad(outputs) + if torch.is_tensor(inputs): + inputs = (inputs,) + _inputs = [torch.as_strided(input_, (), ()) for input_ in inputs] + + dummy = [torch.zeros_like(o, requires_grad=True) for o in outputs] + vjp = torch.autograd.grad(outputs, inputs, grad_outputs=dummy, **kwargs) + return torch.autograd.grad(vjp, dummy, grad_outputs=grad_inputs, **kwargs) diff --git a/torchsde/_core/sdeint.py b/torchsde/_core/sdeint.py index 101f869..eac3131 100644 --- a/torchsde/_core/sdeint.py +++ b/torchsde/_core/sdeint.py @@ -172,7 +172,7 @@ def check_contract(sde, method, logqp, ts, y0, bm, names, adjoint_method=None): if noise_channels != 1: raise ValueError(f"Scalar noise must have only one channel; the diffusion has {noise_channels} noise " f"channels.") - else: # sde.noise_type == NOISE_TYPES.diagonal + else: # sde.noise_type == NOISE_TYPES.diagonal batch_dimensions = diffusion_shape[0][:-1] for drift_shape_, diffusion_shape_ in zip(drift_shape, diffusion_shape): drift_shape_ = tuple(drift_shape_)