Skip to content

Commit

Permalink
Fixes #3.
Browse files Browse the repository at this point in the history
  • Loading branch information
lxuechen committed Aug 6, 2020
1 parent 0f4572f commit 647b6e5
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 19 deletions.
2 changes: 1 addition & 1 deletion benchmarks/brownian.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def swiss_knife_plotter(img_path, plots=None, scatters=None, options=None):
if scatters is None: scatters = ()
if options is None: options = {}

plt.figure()
plt.figure(dpi=300)
if 'xscale' in options: plt.xscale(options['xscale'])
if 'yscale' in options: plt.yscale(options['yscale'])
if 'xlabel' in options: plt.xlabel(options['xlabel'])
Expand Down
14 changes: 14 additions & 0 deletions examples/jit.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torchsde
from torchsde import sdeint
Expand Down
5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
import torch
from torch.utils import cpp_extension
except ModuleNotFoundError:
raise ModuleNotFoundError("Unable to import torch. Please install torch>=1.5.0 at "
"https://pytorch.org.")
raise ModuleNotFoundError("Unable to import torch. Please install torch>=1.6.0 at https://pytorch.org.")

extra_compile_args = []
extra_link_args = []
Expand Down Expand Up @@ -68,7 +67,7 @@
optional=True)
],
cmdclass={'build_ext': cpp_extension.BuildExtension},
install_requires=['torch>=1.6.0', 'blist', 'numpy>=1.17.0', 'scipy'],
install_requires=['torch>=1.6.0', 'blist', 'numpy>=1.17.0'],
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
Expand Down
7 changes: 6 additions & 1 deletion torchsde/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,9 @@

from . import brownian_lib
from ._brownian import BrownianPath, BrownianTree
from ._core import sdeint, sdeint_adjoint, SDEIto, SDEStratonovich
from ._core.adjoint import sdeint_adjoint
from ._core.base_sde import SDEIto, SDEStratonovich
from ._core.sdeint import sdeint

sdeint.__annotations__ = {}
sdeint_adjoint.__annotations__ = {}
7 changes: 0 additions & 7 deletions torchsde/_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from torchsde._core.adjoint import sdeint_adjoint
from torchsde._core.base_sde import SDEIto, SDEStratonovich
from torchsde._core.sdeint import sdeint

sdeint.__annotations__ = {}
sdeint_adjoint.__annotations__ = {}
14 changes: 7 additions & 7 deletions torchsde/_core/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from torchsde._core import methods
from torchsde._core import misc
from torchsde._core.types import TensorOrTensors, Scalar, Vector
import torchsde._core.sdeint as sdeint_module
from torchsde._core import sdeint


class _SdeintAdjointMethod(torch.autograd.Function):
Expand All @@ -46,7 +46,7 @@ def forward(ctx, *args):
ctx.adjoint_options) = sde, dt, bm, adjoint_method, adaptive, rtol, atol, dt_min, adjoint_options

sde = base_sde.ForwardSDEIto(sde)
ans = sdeint_module.integrate(
ans = sdeint.integrate(
sde=sde,
y0=y0,
ts=ts,
Expand Down Expand Up @@ -86,7 +86,7 @@ def backward(ctx, *grad_outputs):
ans_i = [ans_[i] for ans_ in ans]
aug_y0 = (*ans_i, *adj_y, adj_params)

aug_ans = sdeint_module.integrate(
aug_ans = sdeint.integrate(
sde=adjoint_sde,
y0=aug_y0,
ts=torch.tensor([-ts[i], -ts[i - 1]]).to(ts),
Expand Down Expand Up @@ -125,7 +125,7 @@ def forward(ctx, *args):
ctx.adjoint_options) = sde, dt, bm, adjoint_method, adaptive, rtol, atol, dt_min, adjoint_options

sde = base_sde.ForwardSDEIto(sde)
ans_and_logqp = sdeint_module.integrate(
ans_and_logqp = sdeint.integrate(
sde=sde,
y0=y0,
ts=ts,
Expand Down Expand Up @@ -170,7 +170,7 @@ def backward(ctx, *grad_outputs):
ans_i = [ans_[i] for ans_ in ans]
aug_y0 = (*ans_i, *adj_y, *adj_l, adj_params)

aug_ans = sdeint_module.integrate(
aug_ans = sdeint.integrate(
sde=adjoint_sde,
y0=aug_y0,
ts=torch.tensor([-ts[i], -ts[i - 1]]).to(ts),
Expand Down Expand Up @@ -260,10 +260,10 @@ def sdeint_adjoint(sde,
if not isinstance(sde, nn.Module):
raise ValueError('sde is required to be an instance of nn.Module.')

names_to_change = sdeint_module.get_names_to_change(names)
names_to_change = sdeint.get_names_to_change(names)
if len(names_to_change) > 0:
sde = base_sde.RenameMethodsSDE(sde, **names_to_change)
sdeint_module.check_contract(sde=sde, method=method, logqp=logqp, adjoint_method=adjoint_method)
sdeint.check_contract(sde=sde, method=method, logqp=logqp, adjoint_method=adjoint_method)

if bm is None:
bm = BrownianPath(t0=ts[0], w0=torch.zeros_like(y0).cpu())
Expand Down

0 comments on commit 647b6e5

Please sign in to comment.