Skip to content

Commit

Permalink
Mask out private packages.
Browse files Browse the repository at this point in the history
  • Loading branch information
lxuechen committed Aug 3, 2020
1 parent b7f2f53 commit 41fb567
Show file tree
Hide file tree
Showing 45 changed files with 77 additions and 72 deletions.
12 changes: 2 additions & 10 deletions torchsde/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,5 @@
# limitations under the License.

from . import brownian_lib
from .brownian.brownian_path import BrownianPath
from .brownian.brownian_tree import BrownianTree
from .core.adjoint import sdeint_adjoint
from .core.base_sde import BaseSDE, SDEIto
from .core.sdeint import sdeint

sdeint.__annotations__ = {}
sdeint_adjoint.__annotations__ = {}
BrownianPath.__init__.__annotations__ = {}
BrownianTree.__init__.__annotations__ = {}
from ._brownian import BrownianPath, BrownianTree
from ._core import sdeint, sdeint_adjoint, SDEIto, SDEStratonovich
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .brownian_path import BrownianPath
from .brownian_tree import BrownianTree

BrownianPath.__init__.__annotations__ = {}
BrownianTree.__init__.__annotations__ = {}
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
import numpy as np
import torch

from torchsde.brownian import utils
from torchsde.brownian.base_brownian import Brownian
from torchsde._brownian import utils
from torchsde._brownian.base_brownian import Brownian


class BrownianPath(Brownian):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
import torch
from numpy.random import SeedSequence

from torchsde.brownian import utils
from torchsde.brownian.base_brownian import Brownian
from torchsde._brownian import utils
from torchsde._brownian.base_brownian import Brownian


class BrownianTree(Brownian):
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from torchsde._core.adjoint import sdeint_adjoint
from torchsde._core.base_sde import SDEIto, SDEStratonovich
from torchsde._core.sdeint import sdeint

sdeint.__annotations__ = {}
sdeint_adjoint.__annotations__ = {}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import torch

from torchsde.core import misc
from torchsde._core import misc


def update_step_size(error_estimate, prev_step_size, safety=0.9, facmin=0.2, facmax=1.4, prev_error_ratio=None):
Expand Down
26 changes: 13 additions & 13 deletions torchsde/core/adjoint.py → torchsde/_core/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
try:
from torchsde.brownian_lib import BrownianPath
except Exception: # noqa
from torchsde.brownian.brownian_path import BrownianPath
from torchsde._brownian.brownian_path import BrownianPath # noqa

from torchsde.brownian.base_brownian import Brownian
from torchsde.core import base_sde
from torchsde.core import methods
from torchsde.core import misc
from torchsde.core import sdeint
from torchsde.core.types import TensorOrTensors, Scalar, Vector
from torchsde._brownian.base_brownian import Brownian # noqa
from torchsde._core import base_sde
from torchsde._core import methods
from torchsde._core import misc
from torchsde._core.types import TensorOrTensors, Scalar, Vector
import torchsde._core.sdeint as sdeint_module

This comment has been minimized.

Copy link
@AdrienCorenflos

AdrienCorenflos Aug 4, 2020

This change breaks the library import.

Should be `from torchsde._core import sdeint as sdeint_module



class _SdeintAdjointMethod(torch.autograd.Function):
Expand All @@ -46,7 +46,7 @@ def forward(ctx, *args):
ctx.adjoint_options) = sde, dt, bm, adjoint_method, adaptive, rtol, atol, dt_min, adjoint_options

sde = base_sde.ForwardSDEIto(sde)
ans = sdeint.integrate(
ans = sdeint_module.integrate(
sde=sde,
y0=y0,
ts=ts,
Expand Down Expand Up @@ -86,7 +86,7 @@ def backward(ctx, *grad_outputs):
ans_i = [ans_[i] for ans_ in ans]
aug_y0 = (*ans_i, *adj_y, adj_params)

aug_ans = sdeint.integrate(
aug_ans = sdeint_module.integrate(
sde=adjoint_sde,
y0=aug_y0,
ts=torch.tensor([-ts[i], -ts[i - 1]]).to(ts),
Expand Down Expand Up @@ -125,7 +125,7 @@ def forward(ctx, *args):
ctx.adjoint_options) = sde, dt, bm, adjoint_method, adaptive, rtol, atol, dt_min, adjoint_options

sde = base_sde.ForwardSDEIto(sde)
ans_and_logqp = sdeint.integrate(
ans_and_logqp = sdeint_module.integrate(
sde=sde,
y0=y0,
ts=ts,
Expand Down Expand Up @@ -170,7 +170,7 @@ def backward(ctx, *grad_outputs):
ans_i = [ans_[i] for ans_ in ans]
aug_y0 = (*ans_i, *adj_y, *adj_l, adj_params)

aug_ans = sdeint.integrate(
aug_ans = sdeint_module.integrate(
sde=adjoint_sde,
y0=aug_y0,
ts=torch.tensor([-ts[i], -ts[i - 1]]).to(ts),
Expand Down Expand Up @@ -260,10 +260,10 @@ def sdeint_adjoint(sde,
if not isinstance(sde, nn.Module):
raise ValueError('sde is required to be an instance of nn.Module.')

names_to_change = sdeint.get_names_to_change(names)
names_to_change = sdeint_module.get_names_to_change(names)
if len(names_to_change) > 0:
sde = base_sde.RenameMethodsSDE(sde, **names_to_change)
sdeint.check_contract(sde=sde, method=method, logqp=logqp, adjoint_method=adjoint_method)
sdeint_module.check_contract(sde=sde, method=method, logqp=logqp, adjoint_method=adjoint_method)

if bm is None:
bm = BrownianPath(t0=ts[0], w0=torch.zeros_like(y0).cpu())
Expand Down
4 changes: 2 additions & 2 deletions torchsde/core/base_sde.py → torchsde/_core/base_sde.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import torch
from torch import nn

from torchsde.core import misc
from torchsde.core import settings
from . import misc
from . import settings


class BaseSDE(abc.ABC, nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@

import torch

from torchsde.core import adaptive_stepping
from torchsde.core import interp
from torchsde.core import misc
from . import adaptive_stepping
from . import interp
from . import misc


class SDESolver(abc.ABC):
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

import torch

from torchsde.core import base_sde
from torchsde.core import misc
from torchsde._core import base_sde
from torchsde._core import misc


class AdjointSDEAdditive(base_sde.AdjointSDEIto):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from __future__ import division
from __future__ import print_function

from torchsde.core import base_solver
from torchsde.core.methods.general import euler
from torchsde._core import base_solver
from torchsde._core.methods.general import euler


class EulerAdditive(base_solver.GenericSDESolver):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@

import torch

from torchsde.core import base_solver
from torchsde.core import misc
from torchsde.core.methods import utils
from torchsde.core.methods.tableaus import sra1
from torchsde._core import base_solver
from torchsde._core import misc
from torchsde._core.methods import utils
from torchsde._core.methods.tableaus import sra1

STAGES, C0, C1, A0, B0, alpha, beta1, beta2 = (
sra1.STAGES, sra1.C0, sra1.C1, sra1.A0, sra1.B0, sra1.alpha, sra1.beta1, sra1.beta2
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

import torch

from torchsde.core import base_sde
from torchsde.core import misc
from torchsde._core import base_sde
from torchsde._core import misc


class AdjointSDEDiagonal(base_sde.AdjointSDEIto):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from __future__ import division
from __future__ import print_function

from torchsde.core import base_solver
from torchsde._core import base_solver


class EulerDiagonal(base_solver.GenericSDESolver):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from __future__ import division
from __future__ import print_function

from torchsde.core import base_solver
from torchsde._core import base_solver


class MilsteinDiagonal(base_solver.GenericSDESolver):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@

import torch

from torchsde.core import base_solver
from torchsde.core.methods import utils
from torchsde.core.methods.tableaus import srid2
from torchsde._core import base_solver
from torchsde._core.methods import utils
from torchsde._core.methods.tableaus import srid2

STAGES, C0, C1, A0, A1, B0, B1, alpha, beta1, beta2, beta3, beta4 = (
srid2.STAGES,
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from __future__ import division
from __future__ import print_function

from torchsde.core import base_solver
from torchsde._core import base_solver


class EulerGeneral(base_solver.GenericSDESolver):
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import division
from __future__ import print_function

from torchsde.core import base_sde
from torchsde._core import base_sde


class AdjointSDEScalar(base_sde.AdjointSDEIto):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from __future__ import division
from __future__ import print_function

from torchsde.core import base_solver
from torchsde.core.methods.diagonal import euler
from torchsde.core.methods.scalar import utils
from torchsde._core import base_solver
from torchsde._core.methods.diagonal import euler
from torchsde._core.methods.scalar import utils


class EulerScalar(base_solver.GenericSDESolver):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from __future__ import division
from __future__ import print_function

from torchsde.core import base_solver
from torchsde.core.methods.diagonal import milstein
from torchsde.core.methods.scalar import utils
from torchsde._core import base_solver
from torchsde._core.methods.diagonal import milstein
from torchsde._core.methods.scalar import utils


class MilsteinScalar(base_solver.GenericSDESolver):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from __future__ import division
from __future__ import print_function

from torchsde.core import base_solver
from torchsde.core.methods.diagonal import srk
from torchsde.core.methods.scalar import utils
from torchsde._core import base_solver
from torchsde._core.methods.diagonal import srk
from torchsde._core.methods.scalar import utils


class SRKScalar(base_solver.GenericSDESolver):
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
12 changes: 6 additions & 6 deletions torchsde/core/sdeint.py → torchsde/_core/sdeint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
try:
from torchsde.brownian_lib import BrownianPath
except Exception: # noqa
from torchsde.brownian.brownian_path import BrownianPath
from torchsde._brownian.brownian_path import BrownianPath # noqa

from torchsde.brownian.base_brownian import Brownian
from torchsde.core import base_sde
from torchsde.core import methods
from torchsde.core import settings
from torchsde.core.types import TensorOrTensors, Scalar, Vector
from torchsde._brownian.base_brownian import Brownian # noqa
from . import base_sde
from . import methods
from . import settings
from .types import TensorOrTensors, Scalar, Vector


def sdeint(sde,
Expand Down
File renamed without changes.
File renamed without changes.
6 changes: 3 additions & 3 deletions torchsde/brownian_lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
import warnings

try:
from torchsde.brownian_lib.brownian_path import BrownianPath
from torchsde.brownian_lib.brownian_tree import BrownianTree
from .brownian_path import BrownianPath
from .brownian_tree import BrownianTree

BrownianPath.__init__.__annotations__ = {}
BrownianTree.__init__.__annotations__ = {}
except Exception: # noqa
warnings.warn('Failed to import `torchsde._brownian_lib`; falling back to `torchsde.brownian`.')
warnings.warn('Failed to import `torchsde._brownian_lib`; falling back to `torchsde._brownian`.')
4 changes: 2 additions & 2 deletions torchsde/brownian_lib/brownian_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import torch
from torchsde._brownian_lib import BrownianPath as _BrownianPath # noqa

from torchsde.brownian import utils
from torchsde.brownian.base_brownian import Brownian
from torchsde._brownian import utils # noqa
from torchsde._brownian.base_brownian import Brownian # noqa


class BrownianPath(Brownian):
Expand Down
4 changes: 2 additions & 2 deletions torchsde/brownian_lib/brownian_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import torch
from torchsde._brownian_lib import BrownianTree as _BrownianTree # noqa

from torchsde.brownian import utils
from torchsde.brownian.base_brownian import Brownian
from torchsde._brownian import utils # noqa
from torchsde._brownian.base_brownian import Brownian # noqa


class BrownianTree(Brownian):
Expand Down

0 comments on commit 41fb567

Please sign in to comment.