diff --git a/.gitignore b/.gitignore index b536f82..db3b392 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ benchmarks/plots/ CMakeLists.txt restats *-darwin.so +**.pyc diff --git a/diagnostics/ito_additive.py b/diagnostics/ito_additive.py index 2791f73..33d7a12 100644 --- a/diagnostics/ito_additive.py +++ b/diagnostics/ito_additive.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import os -import argparse import matplotlib.pyplot as plt +import numpy as np import torch import tqdm -import numpy as np from scipy import stats from tests.basic_sde import AdditiveSDE @@ -51,7 +51,7 @@ def inspect_samples(): ts_, ys_em_, ys_srk_, ys_true_ = to_numpy(ts, ys_em, ys_srk, ys_true) # Visualize sample path. - img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_additive') + img_dir = os.path.join('.', 'diagnostics', 'plots', 'ito_additive') makedirs_if_not_found(img_dir) for i, (ys_em_i, ys_srk_i, ys_true_i) in enumerate(zip(ys_em_, ys_srk_, ys_true_)): @@ -105,7 +105,7 @@ def inspect_strong_order(): plt.yscale('log') plt.legend() - img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_additive') + img_dir = os.path.join('.', 'diagnostics', 'plots', 'ito_additive') makedirs_if_not_found(img_dir) plt.savefig(os.path.join(img_dir, 'rate')) plt.close() diff --git a/diagnostics/ito_diagonal.py b/diagnostics/ito_diagonal.py index 3ff6633..acf70e0 100644 --- a/diagnostics/ito_diagonal.py +++ b/diagnostics/ito_diagonal.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import os -import argparse import matplotlib.pyplot as plt import numpy as np import torch @@ -56,7 +56,7 @@ def inspect_sample(): ts, ys_euler, ys_milstein, ys_milstein_grad_free, ys_srk, ys_analytical) # Visualize sample path. - img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_diagonal') + img_dir = os.path.join('.', 'diagnostics', 'plots', 'ito_diagonal') makedirs_if_not_found(img_dir) for i, (ys_euler_i, ys_milstein_i, ys_milstein_grad_free_i, ys_srk_i, ys_analytical_i) in enumerate( @@ -92,7 +92,8 @@ def inspect_strong_order(): # Only take end value. _, ys_euler = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='euler') _, ys_milstein = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='milstein') - _, ys_milstein_grad_free = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='milstein', options={'grad_free': True}) + _, ys_milstein_grad_free = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='milstein', + options={'grad_free': True}) _, ys_srk = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='srk') _, ys_analytical = sde.analytical_sample(y0=y0, ts=ts, bm=bm) @@ -101,7 +102,8 @@ def inspect_strong_order(): milstein_grad_free_mse = compute_mse(ys_milstein_grad_free, ys_analytical) srk_mse = compute_mse(ys_srk, ys_analytical) - euler_mse_, milstein_mse_, milstein_grad_free_mse_, srk_mse_ = to_numpy(euler_mse, milstein_mse, milstein_grad_free_mse, srk_mse) + euler_mse_, milstein_mse_, milstein_grad_free_mse_, srk_mse_ = to_numpy( + euler_mse, milstein_mse, milstein_grad_free_mse, srk_mse) euler_mses_.append(euler_mse_) milstein_mses_.append(milstein_mse_) @@ -125,7 +127,7 @@ def inspect_strong_order(): plt.yscale('log') plt.legend() - img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_diagonal') + img_dir = os.path.join('.', 'diagnostics', 'plots', 'ito_diagonal') makedirs_if_not_found(img_dir) plt.savefig(os.path.join(img_dir, 'rate')) plt.close() diff --git a/diagnostics/ito_scalar.py b/diagnostics/ito_scalar.py index 69b463f..8f3eaa5 100644 --- a/diagnostics/ito_scalar.py +++ b/diagnostics/ito_scalar.py @@ -12,16 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import os -import argparse import matplotlib.pyplot as plt import numpy as np import torch import tqdm from scipy import stats -from tests.problems import Ex2 +from tests.problems import Ex2Scalar from torchsde import sdeint, BrownianInterval from torchsde.settings import LEVY_AREA_APPROXIMATIONS from .utils import to_numpy, makedirs_if_not_found, compute_mse @@ -34,7 +34,7 @@ def inspect_sample(): ts = torch.linspace(0., 5., steps=steps).to(device) dt = 1e-1 y0 = torch.ones(batch_size, d).to(device) - sde = Ex2(d=d).to(device) + sde = Ex2Scalar(d=d).to(device) sde.noise_type = "scalar" with torch.no_grad(): @@ -54,7 +54,7 @@ def inspect_sample(): ts, ys_euler, ys_milstein, ys_srk, ys_analytical) # Visualize sample path. - img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_scalar') + img_dir = os.path.join('.', 'diagnostics', 'plots', 'ito_scalar') makedirs_if_not_found(img_dir) for i, (ys_euler_i, ys_milstein_i, ys_srk_i, ys_analytical_i) in enumerate( @@ -74,7 +74,7 @@ def inspect_strong_order(): ts = torch.tensor([0., 5.]).to(device) dts = tuple(2 ** -i for i in range(1, 9)) y0 = torch.ones(batch_size, d).to(device) - sde = Ex2(d=d).to(device) + sde = Ex2Scalar(d=d).to(device) euler_mses_ = [] milstein_mses_ = [] @@ -116,7 +116,7 @@ def inspect_strong_order(): plt.yscale('log') plt.legend() - img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_scalar') + img_dir = os.path.join('.', 'diagnostics', 'plots', 'ito_scalar') makedirs_if_not_found(img_dir) plt.savefig(os.path.join(img_dir, 'rate')) plt.close() diff --git a/diagnostics/stratonovich_diagonal.py b/diagnostics/stratonovich_diagonal.py index 1e48191..dbc0275 100644 --- a/diagnostics/stratonovich_diagonal.py +++ b/diagnostics/stratonovich_diagonal.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import os -import argparse import matplotlib.pyplot as plt import numpy as np import torch @@ -43,7 +43,8 @@ def inspect_sample(): ys_heun = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='heun', names={'drift': 'f_corr'}) ys_midpoint = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='midpoint', names={'drift': 'f_corr'}) ys_milstein_strat = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='milstein', names={'drift': 'f_corr'}) - ys_mil_strat_grad_free = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='milstein', names={'drift': 'f_corr'}, options={'grad_free': True}) + ys_mil_strat_grad_free = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='milstein', names={'drift': 'f_corr'}, + options={'grad_free': True}) ys_analytical = sde.analytical_sample(y0=y0, ts=ts, bm=bm) ys_heun = ys_heun.squeeze().t() @@ -93,7 +94,8 @@ def inspect_strong_order(): _, ys_heun = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='heun', names={'drift': 'f_corr'}) _, ys_midpoint = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='midpoint', names={'drift': 'f_corr'}) _, ys_milstein_strat = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='milstein', names={'drift': 'f_corr'}) - _, ys_mil_strat_grad_free = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='milstein', names={'drift': 'f_corr'}, options={'grad_free': True}) + _, ys_mil_strat_grad_free = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='milstein', + names={'drift': 'f_corr'}, options={'grad_free': True}) _, ys_analytical = sde.analytical_sample(y0=y0, ts=ts, bm=bm) heun_mse = compute_mse(ys_heun, ys_analytical) @@ -101,7 +103,9 @@ def inspect_strong_order(): milstein_strat_mse = compute_mse(ys_milstein_strat, ys_analytical) mil_strat_grad_free_mse = compute_mse(ys_mil_strat_grad_free, ys_analytical) - heun_mse_, midpoint_mse_, milstein_strat_mse_, mil_strat_grad_free_mse_ = to_numpy(heun_mse, midpoint_mse, milstein_strat_mse, mil_strat_grad_free_mse) + heun_mse_, midpoint_mse_, milstein_strat_mse_, mil_strat_grad_free_mse_ = to_numpy(heun_mse, midpoint_mse, + milstein_strat_mse, + mil_strat_grad_free_mse) heun_mses_.append(heun_mse_) midpoint_mses_.append(midpoint_mse_) diff --git a/tests/basic_sde.py b/tests/basic_sde.py index 19fc84b..64283ed 100644 --- a/tests/basic_sde.py +++ b/tests/basic_sde.py @@ -129,12 +129,9 @@ def h(self, t, y): class ScalarSDE(AdditiveSDE): def __init__(self, d=10, m=3): super(ScalarSDE, self).__init__(d=d, m=m) - self.g_param = nn.Parameter(torch.sigmoid(torch.randn(1, d)), requires_grad=True) + self.g_param = nn.Parameter(torch.sigmoid(torch.randn(1, d, 1)), requires_grad=True) self.noise_type = "scalar" - def g(self, t, y): - return self.g_param.repeat(y.size(0), 1) - class TupleSDE(SDEIto): def __init__(self, d=10): diff --git a/tests/problems.py b/tests/problems.py index 3e57d33..709c7b3 100644 --- a/tests/problems.py +++ b/tests/problems.py @@ -105,6 +105,15 @@ def nfe(self): return self._nfe +class Ex2Scalar(Ex2): + def __init__(self, d=10, sde_type='ito'): + super(Ex2Scalar, self).__init__(d=d, sde_type=sde_type) + self.noise_type = "scalar" + + def g(self, t, y): + return super(Ex2Scalar, self).g(t, y).unsqueeze(2) + + class Ex3(BaseSDE): def __init__(self, d=10, sde_type='ito'): super(Ex3, self).__init__(noise_type="diagonal", sde_type=sde_type) diff --git a/tests/test_adjoint_logqp.py b/tests/test_adjoint_logqp.py index 0d3f459..5cf179d 100644 --- a/tests/test_adjoint_logqp.py +++ b/tests/test_adjoint_logqp.py @@ -43,18 +43,26 @@ class TestAdjointLogqp(TorchTestCase): def test_basic_sde1(self): + self.skipTest("Temporarily deprecating logqp.") + sde = BasicSDE1(d).to(device) _test_forward_and_backward(sde) def test_basic_sde2(self): + self.skipTest("Temporarily deprecating logqp.") + sde = BasicSDE2(d).to(device) _test_forward_and_backward(sde) def test_basic_sde3(self): + self.skipTest("Temporarily deprecating logqp.") + sde = BasicSDE3(d).to(device) _test_forward_and_backward(sde) def test_basic_sde4(self): + self.skipTest("Temporarily deprecating logqp.") + sde = BasicSDE4(d).to(device) _test_forward_and_backward(sde) diff --git a/tests/test_sdeint.py b/tests/test_sdeint.py index 51dc89d..11b4ec7 100644 --- a/tests/test_sdeint.py +++ b/tests/test_sdeint.py @@ -61,11 +61,14 @@ class TestSdeint(TorchTestCase): def test_rename_methods(self): - # Test renaming works with a subset of names when `logqp=False`. + # Test renaming works with a subset of names. sde = basic_sde.CustomNamesSDE().to(device) ans = sdeint(sde, y0, ts, dt=dt, names={'drift': 'forward'}) self.assertEqual(ans.shape, (T, batch_size, d)) + def test_rename_methods_logqp(self): + self.skipTest("Temporarily deprecating logqp.") + # Test renaming works with a subset of names when `logqp=True`. sde = basic_sde.CustomNamesSDELogqp().to(device) ans = sdeint(sde, y0, ts, dt=dt, names={'drift': 'forward', 'prior_drift': 'w'}, logqp=True) @@ -76,18 +79,36 @@ def test_sdeint_general(self): sde = basic_sde.GeneralSDE(d=d, m=m).to(device) for method in ('euler',): self._test_sdeint(sde, bm=bm_general, adaptive=False, method=method, dt=dt) + + def test_sdeint_general_logqp(self): + self.skipTest("Temporarily deprecating logqp.") + + sde = basic_sde.GeneralSDE(d=d, m=m).to(device) + for method in ('euler',): self._test_sdeint_logqp(sde, bm=bm_general, adaptive=False, method=method, dt=dt) def test_sdeint_additive(self): sde = basic_sde.AdditiveSDE(d=d, m=m).to(device) for method in ('euler', 'milstein', 'srk'): self._test_sdeint(sde, bm=bm_general, adaptive=False, method=method, dt=dt) + + def test_sdeint_additive_logqp(self): + self.skipTest("Temporarily deprecating logqp.") + + sde = basic_sde.AdditiveSDE(d=d, m=m).to(device) + for method in ('euler', 'milstein', 'srk'): self._test_sdeint_logqp(sde, bm=bm_general, adaptive=False, method=method, dt=dt) def test_sde_scalar(self): sde = basic_sde.ScalarSDE(d=d, m=m).to(device) for method in ('euler', 'milstein', 'srk'): self._test_sdeint(sde, bm=bm_scalar, adaptive=False, method=method, dt=dt) + + def test_sde_scalar_logqp(self): + self.skipTest("Temporarily deprecating logqp.") + + sde = basic_sde.ScalarSDE(d=d, m=m).to(device) + for method in ('euler', 'milstein', 'srk'): self._test_sdeint_logqp(sde, bm=bm_scalar, adaptive=False, method=method, dt=dt) def test_srk_determinism(self): @@ -116,24 +137,19 @@ def test_sdeint_adaptive(self): self._test_sdeint(sde, bm_diagonal, adaptive=True, method=method, dt=dt) def test_sdeint_logqp_fixed(self): + self.skipTest("Temporarily deprecating logqp.") + for sde in basic_sdes: for method in ('euler', 'milstein', 'srk'): self._test_sdeint_logqp(sde, bm_diagonal, adaptive=False, method=method, dt=dt) def test_sdeint_logqp_adaptive(self): + self.skipTest("Temporarily deprecating logqp.") + for sde in basic_sdes: for method in ('milstein', 'srk'): self._test_sdeint_logqp(sde, bm_diagonal, adaptive=True, method=method, dt=dt) - def test_sdeint_tuple_sde(self): - y0_ = (y0,) # Make tuple input. - sde = basic_sde.TupleSDE(d=d).to(device) - - for method in ('euler', 'milstein', 'srk'): - ans = sdeint(sde, y0_, ts, method=method, dt=dt) - self.assertTrue(isinstance(ans, tuple)) - self.assertEqual(ans[0].size(), (T, batch_size, d)) - def _test_sdeint(self, sde, bm, adaptive, method, dt): # Using `f` as drift. with torch.no_grad(): diff --git a/tests/test_strat.py b/tests/test_strat.py index 17f44b4..b56ae2b 100644 --- a/tests/test_strat.py +++ b/tests/test_strat.py @@ -24,7 +24,7 @@ from torchsde import sdeint_adjoint, BrownianInterval from torchsde import settings -from torchsde._core.base_sde import ForwardSDE, TupleSDE # noqa +from torchsde._core.base_sde import ForwardSDE # noqa torch.manual_seed(1147481649) torch.set_default_dtype(torch.float64) @@ -75,38 +75,34 @@ def _batch_jacobian(output, input_): return torch.stack(jacs, dim=0) -def _gdg_jvp_brute_force(sde, t, y, a): - # Only returns the value for the first input-output pair. +def _dg_ga_jvp_brute_force(sde, t, y, a): with torch.enable_grad(): - y = [y_.detach().requires_grad_(True) if not y_.requires_grad else y_ for y_ in y] - g_eval = sde.g(t, y) - v = [torch.bmm(g_eval_, a_) for g_eval_, a_ in zip(g_eval, a)] + y = y.detach().requires_grad_(True) if not y.requires_grad else y + g = sde.g(t, y) + ga = torch.bmm(g, a) - y0, g_eval0, v0 = y[0], g_eval[0], v[0] - num_brownian = g_eval0.size(-1) - jacobians_by_column = [_batch_jacobian(g_eval0[..., l], y0) for l in range(num_brownian)] - return [ - sum(torch.bmm(jacobians_by_column[l], v0[..., l].unsqueeze(-1)).squeeze() for l in range(num_brownian)) - ] + num_brownian = g.size(-1) + jacobians_by_column = [_batch_jacobian(g[..., l], y) for l in range(num_brownian)] + return sum(torch.bmm(jacobians_by_column[l], ga[..., l].unsqueeze(-1)).squeeze() for l in range(num_brownian)) def _make_inputs(): t = torch.rand((), device=device) - y = [torch.randn(batch_size, d, device=device)] + y = torch.randn(batch_size, d, device=device) a = torch.randn(batch_size, m, m, device=device) - a = [a - a.transpose(1, 2)] # Anti-symmetric. - sde = ForwardSDE(TupleSDE(SDE())) + a = a - a.transpose(1, 2) # Anti-symmetric. + sde = ForwardSDE(SDE()) return sde, t, y, a -def test_gdg_jvp(): +def test_dg_ga_jvp(): sde, t, y, a = _make_inputs() - outs_brute_force = _gdg_jvp_brute_force(sde, t, y, a) # Reference. - outs = sde.gdg_jvp_column_sum(t, y, a) - outs_v2 = sde.gdg_jvp_column_sum_v2(t, y, a) - for out_brute_force, out, out_v2 in zip(outs_brute_force, outs, outs_v2): - assert torch.allclose(out_brute_force, out) - assert torch.allclose(out_brute_force, out_v2) + outs_brute_force = _dg_ga_jvp_brute_force(sde, t, y, a) # Reference. + outs = sde.dg_ga_jvp_column_sum_v1(t, y, a) + outs_v2 = sde.dg_ga_jvp_column_sum_v2(t, y, a) + assert torch.is_tensor(outs_brute_force) and torch.is_tensor(outs) and torch.is_tensor(outs_v2) + assert torch.allclose(outs_brute_force, outs) + assert torch.allclose(outs_brute_force, outs_v2) def _time_function(func, reps=10): @@ -118,11 +114,11 @@ def _time_function(func, reps=10): def check_efficiency(): sde, t, y, a = _make_inputs() - func1 = lambda: sde.gdg_jvp_column_sum_v1(t, y, a) # Linear in m. + func1 = lambda: sde.dg_ga_jvp_column_sum_v1(t, y, a) # Linear in m. time_elapse = _time_function(func1) print(f'Time elapse for loop: {time_elapse:.4f}') - func2 = lambda: sde.gdg_jvp_column_sum_v2(t, y, a) # Almost constant in m. + func2 = lambda: sde.dg_ga_jvp_column_sum_v2(t, y, a) # Almost constant in m. time_elapse = _time_function(func2) print(f'Time elapse for duplicate: {time_elapse:.4f}') @@ -139,6 +135,6 @@ def func(y0): torch.autograd.gradcheck(func, y0_, rtol=1e-4, atol=1e-3, eps=1e-8) -test_gdg_jvp() +test_dg_ga_jvp() check_efficiency() test_adjoint() diff --git a/torchsde/_brownian/__init__.py b/torchsde/_brownian/__init__.py index 46910b4..840fe1e 100644 --- a/torchsde/_brownian/__init__.py +++ b/torchsde/_brownian/__init__.py @@ -16,7 +16,7 @@ from .brownian_interval import BrownianInterval from .brownian_path import BrownianPath from .brownian_tree import BrownianTree -from .modified import ReverseBrownian, TupleBrownian +from .modified import ReverseBrownian BrownianInterval.__init__.__annotations__ = {} BrownianPath.__init__.__annotations__ = {} diff --git a/torchsde/_brownian/brownian_path.py b/torchsde/_brownian/brownian_path.py index 3799a2e..e505c3a 100644 --- a/torchsde/_brownian/brownian_path.py +++ b/torchsde/_brownian/brownian_path.py @@ -72,7 +72,7 @@ def __init__(self, """ # TODO: Couple of things to optimize: 1) search based on local window, # 2) avoid the `return_U` and `return_A` arguments. - handle_unused_kwargs(self, unused_kwargs) + handle_unused_kwargs(unused_kwargs, msg=self.__class__.__name__) del unused_kwargs super(BrownianPath, self).__init__() diff --git a/torchsde/_brownian/brownian_tree.py b/torchsde/_brownian/brownian_tree.py index 9f537d6..c0585b2 100644 --- a/torchsde/_brownian/brownian_tree.py +++ b/torchsde/_brownian/brownian_tree.py @@ -86,7 +86,7 @@ def __init__(self, approximation type. This is needed for some higher-order SDE solvers. """ - handle_unused_kwargs(self, unused_kwargs) + handle_unused_kwargs(unused_kwargs, msg=self.__class__.__name__) del unused_kwargs super(BrownianTree, self).__init__() diff --git a/torchsde/_brownian/modified.py b/torchsde/_brownian/modified.py index ad623b0..d23ae13 100644 --- a/torchsde/_brownian/modified.py +++ b/torchsde/_brownian/modified.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch - from . import base_brownian @@ -50,11 +48,3 @@ def __call__(self, ta, tb, return_U=False, return_A=False): # Whether or not to negate the statistics depends on the return value of the adjoint SDE. Currently, the adjoint # returns negated drift and diffusion, so we don't negate here. return self.base_brownian(-tb, -ta, return_U=return_U, return_A=return_A) - - -class TupleBrownian(_ModifiedBrownian): - def __call__(self, ta, tb, return_U=False, return_A=False): - statistics = self.base_brownian(ta, tb, return_U=return_U, return_A=return_A) - if torch.is_tensor(statistics): - return (statistics,) - return [(i,) for i in statistics] diff --git a/torchsde/_brownian/utils.py b/torchsde/_brownian/utils.py index b3b4a0b..75d6134 100644 --- a/torchsde/_brownian/utils.py +++ b/torchsde/_brownian/utils.py @@ -27,7 +27,7 @@ class BrownianReturn: __slots__ = ('W', 'U', 'A') - + def __init__(self, W, U=None, A=None): self.W = W self.U = U diff --git a/torchsde/_core/adjoint.py b/torchsde/_core/adjoint.py index 6115f01..66306af 100644 --- a/torchsde/_core/adjoint.py +++ b/torchsde/_core/adjoint.py @@ -17,25 +17,20 @@ import torch from torch import nn -try: - from ..brownian_lib import BrownianPath -except Exception: # noqa - from .._brownian import BrownianPath -from .._brownian import BaseBrownian, ReverseBrownian -from ..types import TensorOrTensors, Scalar, Vector -from .adjoint_sde import AdjointSDE # Directly import to avoid conflicting names. - from . import base_sde from . import misc from . import sdeint +from .adjoint_sde import AdjointSDE +from .._brownian import BaseBrownian, ReverseBrownian from ..settings import METHODS, SDE_TYPES, NOISE_TYPES +from ..types import Scalar, Vector class _SdeintAdjointMethod(torch.autograd.Function): @staticmethod - def forward(ctx, sde, ts, flat_params, dt, bm, method, adjoint_method, adaptive, adjoint_adaptive, rtol, # noqa - adjoint_rtol, atol, adjoint_atol, dt_min, options, adjoint_options, *y0): + def forward(ctx, sde, ts, dt, bm, method, adjoint_method, adaptive, adjoint_adaptive, rtol, # noqa + adjoint_rtol, atol, adjoint_atol, dt_min, options, adjoint_options, y0, *params): ctx.sde = sde ctx.dt = dt ctx.bm = bm @@ -46,95 +41,7 @@ def forward(ctx, sde, ts, flat_params, dt, bm, method, adjoint_method, adaptive, ctx.dt_min = dt_min ctx.adjoint_options = adjoint_options - sde = base_sde.ForwardSDE(sde) - ans = sdeint.integrate( - sde=sde, - y0=y0, - ts=ts, - bm=bm, - method=method, - dt=dt, - adaptive=adaptive, - rtol=rtol, - atol=atol, - dt_min=dt_min, - options=options - ) - ctx.save_for_backward(ts, flat_params, *ans) - return ans - - @staticmethod - def backward(ctx, *grad_outputs): - ts, flat_params, *ans = ctx.saved_tensors - sde = ctx.sde - dt = ctx.dt - bm = ctx.bm - adjoint_method = ctx.adjoint_method - adjoint_adaptive = ctx.adjoint_adaptive - adjoint_rtol = ctx.adjoint_rtol - adjoint_atol = ctx.adjoint_atol - dt_min = ctx.dt_min - adjoint_options = ctx.adjoint_options - - params = misc.make_seq_requires_grad(sde.parameters()) - n_tensors, n_params = len(ans), len(params) - - reverse_bm = ReverseBrownian(bm) - adjoint_sde = AdjointSDE(forward_sde=sde, params=params, n_tensors=n_tensors) - - T = ans[0].size(0) - adj_y = [grad_outputs_[-1] for grad_outputs_ in grad_outputs] - adj_params = torch.zeros_like(flat_params) - - for i in range(T - 1, 0, -1): - ans_i = [ans_[i] for ans_ in ans] - aug_y0 = (*ans_i, *adj_y, adj_params) - - aug_ans = sdeint.integrate( - sde=adjoint_sde, - y0=aug_y0, - ts=torch.stack([-ts[i], -ts[i - 1]]), - bm=reverse_bm, - method=adjoint_method, - dt=dt, - adaptive=adjoint_adaptive, - rtol=adjoint_rtol, - atol=adjoint_atol, - dt_min=dt_min, - options=adjoint_options - ) - - adj_y = aug_ans[n_tensors:2 * n_tensors] - adj_params = aug_ans[-1] - - adj_y = [adj_y_[1] for adj_y_ in adj_y] - adj_params = adj_params[1] - - adj_y = misc.seq_add(adj_y, [grad_outputs_[i - 1] for grad_outputs_ in grad_outputs]) - - del aug_y0, aug_ans - - return (None, None, adj_params, None, None, None, None, None, None, None, None, None, None, None, None, None, - *adj_y) - - -class _SdeintLogqpAdjointMethod(torch.autograd.Function): - - @staticmethod - def forward(ctx, sde, ts, flat_params, dt, bm, method, adjoint_method, adaptive, adjoint_adaptive, rtol, # noqa - adjoint_rtol, atol, adjoint_atol, dt_min, options, adjoint_options, *y0): - ctx.sde = sde - ctx.dt = dt - ctx.bm = bm - ctx.adjoint_method = adjoint_method - ctx.adjoint_adaptive = adjoint_adaptive - ctx.adjoint_rtol = adjoint_rtol - ctx.adjoint_atol = adjoint_atol - ctx.dt_min = dt_min - ctx.adjoint_options = adjoint_options - - sde = base_sde.ForwardSDE(sde) - ans_and_logqp = sdeint.integrate( + ys = sdeint.integrate( sde=sde, y0=y0, ts=ts, @@ -146,17 +53,13 @@ def forward(ctx, sde, ts, flat_params, dt, bm, method, adjoint_method, adaptive, atol=atol, dt_min=dt_min, options=options, - logqp=True ) - ans, logqp = ans_and_logqp[:len(y0)], ans_and_logqp[len(y0):] - - # Don't need to save `logqp`, since it is never used in the backward pass to compute gradients. - ctx.save_for_backward(ts, flat_params, *ans) - return ans + logqp + ctx.save_for_backward(ys, ts, *params) + return ys @staticmethod - def backward(ctx, *grad_outputs): - ts, flat_params, *ans = ctx.saved_tensors + def backward(ctx, grad_ys): # noqa + ys, ts, *params = ctx.saved_tensors sde = ctx.sde dt = ctx.dt bm = ctx.bm @@ -167,24 +70,16 @@ def backward(ctx, *grad_outputs): dt_min = ctx.dt_min adjoint_options = ctx.adjoint_options - params = misc.make_seq_requires_grad(sde.parameters()) - n_tensors, n_params = len(ans), len(params) - + aug_state = [ys[-1], grad_ys[-1]] + [torch.zeros_like(param) for param in params] + shapes = [t.size() for t in aug_state] + adjoint_sde = AdjointSDE(sde, params, shapes) reverse_bm = ReverseBrownian(bm) - adjoint_sde = AdjointSDE(forward_sde=sde, params=params, n_tensors=n_tensors, logqp=True) - - T = ans[0].size(0) - adj_y = [grad_outputs_[-1] for grad_outputs_ in grad_outputs[:n_tensors]] - adj_l = [grad_outputs_[-1] for grad_outputs_ in grad_outputs[n_tensors:]] - adj_params = torch.zeros_like(flat_params) - for i in range(T - 1, 0, -1): - ans_i = [ans_[i] for ans_ in ans] - aug_y0 = (*ans_i, *adj_y, *adj_l, adj_params) - - aug_ans = sdeint.integrate( + for i in range(ys.size(0) - 1, 0, -1): + aug_state = misc.flatten(aug_state) + aug_state = sdeint.integrate( sde=adjoint_sde, - y0=aug_y0, + y0=aug_state, ts=torch.stack([-ts[i], -ts[i - 1]]), bm=reverse_bm, method=adjoint_method, @@ -195,28 +90,20 @@ def backward(ctx, *grad_outputs): dt_min=dt_min, options=adjoint_options ) + aug_state = misc.flat_to_shape(aug_state[1], shapes) # Unpack the state at time -ts[i - 1]. + aug_state[0] = ys[i - 1] + aug_state[1] = aug_state[1] + grad_ys[i - 1] - adj_y = aug_ans[n_tensors:2 * n_tensors] - adj_params = aug_ans[-1] - - adj_y = [adj_y_[1] for adj_y_ in adj_y] - adj_params = adj_params[1] - - adj_y = misc.seq_add(adj_y, [grad_outputs_[i - 1] for grad_outputs_ in grad_outputs[:n_tensors]]) - adj_l = [grad_outputs_[i - 1] for grad_outputs_ in grad_outputs[n_tensors:]] - - del aug_y0, aug_ans - - return (None, None, adj_params, None, None, None, None, None, None, None, None, None, None, None, None, None, - *adj_y) + return ( + None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, *aug_state[1:] + ) -def sdeint_adjoint(sde, - y0: TensorOrTensors, +def sdeint_adjoint(sde: nn.Module, + y0: torch.Tensor, ts: Vector, bm: Optional[BaseBrownian] = None, - logqp: Optional[bool] = False, - method: Optional[str] = 'srk', + method: Optional[str] = "srk", adjoint_method: Optional[str] = None, dt: Optional[Scalar] = 1e-3, adaptive: Optional[bool] = False, @@ -228,24 +115,22 @@ def sdeint_adjoint(sde, dt_min: Optional[Scalar] = 1e-5, options: Optional[Dict[str, Any]] = None, adjoint_options: Optional[Dict[str, Any]] = None, - names: Optional[Dict[str, str]] = None) -> TensorOrTensors: + names: Optional[Dict[str, str]] = None, + **unused_kwargs) -> torch.Tensor: """Numerically integrate an Itô SDE with stochastic adjoint support. Args: sde (torch.nn.Module): Object with methods `f` and `g` representing the - drift and diffusion. The output of `g` should be a single - (or a tuple of) tensor(s) of size (batch_size, d) for diagonal - noise SDEs or (batch_size, d, m) for SDEs of other noise types; d - is the dimensionality of state and m is the dimensionality of - Brownian motion. - y0 (sequence of Tensor): Tensors for initial state. + drift and diffusion. The output of `g` should be a single tensor of + size (batch_size, d) for diagonal noise SDEs or (batch_size, d, m) + for SDEs of other noise types; d is the dimensionality of state and + m is the dimensionality of Brownian motion. + y0 (Tensor): A tensor for the initial state. ts (Tensor or sequence of float): Query times in non-descending order. The state at the first time of `ts` should be `y0`. - bm (Brownian, optional): A `BrownianPath` or `BrownianTree` object. - Should return tensors of size (batch_size, m) for `__call__`. - Defaults to `BrownianPath` for diagonal noise on CPU. - Currently does not support tuple outputs yet. - logqp (bool, optional): If `True`, also return the log-ratio penalty. + bm (Brownian, optional): A 'BrownianInterval', `BrownianPath` or + `BrownianTree` object. Should return tensors of size (batch_size, m) + for `__call__`. Defaults to `BrownianInterval`. method (str, optional): Name of numerical integration method. adjoint_method (str, optional): Name of numerical integration method for backward adjoint solve. Defaults to a sensible choice depending on @@ -265,42 +150,35 @@ def sdeint_adjoint(sde, options (dict, optional): Dict of options for the integration method. adjoint_options (dict, optional): Dict of options for the integration method of the backward adjoint solve. - names (dict, optional): Dict of method names for drift, diffusion, and - prior drift. Expected keys are "drift", "diffusion", and - "prior_drift". Serves so that users can use methods with names not - in `("f", "g", "h")`, e.g. to use the method "foo" for the drift, - we would supply `names={"drift": "foo"}`. + names (dict, optional): Dict of method names for drift and diffusion. + Expected keys are "drift" and "diffusion". Serves so that users can + use methods with names not in `("f", "g")`, e.g. to use the + method "foo" for the drift, we supply `names={"drift": "foo"}`. Returns: - A single state tensor of size (T, batch_size, d) or a tuple of such - tensors. Also returns a single log-ratio tensor of size - (T - 1, batch_size) or a tuple of such tensors, if `logqp==True`. + A single state tensor of size (T, batch_size, d). Raises: ValueError: An error occurred due to unrecognized noise type/method, or `sde` is missing required methods. """ - if not isinstance(sde, nn.Module): - raise ValueError('sde is required to be an instance of nn.Module.') + misc.handle_unused_kwargs(unused_kwargs, msg="`sdeint_adjoint`") + del unused_kwargs - sde, y0, ts, bm, tensor_input = sdeint.check_contract(sde, y0, ts, bm, logqp, method, names) - adjoint_method = _check_and_select_default_adjoint_method(sde, adjoint_method) + if not isinstance(sde, nn.Module): + raise ValueError("`sde` is required to be an instance of nn.Module.") - flat_params = misc.flatten(sde.parameters()) - if logqp: - return _SdeintLogqpAdjointMethod.apply( # noqa - sde, ts, flat_params, dt, bm, method, adjoint_method, adaptive, adjoint_adaptive, rtol, adjoint_rtol, atol, - adjoint_atol, dt_min, options, adjoint_options, *y0 - ) + sde, y0, ts, bm = sdeint.check_contract(sde, y0, ts, bm, method, names) + adjoint_method = _select_default_adjoint_method(sde, adjoint_method) + params = list(filter(lambda x: x.requires_grad, sde.parameters())) - ys = _SdeintAdjointMethod.apply( # noqa - sde, ts, flat_params, dt, bm, method, adjoint_method, adaptive, adjoint_adaptive, rtol, adjoint_rtol, atol, - adjoint_atol, dt_min, options, adjoint_options, *y0 + return _SdeintAdjointMethod.apply( # noqa + sde, ts, dt, bm, method, adjoint_method, adaptive, adjoint_adaptive, rtol, adjoint_rtol, atol, + adjoint_atol, dt_min, options, adjoint_options, y0, *params ) - return ys[0] if tensor_input else ys -def _check_and_select_default_adjoint_method(sde, adjoint_method: str) -> str: +def _select_default_adjoint_method(sde: base_sde.ForwardSDE, adjoint_method: str) -> str: sde_type, noise_type = sde.sde_type, sde.noise_type if adjoint_method is None: # Select the default based on noise type of forward. @@ -308,7 +186,7 @@ def _check_and_select_default_adjoint_method(sde, adjoint_method: str) -> str: SDE_TYPES.ito: { NOISE_TYPES.diagonal: METHODS.milstein, NOISE_TYPES.additive: METHODS.euler, - NOISE_TYPES.scalar: METHODS.euler, # TODO: Optimize this. + NOISE_TYPES.scalar: METHODS.euler, }.get(noise_type, "unsupported"), SDE_TYPES.stratonovich: { NOISE_TYPES.general: METHODS.midpoint, diff --git a/torchsde/_core/adjoint_sde.py b/torchsde/_core/adjoint_sde.py index 1d71c52..4433c5a 100644 --- a/torchsde/_core/adjoint_sde.py +++ b/torchsde/_core/adjoint_sde.py @@ -13,148 +13,117 @@ # limitations under the License. +from typing import Sequence + import torch from . import base_sde from . import misc from ..settings import SDE_TYPES, NOISE_TYPES +from ..types import TensorOrTensors class AdjointSDE(base_sde.BaseSDE): - def __init__(self, forward_sde, params, n_tensors, logqp=False): + def __init__(self, + sde: base_sde.ForwardSDE, + params: TensorOrTensors, + shapes: Sequence[torch.Size]): # There's a mapping from the noise type of the forward SDE to the noise type of the adjoint. # Usually, these two aren't the same, e.g. when the forward SDE has additive noise, the adjoint SDE's diffusion # is a linear function of the adjoint variable, so it is not of additive noise. - sde_type = forward_sde.sde_type + sde_type = sde.sde_type noise_type = { NOISE_TYPES.general: NOISE_TYPES.general, NOISE_TYPES.additive: NOISE_TYPES.general, NOISE_TYPES.scalar: NOISE_TYPES.scalar, NOISE_TYPES.diagonal: NOISE_TYPES.diagonal, - }[forward_sde.noise_type] - + }.get(sde.noise_type) super(AdjointSDE, self).__init__(sde_type=sde_type, noise_type=noise_type) - self._base_sde = forward_sde - self._params = params - self._n_tensors = n_tensors - - # Register the core function. This avoids polluting the codebase with if-statements and speeds things up. - # The `sde_type` and `noise_type` of the forward SDE determines the registered functions. - if logqp: - self.f = { - SDE_TYPES.ito: { - NOISE_TYPES.diagonal: self.f_corrected_diagonal_logqp, - NOISE_TYPES.additive: self.f_uncorrected_logqp, - NOISE_TYPES.scalar: self.f_corrected_default_logqp, - NOISE_TYPES.general: self.f_corrected_default_logqp - }[forward_sde.noise_type], - SDE_TYPES.stratonovich: self.f_uncorrected_logqp - }[forward_sde.sde_type] - - self.g_prod = { - NOISE_TYPES.diagonal: self.g_prod_diagonal_logqp - }.get(forward_sde.noise_type, self.g_prod_default_logqp) - - self.gdg_prod = { - NOISE_TYPES.diagonal: self.gdg_prod_diagonal_logqp, - }.get(forward_sde.noise_type, self.gdg_prod_default_logqp) - else: - self.f = { - SDE_TYPES.ito: { - NOISE_TYPES.diagonal: self.f_corrected_diagonal, - NOISE_TYPES.additive: self.f_uncorrected, - NOISE_TYPES.scalar: self.f_corrected_default, - NOISE_TYPES.general: self.f_corrected_default - }[forward_sde.noise_type], - SDE_TYPES.stratonovich: self.f_uncorrected - }[forward_sde.sde_type] - self.g_prod = { - NOISE_TYPES.diagonal: self.g_prod_diagonal - }.get(forward_sde.noise_type, self.g_prod_default) - - self.gdg_prod = { - NOISE_TYPES.diagonal: self.gdg_prod_diagonal, - }.get(forward_sde.noise_type, self.gdg_prod_default) + self._base_sde = sde + self._params = params + self._shapes = shapes + + # Register the core functions. This avoids polluting the codebase with if-statements and achieves speed-ups + # by making sure it's a one-time cost. The `sde_type` and `noise_type` of the forward SDE determines the + # registered functions. + self.f = { + SDE_TYPES.ito: { + NOISE_TYPES.diagonal: self.f_corrected_diagonal, + NOISE_TYPES.additive: self.f_uncorrected, + NOISE_TYPES.scalar: self.f_corrected_default, + NOISE_TYPES.general: self.f_corrected_default + }.get(sde.noise_type), + SDE_TYPES.stratonovich: self.f_uncorrected + }.get(sde.sde_type) + self.gdg_prod = { + NOISE_TYPES.diagonal: self.gdg_prod_diagonal, + }.get(sde.noise_type, self.gdg_prod_default) + + def _flat_to_shape(self, y_aug): + """Recover only first two tensors from the flattened augmented state.""" + return misc.flat_to_shape(y_aug, self._shapes[:2]) ######################################## # f # ######################################## def f_uncorrected(self, t, y_aug): # For Ito additive and Stratonovich. - sde, params, n_tensors = self._base_sde, self._params, self._n_tensors - y, adj_y = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors] - with torch.enable_grad(): - y = [y_.detach().requires_grad_(True) for y_ in y] - minus_adj_y = [-adj_y_.detach() for adj_y_ in adj_y] - minus_f = [-f_ for f_ in sde.f(-t, y)] + y_aug = self._flat_to_shape(y_aug) + y = y_aug[0].detach().requires_grad_(True) + adj_y = y_aug[1].detach() + f = self._base_sde.f(-t, y) vjp_y_and_params = misc.grad( - outputs=minus_f, - inputs=y + params, - grad_outputs=minus_adj_y, - allow_unused=True, + outputs=f, + inputs=[y] + self._params, + grad_outputs=adj_y, + allow_unused=True ) - vjp_y, vjp_params = vjp_y_and_params[:n_tensors], vjp_y_and_params[n_tensors:] - vjp_params = misc.flatten(vjp_params) - - return (*minus_f, *vjp_y, vjp_params) + return misc.flatten((-f, *vjp_y_and_params)) def f_corrected_default(self, t, y_aug): # For Ito general/scalar. raise NotImplementedError def f_corrected_diagonal(self, t, y_aug): # For Ito diagonal. - sde, params, n_tensors = self._base_sde, self._params, self._n_tensors - y, adj_y = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors] - with torch.enable_grad(): - y = [y_.detach().requires_grad_(True) for y_ in y] - adj_y = [adj_y_.detach() for adj_y_ in adj_y] - - g_eval = sde.g(-t, y) - gdg = misc.grad( - outputs=g_eval, + y_aug = self._flat_to_shape(y_aug) + y = y_aug[0].detach().requires_grad_(True) + adj_y = y_aug[1].detach() + g = self._base_sde.g(-t, y) + g_dg_vjp, = misc.grad( + outputs=g, inputs=y, - grad_outputs=g_eval, + grad_outputs=g, allow_unused=True, create_graph=True ) - f_eval = sde.f(-t, y) - # Stratonovich correction for reverse-time. - f_eval_corrected = misc.seq_sub(gdg, f_eval) + # Double Stratonovich correction. + f = self._base_sde.f(-t, y) - g_dg_vjp vjp_y_and_params = misc.grad( - outputs=f_eval_corrected, - inputs=y + params, - grad_outputs=[-adj_y_ for adj_y_ in adj_y], + outputs=f, + inputs=[y] + self._params, + grad_outputs=adj_y, allow_unused=True, create_graph=True ) - vjp_y, vjp_params = vjp_y_and_params[:n_tensors], vjp_y_and_params[n_tensors:] - vjp_params = misc.flatten(vjp_params) - - adj_times_dgdx = misc.grad( - outputs=g_eval, + # Convert the adjoint Stratonovich SDE to Itô form. + a_dg_vjp, = misc.grad( + outputs=g, inputs=y, grad_outputs=adj_y, allow_unused=True, create_graph=True ) - - # Converting the *adjoint* Stratonovich backward SDE to Itô. extra_vjp_y_and_params = misc.grad( - outputs=g_eval, - inputs=y + params, - grad_outputs=adj_times_dgdx, + outputs=g, + inputs=[y] + self._params, + grad_outputs=a_dg_vjp, allow_unused=True, ) - extra_vjp_y, extra_vjp_params = extra_vjp_y_and_params[:n_tensors], extra_vjp_y_and_params[n_tensors:] - extra_vjp_params = misc.flatten(extra_vjp_params) - - vjp_y = misc.seq_add(vjp_y, extra_vjp_y) - vjp_params = vjp_params + extra_vjp_params - - return (*f_eval_corrected, *vjp_y, vjp_params) + vjp_y_and_params = misc.seq_add(vjp_y_and_params, extra_vjp_y_and_params) + return misc.flatten((-f, *vjp_y_and_params)) ######################################## # g # @@ -171,47 +140,19 @@ def g(self, t, y): # g_prod # ######################################## - def g_prod_default(self, t, y_aug, v): # For Ito/Stratonovich general/additive/scalar. - sde, params, n_tensors = self._base_sde, self._params, self._n_tensors - y, adj_y = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors] - + def g_prod(self, t, y_aug, v): with torch.enable_grad(): - y = [y_.detach().requires_grad_(True) for y_ in y] - minus_adj_y = [-adj_y_.detach() for adj_y_ in adj_y] - minus_g = [-g_ for g_ in sde.g(-t, y)] - minus_g_prod = misc.seq_batch_mvp(minus_g, v) - minus_g_weighted = [(minus_g_ * v_.unsqueeze(-2)).sum(-1) for minus_g_, v_ in zip(minus_g, v)] + y_aug = self._flat_to_shape(y_aug) + y = y_aug[0].detach().requires_grad_(True) + adj_y = y_aug[1].detach() + g_prod = self._base_sde.g_prod(-t, y, v) vjp_y_and_params = misc.grad( - outputs=minus_g_weighted, - inputs=y + params, - grad_outputs=minus_adj_y, - allow_unused=True, - ) - vjp_y, vjp_params = vjp_y_and_params[:n_tensors], vjp_y_and_params[n_tensors:] - vjp_params = misc.flatten(vjp_params) - - return (*minus_g_prod, *vjp_y, vjp_params) - - def g_prod_diagonal(self, t, y_aug, v): # For Ito/Stratonovich diagonal. - sde, params, n_tensors = self._base_sde, self._params, self._n_tensors - y, adj_y = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors] - - with torch.enable_grad(): - y = [y_.detach().requires_grad_(True) for y_ in y] - adj_y = [adj_y_.detach() for adj_y_ in adj_y] - - g_eval = [-g_ for g_ in sde.g(-t, y)] - g_prod_eval = misc.seq_mul(g_eval, v) - vjp_y_and_params = misc.grad( - outputs=g_eval, - inputs=y + params, - grad_outputs=[-v_ * adj_y_ for v_, adj_y_ in zip(v, adj_y)], + outputs=g_prod, + inputs=[y] + self._params, + grad_outputs=adj_y, allow_unused=True, ) - vjp_y, vjp_params = vjp_y_and_params[:n_tensors], vjp_y_and_params[n_tensors:] - vjp_params = misc.flatten(vjp_params) - - return (*g_prod_eval, *vjp_y, vjp_params) + return misc.flatten((-g_prod, *vjp_y_and_params)) ######################################## # gdg_prod # @@ -221,208 +162,42 @@ def gdg_prod_default(self, t, y, v): # For Ito/Stratonovich general/additive/sc raise NotImplementedError def gdg_prod_diagonal(self, t, y_aug, v): # For Ito/Stratonovich diagonal. - sde, params, n_tensors = self._base_sde, self._params, self._n_tensors - y, adj_y = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors] - with torch.enable_grad(): - y = [y_.detach().requires_grad_(True) for y_ in y] - adj_y = [adj_y_.detach().requires_grad_(True) for adj_y_ in adj_y] - - g_eval = sde.g(-t, y) - gdg_times_v = misc.grad( - outputs=g_eval, + y_aug = self._flat_to_shape(y_aug) + y = y_aug[0].detach().requires_grad_(True) + adj_y = y_aug[1].detach() + g = self._base_sde.g(-t, y) + vg_dg_vjp, = misc.grad( + outputs=g, inputs=y, - grad_outputs=misc.seq_mul(g_eval, v), + grad_outputs=v * g, allow_unused=True, create_graph=True, ) - dgdy = misc.grad( - outputs=g_eval, + dgdy, = misc.grad( + outputs=g.sum(), inputs=y, - grad_outputs=[torch.ones_like(y_) for y_ in y], allow_unused=True, create_graph=True, ) prod_partials_adj_y_and_params = misc.grad( - outputs=g_eval, - inputs=y + params, - grad_outputs=misc.seq_mul(adj_y, v, dgdy), - allow_unused=True, - create_graph=True, - ) - prod_partials_adj_y = prod_partials_adj_y_and_params[:n_tensors] - prod_partials_params = prod_partials_adj_y_and_params[n_tensors:] - prod_partials_params = misc.flatten(prod_partials_params) - - gdg_v = misc.grad( - outputs=g_eval, - inputs=y, - grad_outputs=[p.detach() for p in misc.seq_mul(adj_y, v, g_eval)], - allow_unused=True, - create_graph=True - ) - mixed_partials_adj_y_and_params = misc.grad( - outputs=gdg_v, - inputs=y + params, - grad_outputs=[torch.ones_like(p) for p in gdg_v], - allow_unused=True, - ) - mixed_partials_adj_y = mixed_partials_adj_y_and_params[:n_tensors] - mixed_partials_params = mixed_partials_adj_y_and_params[n_tensors:] - mixed_partials_params = misc.flatten(mixed_partials_params) - - return ( - *gdg_times_v, - *misc.seq_sub(prod_partials_adj_y, mixed_partials_adj_y), - prod_partials_params - mixed_partials_params - ) - - ######################################## - # f_logqp # - ######################################## - - def f_uncorrected_logqp(self, t, y_aug): - sde, params, n_tensors = self._base_sde, self._params, self._n_tensors - y, adj_y, adj_l = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors], y_aug[2 * n_tensors:3 * n_tensors] - vjp_l = [torch.zeros_like(adj_l_) for adj_l_ in adj_l] - - with torch.enable_grad(): - y = [y_.detach().requires_grad_(True) for y_ in y] - adj_y = [adj_y_.detach() for adj_y_ in adj_y] - - f_eval = sde.f(-t, y) - f_eval = [-f_eval_ for f_eval_ in f_eval] - vjp_y_and_params = misc.grad( - outputs=f_eval, - inputs=y + params, - grad_outputs=[-adj_y_ for adj_y_ in adj_y], + outputs=g, + inputs=[y] + self._params, + grad_outputs=adj_y * v * dgdy, allow_unused=True, create_graph=True ) - vjp_y, vjp_params = vjp_y_and_params[:n_tensors], vjp_y_and_params[n_tensors:] - vjp_params = misc.flatten(vjp_params) - - # Vector field change due to log-ratio term, i.e. ||u||^2 / 2. - g_eval = sde.g(-t, y) - h_eval = sde.h(-t, y) - - g_inv_eval = [torch.pinverse(g_eval_) for g_eval_ in g_eval] - u_eval = misc.seq_sub(f_eval, h_eval) - u_eval = [torch.bmm(g_inv_eval_, u_eval_) for g_inv_eval_, u_eval_ in zip(g_inv_eval, u_eval)] - log_ratio_correction = [.5 * torch.sum(u_eval_ ** 2., dim=1) for u_eval_ in u_eval] - corr_vjp_y_and_params = misc.grad( - outputs=log_ratio_correction, - inputs=y + params, - grad_outputs=adj_l, - allow_unused=True, - ) - corr_vjp_y, corr_vjp_params = corr_vjp_y_and_params[:n_tensors], corr_vjp_y_and_params[n_tensors:] - corr_vjp_params = misc.flatten(corr_vjp_params) - - vjp_y = misc.seq_add(vjp_y, corr_vjp_y) - vjp_params = vjp_params + corr_vjp_params - - return (*f_eval, *vjp_y, *vjp_l, vjp_params) - - def f_corrected_default_logqp(self, t, y_aug): - raise NotImplementedError - - def f_corrected_diagonal_logqp(self, t, y_aug): - sde, params, n_tensors = self._base_sde, self._params, self._n_tensors - y, adj_y, adj_l = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors], y_aug[2 * n_tensors:3 * n_tensors] - vjp_l = [torch.zeros_like(adj_l_) for adj_l_ in adj_l] - - with torch.enable_grad(): - y = [y_.detach().requires_grad_(True) for y_ in y] - adj_y = [adj_y_.detach() for adj_y_ in adj_y] - - g_eval = sde.g(-t, y) - gdg = misc.grad( - outputs=g_eval, + avg_dg_vjp, = misc.grad( + outputs=g, inputs=y, - grad_outputs=g_eval, - allow_unused=True, - create_graph=True, - ) - - f_eval = sde.f(-t, y) - f_eval_corrected = misc.seq_sub(gdg, f_eval) - vjp_y_and_params = misc.grad( - outputs=f_eval_corrected, - inputs=y + params, - grad_outputs=[-adj_y_ for adj_y_ in adj_y], + grad_outputs=(adj_y * v * g).detach(), allow_unused=True, create_graph=True ) - vjp_y, vjp_params = vjp_y_and_params[:n_tensors], vjp_y_and_params[n_tensors:] - vjp_params = misc.flatten(vjp_params) - - adj_times_dgdx = misc.grad( - outputs=g_eval, - inputs=y, - grad_outputs=adj_y, - allow_unused=True, - create_graph=True - ) - extra_vjp_y_and_params = misc.grad( - outputs=g_eval, - inputs=y + params, - grad_outputs=adj_times_dgdx, - allow_unused=True, - create_graph=True, - ) - extra_vjp_y, extra_vjp_params = extra_vjp_y_and_params[:n_tensors], extra_vjp_y_and_params[n_tensors:] - extra_vjp_params = misc.flatten(extra_vjp_params) - - # Vector field change due to log-ratio term, i.e. ||u||^2 / 2. - h_eval = sde.h(-t, y) - u_eval = misc.seq_sub_div(f_eval, h_eval, g_eval) - log_ratio_correction = [.5 * torch.sum(u_eval_ ** 2., dim=1) for u_eval_ in u_eval] - corr_vjp_y_and_params = misc.grad( - outputs=log_ratio_correction, - inputs=y + params, - grad_outputs=adj_l, + mixed_partials_adj_y_and_params = misc.grad( + outputs=avg_dg_vjp.sum(), + inputs=[y] + self._params, allow_unused=True, ) - corr_vjp_y, corr_vjp_params = corr_vjp_y_and_params[:n_tensors], corr_vjp_y_and_params[n_tensors:] - corr_vjp_params = misc.flatten(corr_vjp_params) - - vjp_y = misc.seq_add(vjp_y, extra_vjp_y, corr_vjp_y) - vjp_params = vjp_params + extra_vjp_params + corr_vjp_params - - return (*f_eval_corrected, *vjp_y, *vjp_l, vjp_params) - - ######################################## - # g_prod_logqp # - ######################################## - - def g_prod_default_logqp(self, t, y_aug, v): - n_tensors = self._n_tensors - adj_l = y_aug[2 * n_tensors:3 * n_tensors] - vjp_l = [torch.zeros_like(adj_l_) for adj_l_ in adj_l] - results = self.g_prod_default(t, y_aug, v) - g_prod_eval, vjp_y, vjp_params = results[:n_tensors], results[n_tensors:2 * n_tensors], results[2 * n_tensors] - return (*g_prod_eval, *vjp_y, *vjp_l, vjp_params) - - def g_prod_diagonal_logqp(self, t, y_aug, v): - n_tensors = self._n_tensors - adj_l = y_aug[2 * n_tensors:3 * n_tensors] - vjp_l = [torch.zeros_like(adj_l_) for adj_l_ in adj_l] - results = self.g_prod_diagonal(t, y_aug, v) - g_prod_eval, vjp_y, vjp_params = results[:n_tensors], results[n_tensors:2 * n_tensors], results[2 * n_tensors] - return (*g_prod_eval, *vjp_y, *vjp_l, vjp_params) - - ######################################## - # gdg_prod_logqp # - ######################################## - - def gdg_prod_diagonal_logqp(self, t, y_aug, v): - n_tensors = self._n_tensors - adj_l = y_aug[2 * n_tensors:3 * n_tensors] - vjp_l = [torch.zeros_like(adj_l_) for adj_l_ in adj_l] - results = self.gdg_prod_diagonal(t, y_aug, v) - gdg_v, vjp_y, vjp_params = results[:n_tensors], results[n_tensors:2 * n_tensors], results[2 * n_tensors] - return (*gdg_v, *vjp_y, *vjp_l, vjp_params) - - def gdg_prod_default_logqp(self, t, y_aug, v): - raise NotImplementedError + vjp_y_and_params = misc.seq_sub(prod_partials_adj_y_and_params, mixed_partials_adj_y_and_params) + return misc.flatten((vg_dg_vjp, *vjp_y_and_params)) diff --git a/torchsde/_core/base_sde.py b/torchsde/_core/base_sde.py index 61fa418..b161297 100644 --- a/torchsde/_core/base_sde.py +++ b/torchsde/_core/base_sde.py @@ -33,7 +33,7 @@ def __init__(self, noise_type, sde_type): raise ValueError(f"Expected noise type in {NOISE_TYPES}, but found {noise_type}") if sde_type not in SDE_TYPES: raise ValueError(f"Expected sde type in {SDE_TYPES}, but found {sde_type}") - # Making these Python properties breaks `torch.jit.script` + # Making these Python properties breaks `torch.jit.script`. self.noise_type = noise_type self.sde_type = sde_type @@ -45,131 +45,126 @@ def __init__(self, sde): self._base_sde = sde self.f = sde.f self.g = sde.g - if hasattr(sde, "h"): - self.h = sde.h - # Register the core function. This avoids polluting the codebase with if-statements. + # Register the core functions. This avoids polluting the codebase with if-statements and achieves speed-ups + # by making sure it's a one-time cost. self.g_prod = { NOISE_TYPES.diagonal: self.g_prod_diagonal, - NOISE_TYPES.additive: self.g_prod_additive_or_general, - NOISE_TYPES.scalar: self.g_prod_scalar, - NOISE_TYPES.general: self.g_prod_additive_or_general - }[sde.noise_type] + }.get(sde.noise_type, self.g_prod_default) self.gdg_prod = { - NOISE_TYPES.diagonal: self.gdg_prod_diagonal_or_scalar, - NOISE_TYPES.additive: self._skip, - NOISE_TYPES.scalar: self.gdg_prod_diagonal_or_scalar, - NOISE_TYPES.general: self.gdg_prod_general - }[sde.noise_type] - self.gdg_jvp_column_sum = { - NOISE_TYPES.diagonal: self._skip, - NOISE_TYPES.additive: self._skip, - NOISE_TYPES.scalar: self._skip, - NOISE_TYPES.general: self.gdg_jvp_column_sum_v2 - }[sde.noise_type] - # TODO: Assign `gdg_jacobian_contraction`. - - # g_prod functions. + NOISE_TYPES.diagonal: self.gdg_prod_diagonal, + NOISE_TYPES.additive: self._return_zero, + }.get(sde.noise_type, self.gdg_prod_default) + self.dg_ga_jvp_column_sum = { + NOISE_TYPES.diagonal: self._return_zero, + NOISE_TYPES.additive: self._return_zero, + NOISE_TYPES.scalar: self._return_zero, + NOISE_TYPES.general: self.dg_ga_jvp_column_sum_v2 + }.get(sde.noise_type) + + ######################################## + # g_prod # + ######################################## + def g_prod_diagonal(self, t, y, v): - return misc.seq_mul(self._base_sde.g(t, y), v) + return self._base_sde.g(t, y) * v - def g_prod_scalar(self, t, y, v): - return misc.seq_mul_bc(self._base_sde.g(t, y), v) + def g_prod_default(self, t, y, v): + return misc.batch_mvp(self._base_sde.g(t, y), v) - def g_prod_additive_or_general(self, t, y, v): - return misc.seq_batch_mvp(ms=self._base_sde.g(t, y), vs=v) + ######################################## + # gdg_prod # + ######################################## - # gdg_prod functions. - def gdg_prod_general(self, t, y, v): - # This function is used for Milstein. For general noise, Levy areas need to be supplied. - raise NotImplemented("This function should not be called.") + # Computes: sum_{j, l} g_{j, l} d g_{j, l} d x_i v_l. + def gdg_prod_default(self, t, y, v): + requires_grad = torch.is_grad_enabled() + with torch.enable_grad(): + y = y if y.requires_grad else y.detach().requires_grad_(True) + g = self._base_sde.g(t, y) + vg_dg_vjp, = misc.grad( + outputs=g, + inputs=y, + grad_outputs=g * v.unsqueeze(-2), + create_graph=requires_grad, + allow_unused=True + ) + return vg_dg_vjp - def gdg_prod_diagonal_or_scalar(self, t, y, v): - requires_grad = torch.is_grad_enabled() # BP through solver. + def gdg_prod_diagonal(self, t, y, v): + requires_grad = torch.is_grad_enabled() with torch.enable_grad(): - y = [y_ if y_.requires_grad else y_.detach().requires_grad_(True) for y_ in y] - val = self._base_sde.g(t, y) - vjp_val = misc.grad( - outputs=val, + y = y if y.requires_grad else y.detach().requires_grad_(True) + g = self._base_sde.g(t, y) + vg_dg_vjp, = misc.grad( + outputs=g, inputs=y, - grad_outputs=misc.seq_mul(val, v), + grad_outputs=g * v, create_graph=requires_grad, allow_unused=True ) - return misc.convert_none_to_zeros(vjp_val, y) + return vg_dg_vjp + + ######################################## + # dg_ga_jvp # + ######################################## - # gdg_jvp_column_sum functions. - # Computes: sum_{j,k,l} d sigma_{i,l} / d x_j sigma_{j,k} A_{k,l}. - def gdg_jvp_column_sum_v1(self, t, y, a): - # Assumes `a` is anti-symmetric and `base_sde` is not of diagonal noise. - requires_grad = torch.is_grad_enabled() # BP through solver. + # Computes: sum_{j,k,l} d g_{i,l} / d x_j g_{j,k} A_{k,l}. + def dg_ga_jvp_column_sum_v1(self, t, y, a): + # Assumes `a` is anti-symmetric and `_base_sde` is not of diagonal noise. + requires_grad = torch.is_grad_enabled() with torch.enable_grad(): - y = [y_ if y_.requires_grad else y_.detach().requires_grad_(True) for y_ in y] - g_eval = self._base_sde.g(t, y) - v = [torch.bmm(g_eval_, a_) for g_eval_, a_ in zip(g_eval, a)] - gdg_jvp_eval = [ + y = y if y.requires_grad else y.detach().requires_grad_(True) + g = self._base_sde.g(t, y) + ga = torch.bmm(g, a) + dg_ga_jvp = [ misc.jvp( - outputs=[g_eval_[..., col_idx] for g_eval_ in g_eval], + outputs=g[..., col_idx], inputs=y, - grad_inputs=[v_[..., col_idx] for v_ in v], + grad_inputs=ga[..., col_idx], retain_graph=True, create_graph=requires_grad, allow_unused=True - ) - for col_idx in range(g_eval[0].size(-1)) + )[0] + for col_idx in range(g.size(-1)) ] - gdg_jvp_eval = misc.seq_add(*gdg_jvp_eval) - return misc.convert_none_to_zeros(gdg_jvp_eval, y) + dg_ga_jvp = sum(dg_ga_jvp) + return dg_ga_jvp - def gdg_jvp_column_sum_v2(self, t, y, a): + def dg_ga_jvp_column_sum_v2(self, t, y, a): # Faster, but more memory intensive. - requires_grad = torch.is_grad_enabled() # BP through solver. + requires_grad = torch.is_grad_enabled() with torch.enable_grad(): - y = [y_ if y_.requires_grad else y_.detach().requires_grad_(True) for y_ in y] - g_eval = self._base_sde.g(t, y) - v = [torch.bmm(g_eval_, a_) for g_eval_, a_ in zip(g_eval, a)] - - batch_size, d, m = g_eval[0].size() # TODO: Relax this assumption. - y_dup = [torch.repeat_interleave(y_, repeats=m, dim=0) for y_ in y] - g_eval_dup = self._base_sde.g(t, y_dup) - v_flat = [v_.transpose(1, 2).flatten(0, 1) for v_ in v] - gdg_jvp_eval = misc.jvp( - g_eval_dup, y_dup, grad_inputs=v_flat, create_graph=requires_grad, allow_unused=True + y = y if y.requires_grad else y.detach().requires_grad_(True) + g = self._base_sde.g(t, y) + ga = torch.bmm(g, a) + + batch_size, d, m = g.size() + y_dup = torch.repeat_interleave(y, repeats=m, dim=0) + g_dup = self._base_sde.g(t, y_dup) + ga_flat = ga.transpose(1, 2).flatten(0, 1) + dg_ga_jvp, = misc.jvp( + outputs=g_dup, + inputs=y_dup, + grad_inputs=ga_flat, + create_graph=requires_grad, + allow_unused=True ) - gdg_jvp_eval = misc.convert_none_to_zeros(gdg_jvp_eval, y) - gdg_jvp_eval = [t.reshape(batch_size, m, d, m).permute(0, 2, 1, 3) for t in gdg_jvp_eval] - gdg_jvp_eval = [t.diagonal(dim1=-2, dim2=-1).sum(-1) for t in gdg_jvp_eval] - return gdg_jvp_eval - - def _skip(self, t, y, v): # noqa - return [0.] * len(y) - - -class TupleSDE(BaseSDE): - - def __init__(self, sde): - super(TupleSDE, self).__init__(noise_type=sde.noise_type, sde_type=sde.sde_type) - self._base_sde = sde - - def f(self, t, y): - return self._base_sde.f(t, y[0]), - - def g(self, t, y): - return self._base_sde.g(t, y[0]), + dg_ga_jvp = dg_ga_jvp.reshape(batch_size, m, d, m).permute(0, 2, 1, 3) + dg_ga_jvp = dg_ga_jvp.diagonal(dim1=-2, dim2=-1).sum(-1) + return dg_ga_jvp - def h(self, t, y): - return self._base_sde.h(t, y[0]), + def _return_zero(self, t, y, v): # noqa + return 0. class RenameMethodsSDE(BaseSDE): - def __init__(self, sde, drift='f', diffusion='g', prior_drift='h'): + def __init__(self, sde, drift='f', diffusion='g'): super(RenameMethodsSDE, self).__init__(noise_type=sde.noise_type, sde_type=sde.sde_type) self._base_sde = sde self.f = getattr(sde, drift) self.g = getattr(sde, diffusion) - if hasattr(sde, prior_drift): - self.h = getattr(sde, prior_drift) class SDEIto(BaseSDE): diff --git a/torchsde/_core/base_solver.py b/torchsde/_core/base_solver.py index 2cc1494..e8924e1 100644 --- a/torchsde/_core/base_solver.py +++ b/torchsde/_core/base_solver.py @@ -44,8 +44,8 @@ def __init__(self, sde, bm, y0, dt, adaptive, rtol, atol, dt_min, options, **kwa f"SDE solver requires one of {self.levy_area_approximations} set as the `levy_area_approximation` on the " f"Brownian motion." ) - if sde.noise_type == NOISE_TYPES.scalar and torch.Size(bm.shape[1:]).numel() != 1: - raise ValueError('The Brownian motion for scalar SDEs must of dimension 1.') + if sde.noise_type == NOISE_TYPES.scalar and torch.Size(bm.shape[1:]).numel() != 1: # noqa + raise ValueError("The Brownian motion for scalar SDEs must of dimension 1.") self.sde = sde self.bm = bm @@ -75,34 +75,6 @@ def step(self, t0, t1, y0): """ raise NotImplementedError - def step_logqp(self, t0, t1, y0, logqp0): - y1 = self.step(t0, t1, y0) - - dt = t1 - t0 - if self.sde.noise_type in (NOISE_TYPES.diagonal, NOISE_TYPES.scalar): - f_eval = self.sde.f(t0, y0) - g_eval = self.sde.g(t0, y0) - h_eval = self.sde.h(t0, y0) - u_eval = misc.seq_sub_div(f_eval, h_eval, g_eval) - logqp1 = [ - logqp0_i + .5 * torch.sum(u_eval_i ** 2., dim=1) * dt - for logqp0_i, u_eval_i in zip(logqp0, u_eval) - ] - else: - f_eval = self.sde.f(t0, y0) - g_eval = self.sde.g(t0, y0) - h_eval = self.sde.h(t0, y0) - - g_inv_eval = [torch.pinverse(g_eval_) for g_eval_ in g_eval] - u_eval = misc.seq_sub(f_eval, h_eval) - u_eval = misc.seq_batch_mvp(ms=g_inv_eval, vs=u_eval) - logqp1 = [ - logqp0_i + .5 * torch.sum(u_eval_i ** 2., dim=1) * dt - for logqp0_i, u_eval_i in zip(logqp0, u_eval) - ] - return y1, logqp1 - - # TODO: unify integrate and integrate_logqp? My IDE spits out so many warnings about duplicate code. def integrate(self, ts): """Integrate along trajectory. @@ -110,7 +82,7 @@ def integrate(self, ts): A single state tensor of size (T, batch_size, d) (or tuple). """ assert misc.is_strictly_increasing(ts), "Evaluation times `ts` must be strictly increasing." - y0, dt, adaptive, rtol, atol, dt_min = (self.y0, self.dt, self.adaptive, self.rtol, self.atol, self.dt_min) + y0, dt, adaptive, rtol, atol, dt_min = self.y0, self.dt, self.adaptive, self.rtol, self.atol, self.dt_min step_size = dt @@ -141,7 +113,7 @@ def integrate(self, ts): ) if step_size < dt_min: - warnings.warn('Hitting minimum allowed step size in adaptive time-stepping.') + warnings.warn("Hitting minimum allowed step size in adaptive time-stepping.") step_size = dt_min prev_error_ratio = None @@ -154,69 +126,4 @@ def integrate(self, ts): curr_t, curr_y = next_t, self.step(curr_t, next_t, curr_y) ys.append(interp.linear_interp(t0=prev_t, y0=prev_y, t1=curr_t, y1=curr_y, t=out_t)) - ans = tuple(torch.stack([ys[j][i] for j in range(len(ts))], dim=0) for i in range(len(y0))) - return ans - - def integrate_logqp(self, ts): - """Integrate along trajectory; also return the log-ratio. - - Returns: - A single state tensor of size (T, batch_size, d) (or tuple), and a single log-ratio tensor of - size (T - 1, batch_size) (or tuple). - """ - assert misc.is_strictly_increasing(ts), "Evaluation times `ts` must be strictly increasing." - y0, dt, adaptive, rtol, atol, dt_min = (self.y0, self.dt, self.adaptive, self.rtol, self.atol, self.dt_min) - - step_size = dt - - prev_t = curr_t = ts[0] - prev_y = curr_y = y0 - - ys = [y0] - prev_error_ratio = None - logqp = [[] for _ in y0] - - for out_t in ts[1:]: - curr_logqp = [0. for _ in y0] - prev_logqp = curr_logqp - while curr_t < out_t: - next_t = min(curr_t + step_size, ts[-1]) - if adaptive: - # Take 1 full step. - next_y_full = self.step(curr_t, next_t, curr_y) - # Take 2 half steps. - midpoint_t = 0.5 * (curr_t + next_t) - midpoint_y, midpoint_logqp = self.step_logqp(curr_t, midpoint_t, curr_y, curr_logqp) - next_y, next_logqp = self.step_logqp(midpoint_t, next_t, midpoint_y, midpoint_logqp) - - # Estimate error based on difference between 1 full step and 2 half steps. - with torch.no_grad(): - error_estimate = adaptive_stepping.compute_error(next_y_full, next_y, rtol, atol) - step_size, prev_error_ratio = adaptive_stepping.update_step_size( - error_estimate=error_estimate, - prev_step_size=step_size, - prev_error_ratio=prev_error_ratio - ) - - if step_size < dt_min: - warnings.warn('Hitting minimum allowed step size in adaptive time-stepping.') - step_size = dt_min - prev_error_ratio = None - - # Accept step. - if error_estimate <= 1 or step_size <= dt_min: - prev_t, prev_y, prev_logqp = curr_t, curr_y, curr_logqp - curr_t, curr_y, curr_logqp = next_t, next_y, next_logqp - else: - prev_t, prev_y, prev_logqp = curr_t, curr_y, curr_logqp - curr_y, curr_logqp = self.step_logqp(curr_t, next_t, curr_y, curr_logqp) - curr_t = next_t - ret_y, ret_logqp = interp.linear_interp_logqp(t0=prev_t, y0=prev_y, logqp0=prev_logqp, t1=curr_t, - y1=curr_y, logqp1=curr_logqp, t=out_t) - ys.append(ret_y) - [logqp_i.append(ret_logqp_i) for logqp_i, ret_logqp_i in zip(logqp, ret_logqp)] - - - ans = [torch.stack([ys[j][i] for j in range(len(ts))], dim=0) for i in range(len(y0))] - logqp = [torch.stack(logqp_i, dim=0) for logqp_i in logqp] - return (*ans, *logqp) + return torch.stack(ys, dim=0) diff --git a/torchsde/_core/interp.py b/torchsde/_core/interp.py index a14500e..a4ca576 100644 --- a/torchsde/_core/interp.py +++ b/torchsde/_core/interp.py @@ -14,13 +14,6 @@ def linear_interp(t0, y0, t1, y1, t): - assert t0 <= t <= t1, f'Incorrect time order for linear interpolation: t0={t0}, t={t}, t1={t1}.' - y = [(t1 - t) / (t1 - t0) * y0_ + (t - t0) / (t1 - t0) * y1_ for y0_, y1_ in zip(y0, y1)] + assert t0 <= t <= t1, f"Incorrect time order for linear interpolation: t0={t0}, t={t}, t1={t1}." + y = (t1 - t) / (t1 - t0) * y0 + (t - t0) / (t1 - t0) * y1 return y - - -def linear_interp_logqp(t0, y0, logqp0, t1, y1, logqp1, t): - assert t0 <= t <= t1, f'Incorrect time order for linear interpolation: t0={t0}, t={t}, t1={t1}.' - y = [(t1 - t) / (t1 - t0) * y0_ + (t - t0) / (t1 - t0) * y1_ for y0_, y1_ in zip(y0, y1)] - logqp = [(t1 - t) / (t1 - t0) * l0 + (t - t0) / (t1 - t0) * l1 for l0, l1 in zip(logqp0, logqp1)] - return y, logqp diff --git a/torchsde/_core/methods/euler.py b/torchsde/_core/methods/euler.py index 587f7ec..e3839dd 100644 --- a/torchsde/_core/methods/euler.py +++ b/torchsde/_core/methods/euler.py @@ -33,12 +33,8 @@ def step(self, t0, t1, y0): dt = t1 - t0 I_k = self.bm(t0, t1) - f_eval = self.sde.f(t0, y0) - g_prod_eval = self.sde.g_prod(t0, y0, I_k) - - y1 = [ - y0_ + f_eval_ * dt + g_prod_eval_ - for y0_, f_eval_, g_prod_eval_ in zip(y0, f_eval, g_prod_eval) - ] + f = self.sde.f(t0, y0) + g_prod = self.sde.g_prod(t0, y0, I_k) + y1 = y0 + f * dt + g_prod return y1 diff --git a/torchsde/_core/methods/heun.py b/torchsde/_core/methods/heun.py index de01070..5e0c021 100644 --- a/torchsde/_core/methods/heun.py +++ b/torchsde/_core/methods/heun.py @@ -14,13 +14,12 @@ """Stratonovich Heun method (strong order 1.0 scheme) from -Burrage K., Burrage P. M. and Tian T. 2004 "Numerical methods for strong solutions +Burrage K., Burrage P. M. and Tian T. 2004 "Numerical methods for strong solutions of stochastic differential equations: an overview" Proc. R. Soc. Lond. A. 460: 373–402. """ -from ...settings import SDE_TYPES, NOISE_TYPES, LEVY_AREA_APPROXIMATIONS - from .. import base_solver +from ...settings import SDE_TYPES, NOISE_TYPES, LEVY_AREA_APPROXIMATIONS class Heun(base_solver.BaseSDESolver): @@ -40,20 +39,14 @@ def step(self, t0, t1, y0): dt = t1 - t0 I_k = self.bm(t0, t1) - f_eval = self.sde.f(t0, y0) - g_prod_eval = self.sde.g_prod(t0, y0, I_k) + f = self.sde.f(t0, y0) + g_prod = self.sde.g_prod(t0, y0, I_k) - y0_prime = [ - y0_ + dt * f_eval_ + g_prod_eval_ - for y0_, f_eval_, g_prod_eval_ in zip(y0, f_eval, g_prod_eval) - ] + y0_prime = y0 + dt * f + g_prod - f_eval_prime = self.sde.f(t1, y0_prime) - g_prod_eval_prime = self.sde.g_prod(t1, y0_prime, I_k) + f_prime = self.sde.f(t1, y0_prime) + g_prod_prime = self.sde.g_prod(t1, y0_prime, I_k) - y1 = [ - y0_ + (dt * (f_eval_ + f_eval_prime_) + g_prod_eval_ + g_prod_eval_prime_) * 0.5 - for y0_, f_eval_, f_eval_prime_, g_prod_eval_, g_prod_eval_prime_ in zip(y0, f_eval, f_eval_prime, g_prod_eval, g_prod_eval_prime) - ] + y1 = y0 + (dt * (f + f_prime) + g_prod + g_prod_prime) * 0.5 return y1 diff --git a/torchsde/_core/methods/midpoint.py b/torchsde/_core/methods/midpoint.py index 98078e9..210e8d5 100644 --- a/torchsde/_core/methods/midpoint.py +++ b/torchsde/_core/methods/midpoint.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ...settings import SDE_TYPES, NOISE_TYPES, LEVY_AREA_APPROXIMATIONS - from .. import base_solver +from ...settings import SDE_TYPES, NOISE_TYPES, LEVY_AREA_APPROXIMATIONS class Midpoint(base_solver.BaseSDESolver): @@ -34,23 +33,17 @@ def step(self, t0, t1, y0): dt = t1 - t0 I_k = self.bm(t0, t1) - f_eval = self.sde.f(t0, y0) - g_prod_eval = self.sde.g_prod(t0, y0, I_k) + f = self.sde.f(t0, y0) + g_prod = self.sde.g_prod(t0, y0, I_k) half_dt = 0.5 * dt t0_prime = t0 + half_dt - y0_prime = [ - y0_ + half_dt * f_eval_ + 0.5 * g_prod_eval_ - for y0_, f_eval_, g_prod_eval_ in zip(y0, f_eval, g_prod_eval) - ] - - f_eval_prime = self.sde.f(t0_prime, y0_prime) - g_prod_eval_prime = self.sde.g_prod(t0_prime, y0_prime, I_k) - - y1 = [ - y0_ + dt * f_eval_ + g_prod_eval_ - for y0_, f_eval_, g_prod_eval_ in zip(y0, f_eval_prime, g_prod_eval_prime) - ] + y0_prime = y0 + half_dt * f + 0.5 * g_prod + + f_prime = self.sde.f(t0_prime, y0_prime) + g_prod_prime = self.sde.g_prod(t0_prime, y0_prime, I_k) + + y1 = y0 + dt * f_prime + g_prod_prime return y1 diff --git a/torchsde/_core/methods/milstein.py b/torchsde/_core/methods/milstein.py index 1a2ecea..c9f3fb5 100644 --- a/torchsde/_core/methods/milstein.py +++ b/torchsde/_core/methods/milstein.py @@ -13,14 +13,12 @@ # limitations under the License. import abc -import math import torch -from ...settings import SDE_TYPES, NOISE_TYPES, LEVY_AREA_APPROXIMATIONS, METHOD_OPTIONS - from .. import adjoint_sde from .. import base_solver +from ...settings import SDE_TYPES, NOISE_TYPES, LEVY_AREA_APPROXIMATIONS, METHOD_OPTIONS class BaseMilstein(base_solver.BaseSDESolver, metaclass=abc.ABCMeta): @@ -50,7 +48,7 @@ def v_term(self, I_k, dt): raise NotImplementedError @abc.abstractmethod - def y_prime_f_factor(self, dt, f_eval): + def y_prime_f_factor(self, dt, f): raise NotImplementedError def step(self, t0, t1, y0): @@ -58,47 +56,39 @@ def step(self, t0, t1, y0): I_k = self.bm(t0, t1) v = self.v_term(I_k, dt) - f_eval = self.sde.f(t0, y0) - g_prod_eval = self.sde.g_prod(t0, y0, I_k) + f = self.sde.f(t0, y0) + g_prod_I_k = self.sde.g_prod(t0, y0, I_k) if self.options[METHOD_OPTIONS.grad_free]: - g_eval = self.sde.g(t0, y0) - g_prod_eval_v = self.sde.g_prod(t0, y0, v) - sqrt_dt = torch.sqrt(dt) if isinstance(dt, torch.Tensor) else math.sqrt(dt) - y0_prime = [ - y0_ + self.y_prime_f_factor(dt, f_eval_) + g_eval_ * sqrt_dt - for y0_, f_eval_, g_eval_ in zip(y0, f_eval, g_eval) - ] - g_prod_eval_prime = self.sde.g_prod(t0, y0_prime, v) - gdg_prod_eval = [ - (g_prod_eval_prime_ - g_prod_eval_v_) / sqrt_dt - for g_prod_eval_prime_, g_prod_eval_v_ in zip(g_prod_eval_prime, g_prod_eval_v) - ] + g = self.sde.g(t0, y0) + g_prod_v = self.sde.g_prod(t0, y0, v) + sqrt_dt = torch.sqrt(dt) + y0_prime = y0 + self.y_prime_f_factor(dt, f) + g * sqrt_dt + g_prod_v_prime = self.sde.g_prod(t0, y0_prime, v) + gdg_prod = (g_prod_v_prime - g_prod_v) / sqrt_dt else: - gdg_prod_eval = self.sde.gdg_prod(t0, y0, v) - - y1 = [ - y0_i + f_eval_i * dt + g_prod_eval_i + .5 * gdg_prod_eval_i - for y0_i, f_eval_i, g_prod_eval_i, gdg_prod_eval_i in zip(y0, f_eval, g_prod_eval, gdg_prod_eval) - ] + gdg_prod = self.sde.gdg_prod(t0, y0, v) + + y1 = y0 + f * dt + g_prod_I_k + .5 * gdg_prod + return y1 class MilsteinIto(BaseMilstein): sde_type = SDE_TYPES.ito - + def v_term(self, I_k, dt): - return [delta_bm_ ** 2 - dt for delta_bm_ in I_k] + return I_k ** 2 - dt - def y_prime_f_factor(self, dt, f_eval): - return dt * f_eval + def y_prime_f_factor(self, dt, f): + return dt * f class MilsteinStratonovich(BaseMilstein): sde_type = SDE_TYPES.stratonovich def v_term(self, I_k, dt): - return [delta_bm_ ** 2 for delta_bm_ in I_k] + return I_k ** 2 - def y_prime_f_factor(self, dt, f_eval): + def y_prime_f_factor(self, dt, f): return 0. diff --git a/torchsde/_core/methods/srk.py b/torchsde/_core/methods/srk.py index d0e8226..9591f64 100644 --- a/torchsde/_core/methods/srk.py +++ b/torchsde/_core/methods/srk.py @@ -19,21 +19,18 @@ no. 3 (2010): 922-952. """ -import math - import torch -from ...settings import SDE_TYPES, NOISE_TYPES, LEVY_AREA_APPROXIMATIONS - +from .tableaus import sra1, srid2 from .. import adjoint_sde from .. import base_solver -from .. import misc +from ...settings import SDE_TYPES, NOISE_TYPES, LEVY_AREA_APPROXIMATIONS -from .tableaus import sra1, srid2 +_r2 = 1 / 2 +_r6 = 1 / 6 class SRK(base_solver.BaseSDESolver): - # TODO: should the strong order be 2.0 for additive noise? Numerically it looks like it. strong_order = 1.5 weak_order = 1.5 sde_type = SDE_TYPES.ito @@ -61,44 +58,39 @@ def step(self, t0, t1, y): def diagonal_or_scalar_step(self, t0, t1, y0): dt = t1 - t0 - sqrt_dt = torch.sqrt(dt) if isinstance(dt, torch.Tensor) else math.sqrt(dt) + rdt = 1 / dt + sqrt_dt = torch.sqrt(dt) I_k, I_k0 = self.bm(t0, t1, return_U=True) - I_kk = [(delta_bm_ ** 2. - dt) / 2. for delta_bm_ in I_k] - I_kkk = [(delta_bm_ ** 3. - 3. * dt * delta_bm_) / 6. for delta_bm_ in I_k] + I_kk = (I_k ** 2 - dt) * _r2 + I_kkk = (I_k ** 3 - 3 * dt * I_k) * _r6 y1 = y0 H0, H1 = [], [] for s in range(srid2.STAGES): H0s, H1s = y0, y0 # Values at the current stage to be accumulated. for j in range(s): - f_eval = self.sde.f(t0 + srid2.C0[j] * dt, H0[j]) - g_eval = self.sde.g(t0 + srid2.C1[j] * dt, H1[j]) - H0s = [ - H0s_ + srid2.A0[s][j] * f_eval_ * dt + srid2.B0[s][j] * g_eval_ * I_k0_ / dt - for H0s_, f_eval_, g_eval_, I_k0_ in zip(H0s, f_eval, g_eval, I_k0) - ] - H1s = [ - H1s_ + srid2.A1[s][j] * f_eval_ * dt + srid2.B1[s][j] * g_eval_ * sqrt_dt - for H1s_, f_eval_, g_eval_ in zip(H1s, f_eval, g_eval) - ] + f = self.sde.f(t0 + srid2.C0[j] * dt, H0[j]) + g = self.sde.g(t0 + srid2.C1[j] * dt, H1[j]) + g = g.squeeze(2) if g.dim() == 3 else g + H0s = H0s + srid2.A0[s][j] * f * dt + srid2.B0[s][j] * g * I_k0 * rdt + H1s = H1s + srid2.A1[s][j] * f * dt + srid2.B1[s][j] * g * sqrt_dt H0.append(H0s) H1.append(H1s) - f_eval = self.sde.f(t0 + srid2.C0[s] * dt, H0s) - g_eval = self.sde.g(t0 + srid2.C1[s] * dt, H1s) - g_weight = [ - srid2.beta1[s] * I_k_ + srid2.beta2[s] * I_kk_ / sqrt_dt + - srid2.beta3[s] * I_k0_ / dt + srid2.beta4[s] * I_kkk_ / dt - for I_k_, I_kk_, I_k0_, I_kkk_ in zip(I_k, I_kk, I_k0, I_kkk) - ] - y1 = [ - y1_ + srid2.alpha[s] * f_eval_ * dt + g_weight_ * g_eval_ - for y1_, f_eval_, g_eval_, g_weight_ in zip(y1, f_eval, g_eval, g_weight) - ] + f = self.sde.f(t0 + srid2.C0[s] * dt, H0s) + g_weight = ( + srid2.beta1[s] * I_k + + srid2.beta2[s] * I_kk / sqrt_dt + + srid2.beta3[s] * I_k0 * rdt + + srid2.beta4[s] * I_kkk * rdt + ) + g_prod = self.sde.g_prod(t0 + srid2.C1[s] * dt, H1s, g_weight) + y1 = y1 + srid2.alpha[s] * f * dt + g_prod return y1 def additive_step(self, t0, t1, y0): dt = t1 - t0 + rdt = 1 / dt I_k, I_k0 = self.bm(t0, t1, return_U=True) y1 = y0 @@ -106,19 +98,13 @@ def additive_step(self, t0, t1, y0): for i in range(sra1.STAGES): H0i = y0 for j in range(i): - f_eval = self.sde.f(t0 + sra1.C0[j] * dt, H0[j]) - g_eval = self.sde.g(t0 + sra1.C1[j] * dt, y0) # The state should not affect the diffusion. - H0i = [ - H0i_ + sra1.A0[i][j] * f_eval_ * dt + sra1.B0[i][j] * misc.batch_mvp(g_eval_, I_k0_) / dt - for H0i_, f_eval_, g_eval_, I_k0_ in zip(H0i, f_eval, g_eval, I_k0) - ] + f = self.sde.f(t0 + sra1.C0[j] * dt, H0[j]) + g_prod = self.sde.g_prod(t0 + sra1.C1[j] * dt, y0, I_k0) + H0i = H0i + sra1.A0[i][j] * f * dt + sra1.B0[i][j] * g_prod * rdt H0.append(H0i) - f_eval = self.sde.f(t0 + sra1.C0[i] * dt, H0i) - g_eval = self.sde.g(t0 + sra1.C1[i] * dt, y0) - g_weight = [sra1.beta1[i] * I_k_ + sra1.beta2[i] * I_k0_ / dt for I_k_, I_k0_ in zip(I_k, I_k0)] - y1 = [ - y1_ + sra1.alpha[i] * f_eval_ * dt + misc.batch_mvp(g_eval_, g_weight_) - for y1_, f_eval_, g_eval_, g_weight_ in zip(y1, f_eval, g_eval, g_weight) - ] + f = self.sde.f(t0 + sra1.C0[i] * dt, H0i) + g_weight = sra1.beta1[i] * I_k + sra1.beta2[i] * I_k0 * rdt + g_prod = self.sde.g_prod(t0 + sra1.C1[i] * dt, y0, g_weight) + y1 = y1 + sra1.alpha[i] * f * dt + g_prod return y1 diff --git a/torchsde/_core/misc.py b/torchsde/_core/misc.py index 7283621..d307c99 100644 --- a/torchsde/_core/misc.py +++ b/torchsde/_core/misc.py @@ -12,16 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools -import operator import warnings import torch -def handle_unused_kwargs(obj, unused_kwargs): +def handle_unused_kwargs(unused_kwargs, msg=None): if len(unused_kwargs) > 0: - warnings.warn(f'{obj.__class__.__name__}: Unexpected arguments {unused_kwargs}') + if msg is not None: + warnings.warn(f"{msg}: Unexpected arguments {unused_kwargs}") + else: + warnings.warn(f"Unexpected arguments {unused_kwargs}") def flatten(sequence): @@ -49,58 +50,22 @@ def seq_add(*seqs): return [sum(seq) for seq in zip(*seqs)] -def seq_mul(*seqs): - return [functools.reduce(operator.mul, seq) for seq in zip(*seqs)] - - -def seq_mul_bc(*seqs): # Supports broadcasting. - soln = [] - for seq in zip(*seqs): - cumprod = seq[0] - for tensor in seq[1:]: - # Insert dummy dims at the end of the tensor with fewer dims. - num_missing_dims = cumprod.dim() - tensor.dim() - if num_missing_dims > 0: - new_size = tensor.size() + (1,) * num_missing_dims - tensor = tensor.reshape(*new_size) - elif num_missing_dims < 0: - new_size = cumprod.size() + (1,) * num_missing_dims - cumprod = cumprod.reshape(*new_size) - cumprod = cumprod * tensor - soln += [cumprod] - return soln - - def seq_sub(xs, ys): return [x - y for x, y in zip(xs, ys)] -def seq_sub_div(xs, ys, zs): - return [_stable_div(x - y, z) for x, y, z in zip(xs, ys, zs)] - - -def _stable_div(x: torch.Tensor, y: torch.Tensor, epsilon: float = 1e-7): - y = torch.where( - y.abs() > epsilon, - y, - torch.ones_like(y).fill_(epsilon) * y.sign() - ) - return x / y - - -def seq_batch_mvp(ms, vs): - return [batch_mvp(m, v) for m, v in zip(ms, vs)] - - def batch_mvp(m, v): return torch.bmm(m, v.unsqueeze(-1)).squeeze(dim=-1) def grad(outputs, inputs, **kwargs): + if torch.is_tensor(inputs): + inputs = [inputs] + _dummy_inputs = [torch.as_strided(i, (), ()) for i in inputs] # Workaround for PyTorch bug #39784. + + if torch.is_tensor(outputs): + outputs = [outputs] outputs = make_seq_requires_grad(outputs) - if torch.is_tensor(inputs): # Workaround for PyTorch bug #39784. - inputs = (inputs,) - _dummy_inputs = [torch.as_strided(i, (), ()) for i in inputs] _grad = torch.autograd.grad(outputs, inputs, **kwargs) return convert_none_to_zeros(_grad, inputs) @@ -109,12 +74,26 @@ def grad(outputs, inputs, **kwargs): def jvp(outputs, inputs, grad_inputs=None, **kwargs): # `torch.autograd.functional.jvp` takes in `func` and requires re-evaluation. # The present implementation avoids this. + if torch.is_tensor(inputs): + inputs = [inputs] + _dummy_inputs = [torch.as_strided(i, (), ()) for i in inputs] # Workaround for PyTorch bug #39784. + + if torch.is_tensor(outputs): + outputs = [outputs] outputs = make_seq_requires_grad(outputs) - if torch.is_tensor(inputs): # Workaround for PyTorch bug #39784. - inputs = (inputs,) - _dummy_inputs = [torch.as_strided(i, (), ()) for i in inputs] dummy_outputs = [torch.zeros_like(o, requires_grad=True) for o in outputs] vjp = torch.autograd.grad(outputs, inputs, grad_outputs=dummy_outputs, **kwargs) _jvp = torch.autograd.grad(vjp, dummy_outputs, grad_outputs=grad_inputs, **kwargs) return convert_none_to_zeros(_jvp, dummy_outputs) + + +def flat_to_shape(tensor, shapes, length=()): + tensor_list = [] + total = 0 + for shape in shapes: + next_total = total + shape.numel() + # It's important that this be view((...)), not view(...). Else when length=(), shape=() it fails. + tensor_list.append(tensor[..., total:next_total].view((*length, *shape))) + total = next_total + return tensor_list diff --git a/torchsde/_core/sdeint.py b/torchsde/_core/sdeint.py index 3da9dc6..2ac967e 100644 --- a/torchsde/_core/sdeint.py +++ b/torchsde/_core/sdeint.py @@ -17,44 +17,41 @@ import torch -from . import adjoint_sde from . import base_sde from . import methods -from .._brownian import BaseBrownian, TupleBrownian, BrownianInterval +from . import misc +from .._brownian import BaseBrownian, BrownianInterval from ..settings import SDE_TYPES, NOISE_TYPES, METHODS, LEVY_AREA_APPROXIMATIONS -from ..types import TensorOrTensors, Scalar, Vector +from ..types import Scalar, Vector -def sdeint(sde: [base_sde.BaseSDE], - y0: TensorOrTensors, +def sdeint(sde: base_sde.BaseSDE, + y0: torch.Tensor, ts: Vector, bm: Optional[BaseBrownian] = None, - logqp: Optional[bool] = False, - method: Optional[str] = 'srk', + method: Optional[str] = "srk", dt: Optional[Scalar] = 1e-3, adaptive: Optional[bool] = False, rtol: Optional[float] = 1e-5, atol: Optional[float] = 1e-4, dt_min: Optional[Scalar] = 1e-5, options: Optional[Dict[str, Any]] = None, - names: Optional[Dict[str, str]] = None) -> TensorOrTensors: + names: Optional[Dict[str, str]] = None, + **unused_kwargs) -> torch.Tensor: """Numerically integrate an Itô SDE. Args: - sde: Object with methods `f` and `g` representing the drift and - diffusion. The output of `g` should be a single (or a tuple of) - tensor(s) of size (batch_size, d) for diagonal noise SDEs or - (batch_size, d, m) for SDEs of other noise types; d is the - dimensionality of state and m is the dimensionality of Brownian - motion. - y0 (sequence of Tensor): Tensors for initial state. + sde: Object with methods `f` and `g` representing the + drift and diffusion. The output of `g` should be a single tensor of + size (batch_size, d) for diagonal noise SDEs or (batch_size, d, m) + for SDEs of other noise types; d is the dimensionality of state and + m is the dimensionality of Brownian motion. + y0 (Tensor): A tensor for the initial state. ts (Tensor or sequence of float): Query times in non-descending order. The state at the first time of `ts` should be `y0`. bm (Brownian, optional): A 'BrownianInterval', `BrownianPath` or `BrownianTree` object. Should return tensors of size (batch_size, m) - for `__call__`. Defaults to `BrownianInterval`. Currently does not - support tuple outputs yet. - logqp (bool, optional): If `True`, also return the log-ratio penalty. + for `__call__`. Defaults to `BrownianInterval`. method (str, optional): Name of numerical integration method. dt (float, optional): The constant step size or initial step size for adaptive time-stepping. @@ -63,25 +60,23 @@ def sdeint(sde: [base_sde.BaseSDE], atol (float, optional): Absolute tolerance. dt_min (float, optional): Minimum step size during integration. options (dict, optional): Dict of options for the integration method. - names (dict, optional): Dict of method names for drift, diffusion, and - prior drift. Expected keys are "drift", "diffusion", and - "prior_drift". Serves so that users can use methods with names not - in `("f", "g", "h")`, e.g. to use the method "foo" for the drift, - we would supply `names={"drift": "foo"}`. + names (dict, optional): Dict of method names for drift and diffusion. + Expected keys are "drift" and "diffusion". Serves so that users can + use methods with names not in `("f", "g")`, e.g. to use the + method "foo" for the drift, we supply `names={"drift": "foo"}`. Returns: - A single state tensor of size (T, batch_size, d) or a tuple of such - tensors. Also returns a single log-ratio tensor of size - (T - 1, batch_size) or a tuple of such tensors, if `logqp==True`. + A single state tensor of size (T, batch_size, d). Raises: ValueError: An error occurred due to unrecognized noise type/method, or if `sde` is missing required methods. """ - sde, y0, ts, bm, tensor_input = check_contract(sde, y0, ts, bm, logqp, method, names) + misc.handle_unused_kwargs(unused_kwargs, msg="`sdeint`") + del unused_kwargs - sde = base_sde.ForwardSDE(sde) - results = integrate( + sde, y0, ts, bm = check_contract(sde, y0, ts, bm, method, names) + return integrate( sde=sde, y0=y0, ts=ts, @@ -93,134 +88,108 @@ def sdeint(sde: [base_sde.BaseSDE], atol=atol, dt_min=dt_min, options=options, - logqp=logqp ) - if not logqp and tensor_input: - return results[0] - return results -def check_contract(sde, y0, ts, bm, logqp, method, names): +def check_contract(sde, y0, ts, bm, method, names): if names is None: names_to_change = {} else: - names_to_change = {key: names[key] for key in ('drift', 'diffusion', 'prior_drift') if key in names} + names_to_change = {key: names[key] for key in ("drift", "diffusion") if key in names} if len(names_to_change) > 0: sde = base_sde.RenameMethodsSDE(sde, **names_to_change) - required_funcs = ('f', 'g', 'h') if logqp else ('f', 'g') + required_funcs = ("f", "g") missing_funcs = [func for func in required_funcs if not hasattr(sde, func)] if len(missing_funcs) > 0: - raise ValueError(f'sde is required to have the methods {required_funcs}. Missing functions: {missing_funcs}') + raise ValueError(f"sde is required to have the methods {required_funcs}. Missing functions: {missing_funcs}") - if not hasattr(sde, 'noise_type'): - raise ValueError(f'sde does not have the attribute noise_type.') + if not hasattr(sde, "noise_type"): + raise ValueError(f"sde does not have the attribute noise_type.") if sde.noise_type not in NOISE_TYPES: - raise ValueError(f'Expected noise type in {NOISE_TYPES}, but found {sde.noise_type}.') + raise ValueError(f"Expected noise type in {NOISE_TYPES}, but found {sde.noise_type}.") - if not hasattr(sde, 'sde_type'): - raise ValueError(f'sde does not have the attribute sde_type.') + if not hasattr(sde, "sde_type"): + raise ValueError(f"sde does not have the attribute sde_type.") if sde.sde_type not in SDE_TYPES: - raise ValueError(f'Expected sde type in {SDE_TYPES}, but found {sde.sde_type}.') + raise ValueError(f"Expected sde type in {SDE_TYPES}, but found {sde.sde_type}.") if method not in METHODS: - raise ValueError(f'Expected method in {METHODS}, but found {method}.') + raise ValueError(f"Expected method in {METHODS}, but found {method}.") - tensor_input = torch.is_tensor(y0) - if tensor_input: - sde = base_sde.TupleSDE(sde) - y0 = (y0,) - if not isinstance(y0, tuple) or not all(torch.is_tensor(y0_) for y0_ in y0): - raise ValueError("`y0` must be a Tensor or a tuple of Tensors.") + if not torch.is_tensor(y0): + raise ValueError(f"`y0` must be a torch.Tensor.") if not torch.is_tensor(ts): if not isinstance(ts, (tuple, list)) or not all(isinstance(t, (float, int)) for t in ts): raise ValueError(f"Evaluation times `ts` must be a 1-D Tensor or list/tuple of floats.") - ts = torch.tensor(ts, dtype=y0[0].dtype, device=y0[0].device) + ts = torch.tensor(ts, dtype=y0.dtype, device=y0.device) + + drift_shape = sde.f(ts[0], y0).size() + if drift_shape != y0.size(): + raise ValueError(f"Drift must return a Tensor of the same shape as `y0`. " + f"Got drift shape {drift_shape}, but y0 shape {y0.size()}.") + + diffusion_shape = sde.g(ts[0], y0).size() + noise_channels = diffusion_shape[-1] + if sde.noise_type in (NOISE_TYPES.additive, NOISE_TYPES.general, NOISE_TYPES.scalar): + batch_dimensions = diffusion_shape[:-2] + drift_shape, diffusion_shape = tuple(drift_shape), tuple(diffusion_shape) + if len(drift_shape) == 0: + raise ValueError("Drift must be of shape (..., state_channels), but got shape ().") + if len(diffusion_shape) < 2: + raise ValueError(f"Diffusion must have shape (..., state_channels, noise_channels), " + f"but got shape {diffusion_shape}.") + if drift_shape != diffusion_shape[:-1]: + raise ValueError(f"Drift and diffusion shapes do not match. Got drift shape {drift_shape}, " + f"meaning {drift_shape[:-1]} batch dimensions and {drift_shape[-1]} channel " + f"dimensions, but diffusion shape {diffusion_shape}, meaning " + f"{diffusion_shape[:-2]} batch dimensions, {diffusion_shape[-2]} channel " + f"dimensions and {diffusion_shape[-1]} noise dimension.") + if diffusion_shape[:-2] != batch_dimensions: + raise ValueError("Every Tensor returned by the diffusion must have the same number and size of batch " + "dimensions.") + if diffusion_shape[-1] != noise_channels: + raise ValueError("Every Tensor returned by the diffusion must have the same number of noise channels.") + if sde.noise_type == NOISE_TYPES.scalar: + if noise_channels != 1: + raise ValueError(f"Scalar noise must have only one channel; " + f"the diffusion has {noise_channels} noise channels.") + else: # sde.noise_type == NOISE_TYPES.diagonal + batch_dimensions = diffusion_shape[:-1] + drift_shape, diffusion_shape = tuple(drift_shape), tuple(diffusion_shape) + if len(drift_shape) == 0: + raise ValueError("Drift must be of shape (..., state_channels), but got shape ().") + if len(diffusion_shape) == 0: + raise ValueError(f"Diffusion must have shape (..., state_channels), but got shape ().") + if drift_shape != diffusion_shape: + raise ValueError(f"Drift and diffusion shapes do not match. Got drift shape {drift_shape}, " + f"meaning {drift_shape[:-1]} batch dimensions and {drift_shape[-1]} channel " + f"dimensions, but diffusion shape {diffusion_shape}, meaning " + f"{diffusion_shape[:-1]} batch dimensions, {diffusion_shape[-1]} channel " + f"dimensions and {diffusion_shape[-1]} noise dimension.") + if diffusion_shape[:-1] != batch_dimensions: + raise ValueError("Every Tensor return by the diffusion must have the same number and size of batch " + "dimensions.") + if diffusion_shape[-1] != noise_channels: + raise ValueError("Every Tensor return by the diffusion must have the same number of noise " + "channels.") + sde = base_sde.ForwardSDE(sde) - drift_shape = [fi.shape for fi in sde.f(ts[0], y0)] + if bm is None: + if method == METHODS.srk: + levy_area_approximation = LEVY_AREA_APPROXIMATIONS.space_time + else: + levy_area_approximation = LEVY_AREA_APPROXIMATIONS.none + bm = BrownianInterval(t0=ts[0], t1=ts[-1], shape=(*batch_dimensions, noise_channels), dtype=y0.dtype, + device=y0.device, levy_area_approximation=levy_area_approximation) - for drift_shape_, y0_ in zip(drift_shape, y0): - if drift_shape_ != y0_.shape: - raise ValueError(f"Drift must return a Tensor of the same shape as y0. Got drift shape {drift_shape_} but " - f"y0 shape {y0_.shape}.") + return sde, y0, ts, bm - if isinstance(sde, adjoint_sde.AdjointSDE): - if bm is None: - raise ValueError("Adjoint SDEs should have a Brownian motion defined. Please report bug to torchsde.") - else: - diffusion_shape = [gi.shape for gi in sde.g(ts[0], y0)] - - if len(drift_shape) != len(diffusion_shape) or len(drift_shape) != len(y0): - raise ValueError("drift, diffusion and y0 must all return the same number of Tensors.") - - # TODO: Add back the scalar noise check and make it consistent with the underlying functionality. - noise_channels = diffusion_shape[0][-1] - if sde.noise_type in (NOISE_TYPES.additive, NOISE_TYPES.general): - batch_dimensions = diffusion_shape[0][:-2] - for drift_shape_, diffusion_shape_ in zip(drift_shape, diffusion_shape): - drift_shape_ = tuple(drift_shape_) - diffusion_shape_ = tuple(diffusion_shape_) - if len(drift_shape_) == 0: - raise ValueError("Drift must be of shape (..., state_channels), but got shape ().") - if len(diffusion_shape_) < 2: - raise ValueError(f"Diffusion must have shape (..., state_channels, noise_channels), but got shape " - f"{diffusion_shape_}.") - if drift_shape_ != diffusion_shape_[:-1]: - raise ValueError(f"Drift and diffusion shapes do not match. Got drift shape {drift_shape_}, " - f"meaning {drift_shape_[:-1]} batch dimensions and {drift_shape_[-1]} channel " - f"dimensions, but diffusion shape {diffusion_shape_}, meaning " - f"{diffusion_shape_[:-2]} batch dimensions, {diffusion_shape_[-2]} channel " - f"dimensions and {diffusion_shape_[-1]} noise dimension.") - if diffusion_shape_[:-2] != batch_dimensions: - raise ValueError("Every Tensor return by the diffusion must have the same number and size of batch " - "dimensions.") - if diffusion_shape_[-1] != noise_channels: - raise ValueError("Every Tensor return by the diffusion must have the same number of noise " - "channels.") - if sde.noise_type == NOISE_TYPES.scalar: - if noise_channels != 1: - raise ValueError(f"Scalar noise must have only one channel; the diffusion has {noise_channels} " - f"noise channels.") - else: # sde.noise_type == NOISE_TYPES.diagonal - batch_dimensions = diffusion_shape[0][:-1] - for drift_shape_, diffusion_shape_ in zip(drift_shape, diffusion_shape): - drift_shape_ = tuple(drift_shape_) - diffusion_shape_ = tuple(diffusion_shape_) - if len(drift_shape_) == 0: - raise ValueError("Drift must be of shape (..., state_channels), but got shape ().") - if len(diffusion_shape_) == 0: - raise ValueError(f"Diffusion must have shape (..., state_channels), but got shape ().") - if drift_shape_ != diffusion_shape_: - raise ValueError(f"Drift and diffusion shapes do not match. Got drift shape {drift_shape_}, " - f"meaning {drift_shape_[:-1]} batch dimensions and {drift_shape_[-1]} channel " - f"dimensions, but diffusion shape {diffusion_shape_}, meaning " - f"{diffusion_shape_[:-1]} batch dimensions, {diffusion_shape_[-1]} channel " - f"dimensions and {diffusion_shape_[-1]} noise dimension.") - if diffusion_shape_[:-1] != batch_dimensions: - raise ValueError("Every Tensor return by the diffusion must have the same number and size of batch " - "dimensions.") - if diffusion_shape_[-1] != noise_channels: - raise ValueError("Every Tensor return by the diffusion must have the same number of noise " - "channels.") - - if bm is None: - if method == METHODS.srk: - levy_area_approximation = LEVY_AREA_APPROXIMATIONS.space_time - else: - levy_area_approximation = LEVY_AREA_APPROXIMATIONS.none - bm = BrownianInterval(t0=ts[0], t1=ts[-1], shape=(*batch_dimensions, noise_channels), dtype=y0[0].dtype, - device=y0[0].device, levy_area_approximation=levy_area_approximation) - - if tensor_input: - bm = TupleBrownian(bm) - - return sde, y0, ts, bm, tensor_input - - -def integrate(sde, y0, ts, bm, method, dt, adaptive, rtol, atol, dt_min, options, logqp=False): + +def integrate(sde, y0, ts, bm, method, dt, adaptive, rtol, atol, dt_min, options): if options is None: options = {} @@ -237,8 +206,6 @@ def integrate(sde, y0, ts, bm, method, dt, adaptive, rtol, atol, dt_min, options options=options ) if adaptive and solver.strong_order < 1.0: - warnings.warn(f'Numerical solution is only guaranteed to converge to the correct solution ' - f'when a strong order >=1.0 scheme is used for adaptive time-stepping.') - if logqp: - return solver.integrate_logqp(ts) + warnings.warn(f"Numerical solution is only guaranteed to converge to the correct solution " + f"when a strong order >=1.0 scheme is used for adaptive time-stepping.") return solver.integrate(ts) diff --git a/torchsde/brownian_lib/brownian_path.py b/torchsde/brownian_lib/brownian_path.py index fdaaac7..b97772e 100644 --- a/torchsde/brownian_lib/brownian_path.py +++ b/torchsde/brownian_lib/brownian_path.py @@ -42,7 +42,7 @@ def __init__(self, w0: torch.Tensor, levy_area_approximation: str = LEVY_AREA_APPROXIMATIONS.none, **unused_kwargs): - handle_unused_kwargs(self, unused_kwargs) + handle_unused_kwargs(unused_kwargs, msg=self.__class__.__name__) del unused_kwargs super(BrownianPath, self).__init__() diff --git a/torchsde/brownian_lib/brownian_tree.py b/torchsde/brownian_lib/brownian_tree.py index 7c98988..b9701b3 100644 --- a/torchsde/brownian_lib/brownian_tree.py +++ b/torchsde/brownian_lib/brownian_tree.py @@ -51,7 +51,7 @@ def __init__(self, safety: Optional[float] = None, levy_area_approximation: str = LEVY_AREA_APPROXIMATIONS.none, **unused_kwargs): - handle_unused_kwargs(self, unused_kwargs) + handle_unused_kwargs(unused_kwargs, msg=self.__class__.__name__) del unused_kwargs super(BrownianTree, self).__init__()