diff --git a/setup.py b/setup.py index 4b1f3ed113..c89c024647 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,24 @@ REQUIRED_MAJOR = 3 REQUIRED_MINOR = 7 - +INSTALL_REQUIRES = [ + "arviz>=0.11.0", + "astor>=0.7.1", + "botorch>=0.5.1", + "flowtorch>=0.3", + "gpytorch>=1.3.0", + "graphviz>=0.17", + "numpy>=1.18.1", + "pandas>=0.24.2", + "parameterized>=0.8.1", + "plotly>=2.2.1", + "scipy>=0.16", + "statsmodels>=0.12.0", + "torch>=1.9.0", + "tqdm>=4.46.0", + "typing-extensions>=3.10", + "xarray>=0.16.0", +] TEST_REQUIRES = ["pytest>=7.0.0", "pytest-cov"] TUTORIALS_REQUIRES = [ "bokeh", @@ -48,6 +65,7 @@ CPP_COMPILE_ARGS = ["/WX", "/permissive-", "-DEIGEN_HAS_C99_MATH"] else: CPP_COMPILE_ARGS = ["-std=c++17", "-Werror"] + INSTALL_REQUIRES.append("functorch>=0.1.0") # Check for python version @@ -125,24 +143,7 @@ long_description=long_description, long_description_content_type="text/markdown", python_requires=">={}.{}".format(REQUIRED_MAJOR, REQUIRED_MINOR), - install_requires=[ - "arviz>=0.11.0", - "astor>=0.7.1", - "botorch>=0.5.1", - "flowtorch>=0.3", - "gpytorch>=1.3.0", - "graphviz>=0.17", - "numpy>=1.18.1", - "pandas>=0.24.2", - "parameterized>=0.8.1", - "plotly>=2.2.1", - "scipy>=0.16", - "statsmodels>=0.12.0", - "torch>=1.9.0", - "tqdm>=4.46.0", - "typing-extensions>=3.10", - "xarray>=0.16.0", - ], + install_requires=INSTALL_REQUIRES, packages=find_packages("src"), package_dir={"": "src"}, ext_modules=[ diff --git a/src/beanmachine/ppl/experimental/nnc/__init__.py b/src/beanmachine/ppl/experimental/nnc/__init__.py new file mode 100644 index 0000000000..574698dad8 --- /dev/null +++ b/src/beanmachine/ppl/experimental/nnc/__init__.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import sys +from typing import Callable, TypeVar, Optional, Tuple + +from typing_extensions import ParamSpec + +P = ParamSpec("P") +R = TypeVar("R") + + +def nnc_jit( + f: Callable[P, R], static_argnums: Optional[Tuple[int]] = None +) -> Callable[P, R]: + """ + A helper function that lazily import the NNC utils, which initialize the compiler + and displaying a experimental warning, then invoke the underlying nnc_jit on + the function f. + """ + try: + # The setup code in `nnc.utils` will only be executed once in a Python session + from beanmachine.ppl.experimental.nnc.utils import nnc_jit as raw_nnc_jit + except ImportError as e: + if sys.platform.startswith("win"): + message = "functorch is not available on Windows." + else: + message = ( + "Fails to initialize NNC. This is likely caused by version mismatch " + "between PyTorch and functorch. Please checkout the functorch project " + "for installation guide." + ) + raise RuntimeError(message) from e + + return raw_nnc_jit(f, static_argnums) + + +__all__ = ["nnc_jit"] diff --git a/src/beanmachine/ppl/experimental/nnc/utils.py b/src/beanmachine/ppl/experimental/nnc/utils.py new file mode 100644 index 0000000000..2ffe7e78eb --- /dev/null +++ b/src/beanmachine/ppl/experimental/nnc/utils.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import warnings + +import functorch +import torch +import torch.jit +import torch.utils._pytree as pytree +from functorch.compile import ( + nop, + aot_function, + decomposition_table, + register_decomposition, +) + +# the warning wIll only be shown to user once when this module is imported +warnings.warn( + "The support of NNC compiler is experimental and the API is subject to" + "change in the future releases of Bean Machine. For questions regarding NNC, please" + "checkout the functorch project (https://github.com/pytorch/functorch)." +) + +torch._C._jit_set_texpr_reductions_enabled(True) + +# override the usage of torch.jit.script, which has a bit of issue handling +# empty lists (functorch#440) +def simple_ts_compile(fx_g, example_inps): + f = torch.jit.trace(fx_g, example_inps, strict=False) + f = torch.jit.freeze(f.eval()) + torch._C._jit_pass_remove_mutation(f.graph) + + return f + + +aten = torch.ops.aten +decompositions = [aten.detach] +bm_decompositions = { + k: v for k, v in decomposition_table.items() if k in decompositions +} + + +@register_decomposition(aten.mv, bm_decompositions) +def mv(a, b): + return (a * b).sum(dim=-1) + + +@register_decomposition(aten.dot, bm_decompositions) +def dot(a, b): + return (a * b).sum(dim=-1) + + +# @register_decomposition(aten.nan_to_num, bm_decompositions) +# def nan_to_num(a, val): +# return aten.where(a != a, val, a) + + +@register_decomposition(aten.zeros_like, bm_decompositions) +def zeros_like(a, **kwargs): + return a * 0 + + +@register_decomposition(aten.ones_like, bm_decompositions) +def ones_like(a, **kwargs): + return a * 0 + 1 + + +def nnc_jit(f, static_argnums=None): + return aot_function( + f, + simple_ts_compile, + nop, + static_argnums=static_argnums, + decompositions=bm_decompositions, + ) + + +functorch._src.compilers.simple_ts_compile = simple_ts_compile + + +# override default dict flatten (which requires keys to be sortable) +def _dict_flatten(d): + keys = list(d.keys()) + values = [d[key] for key in keys] + return values, keys + + +def _dict_unflatten(values, context): + return {key: value for key, value in zip(context, values)} + + +pytree._register_pytree_node(dict, _dict_flatten, _dict_unflatten) + +__all__ = ["nnc_jit"] diff --git a/src/beanmachine/ppl/experimental/tests/nnc_test.py b/src/beanmachine/ppl/experimental/tests/nnc_test.py new file mode 100644 index 0000000000..59ee64d9cd --- /dev/null +++ b/src/beanmachine/ppl/experimental/tests/nnc_test.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import warnings + +import beanmachine.ppl as bm +import pytest +import torch +import torch.distributions as dist + +if sys.platform.startswith("win"): + pytest.skip("functorch is not available on Windows", allow_module_level=True) + + +class SampleModel: + @bm.random_variable + def foo(self): + return dist.Normal(0.0, 1.0) + + @bm.random_variable + def bar(self): + return dist.Normal(self.foo(), 1.0) + + +@pytest.mark.parametrize( + "algorithm", + [ + bm.GlobalNoUTurnSampler(nnc_compile=True), + bm.GlobalHamiltonianMonteCarlo(trajectory_length=1.0, nnc_compile=True), + ], +) +def test_nnc_compile(algorithm): + model = SampleModel() + queries = [model.foo()] + observations = {model.bar(): torch.tensor(0.5)} + num_samples = 30 + num_chains = 2 + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # verify that NNC can run through + samples = algorithm.infer( + queries, + observations, + num_samples, + num_adaptive_samples=num_samples, + num_chains=num_chains, + ) + # sanity check: make sure that the samples are valid + assert not torch.isnan(samples[model.foo()]).any() diff --git a/src/beanmachine/ppl/inference/hmc_inference.py b/src/beanmachine/ppl/inference/hmc_inference.py index c0714b0ce6..cbac9b053f 100644 --- a/src/beanmachine/ppl/inference/hmc_inference.py +++ b/src/beanmachine/ppl/inference/hmc_inference.py @@ -30,6 +30,8 @@ class GlobalHamiltonianMonteCarlo(BaseInference): adapt_mass_matrix (bool): Whether to adapt the mass matrix. Defaults to True, target_accept_prob (float): Target accept prob. Increasing this value would lead to smaller step size. Defaults to 0.8. + nnc_compile: (Experimental) If True, NNC compiler will be used to accelerate the + inference (defaults to False). """ def __init__( @@ -39,12 +41,14 @@ def __init__( adapt_step_size: bool = True, adapt_mass_matrix: bool = True, target_accept_prob: float = 0.8, + nnc_compile: bool = False, ): self.trajectory_length = trajectory_length self.initial_step_size = initial_step_size self.adapt_step_size = adapt_step_size self.adapt_mass_matrix = adapt_mass_matrix self.target_accept_prob = target_accept_prob + self.nnc_compile = nnc_compile self._proposer = None def _get_default_num_adaptive_samples(self, num_samples: int) -> int: @@ -66,6 +70,7 @@ def get_proposers( self.adapt_step_size, self.adapt_mass_matrix, self.target_accept_prob, + self.nnc_compile, ) return [self._proposer] @@ -94,12 +99,14 @@ def __init__( adapt_step_size: bool = True, adapt_mass_matrix: bool = True, target_accept_prob: float = 0.8, + nnc_compile: bool = False, ): self.trajectory_length = trajectory_length self.initial_step_size = initial_step_size self.adapt_step_size = adapt_step_size self.adapt_mass_matrix = adapt_mass_matrix self.target_accept_prob = target_accept_prob + self.nnc_compile = nnc_compile self._proposers = {} def _get_default_num_adaptive_samples(self, num_samples: int) -> int: @@ -123,6 +130,7 @@ def get_proposers( self.adapt_step_size, self.adapt_mass_matrix, self.target_accept_prob, + self.nnc_compile, ) proposers.append(self._proposers[node]) return proposers diff --git a/src/beanmachine/ppl/inference/nuts_inference.py b/src/beanmachine/ppl/inference/nuts_inference.py index e0b01da011..3370aae9aa 100644 --- a/src/beanmachine/ppl/inference/nuts_inference.py +++ b/src/beanmachine/ppl/inference/nuts_inference.py @@ -38,6 +38,8 @@ class GlobalNoUTurnSampler(BaseInference): defaults to True. target_accept_prob (float): Target accept probability. Increasing this would lead to smaller step size. Defaults to 0.8. + nnc_compile: (Experimental) If True, NNC compiler will be used to accelerate the + inference (defaults to False). """ def __init__( @@ -49,6 +51,7 @@ def __init__( adapt_mass_matrix: bool = True, multinomial_sampling: bool = True, target_accept_prob: float = 0.8, + nnc_compile: bool = False, ): self.max_tree_depth = max_tree_depth self.max_delta_energy = max_delta_energy @@ -57,6 +60,7 @@ def __init__( self.adapt_mass_matrix = adapt_mass_matrix self.multinomial_sampling = multinomial_sampling self.target_accept_prob = target_accept_prob + self.nnc_compile = nnc_compile self._proposer = None def _get_default_num_adaptive_samples(self, num_samples: int) -> int: @@ -80,6 +84,7 @@ def get_proposers( self.adapt_mass_matrix, self.multinomial_sampling, self.target_accept_prob, + self.nnc_compile, ) return [self._proposer] @@ -106,6 +111,8 @@ class SingleSiteNoUTurnSampler(BaseInference): defaults to True. target_accept_prob (float): Target accept probability. Increasing this would lead to smaller step size. Defaults to 0.8. + nnc_compile: (Experimental) If True, NNC compiler will be used to accelerate the + inference (defaults to False). """ def __init__( @@ -117,6 +124,7 @@ def __init__( adapt_mass_matrix: bool = True, multinomial_sampling: bool = True, target_accept_prob: float = 0.8, + nnc_compile: bool = False, ): self.max_tree_depth = max_tree_depth self.max_delta_energy = max_delta_energy @@ -125,6 +133,7 @@ def __init__( self.adapt_mass_matrix = adapt_mass_matrix self.multinomial_sampling = multinomial_sampling self.target_accept_prob = target_accept_prob + self.nnc_compile = nnc_compile self._proposers = {} def _get_default_num_adaptive_samples(self, num_samples: int) -> int: @@ -150,6 +159,7 @@ def get_proposers( self.adapt_mass_matrix, self.multinomial_sampling, self.target_accept_prob, + self.nnc_compile, ) proposers.append(self._proposers[node]) return proposers diff --git a/src/beanmachine/ppl/inference/proposer/hmc_proposer.py b/src/beanmachine/ppl/inference/proposer/hmc_proposer.py index 23dd94051a..df6535ae44 100644 --- a/src/beanmachine/ppl/inference/proposer/hmc_proposer.py +++ b/src/beanmachine/ppl/inference/proposer/hmc_proposer.py @@ -8,6 +8,7 @@ from typing import Callable, Optional, Tuple, cast, Set import torch +from beanmachine.ppl.experimental.nnc import nnc_jit from beanmachine.ppl.inference.proposer.base_proposer import ( BaseProposer, ) @@ -46,6 +47,8 @@ class HMCProposer(BaseProposer): adapt_step_size: Flag whether to adapt step size, defaults to True. adapt_mass_matrix: Flat whether to adapt mass matrix, defaults to True. target_accept_prob: Target accept prob, defaults to 0.8. + nnc_compile: (Experimental) If True, NNC compiler will be used to accelerate the + inference (defaults to False). """ def __init__( @@ -58,6 +61,7 @@ def __init__( adapt_step_size: bool = True, adapt_mass_matrix: bool = True, target_accept_prob: float = 0.8, + nnc_compile: bool = False, ): self.world = initial_world self._target_rvs = target_rvs @@ -76,13 +80,16 @@ def __init__( self._mass_matrix_adapter = MassMatrixAdapter() if self.adapt_step_size: self.step_size = self._find_reasonable_step_size( - initial_step_size, self._positions, self._pe, self._pe_grad + torch.as_tensor(initial_step_size), + self._positions, + self._pe, + self._pe_grad, ) self._step_size_adapter = DualAverageAdapter( self.step_size, target_accept_prob ) else: - self.step_size = initial_step_size + self.step_size = torch.as_tensor(initial_step_size) if self.adapt_mass_matrix: self._window_scheme = WindowScheme(num_adaptive_samples) else: @@ -90,6 +97,10 @@ def __init__( # alpha will store the accept prob and will be used to adapt step size self._alpha = None + if nnc_compile: + # pyre-ignore[8] + self._leapfrog_step = nnc_jit(self._leapfrog_step) + @property def _initialize_momentums(self) -> Callable: return self._mass_matrix_adapter.initialize_momentums @@ -117,7 +128,7 @@ def _potential_energy(self, positions: RVDict) -> torch.Tensor: current values)""" constrained_vals = self._to_unconstrained.inv(positions) log_joint = self.world.replace(constrained_vals).log_prob() - log_joint -= self._to_unconstrained.log_abs_det_jacobian( + log_joint = log_joint - self._to_unconstrained.log_abs_det_jacobian( constrained_vals, positions ) return -log_joint @@ -175,7 +186,7 @@ def _leapfrog_step( self, positions: RVDict, momentums: RVDict, - step_size: float, + step_size: torch.Tensor, mass_inv: RVDict, pe_grad: Optional[RVDict] = None, ) -> Tuple[RVDict, RVDict, torch.Tensor, RVDict]: @@ -206,28 +217,28 @@ def _leapfrog_updates( positions: RVDict, momentums: RVDict, trajectory_length: float, - step_size: float, + step_size: torch.Tensor, mass_inv: RVDict, pe_grad: Optional[RVDict] = None, ) -> Tuple[RVDict, RVDict, torch.Tensor, RVDict]: """Run multiple iterations of leapfrog integration until the length of the trajectory is greater than the specified trajectory_length.""" # we should run at least 1 step - num_steps = max(math.ceil(trajectory_length / step_size), 1) + num_steps = max(math.ceil(trajectory_length / step_size.item()), 1) for _ in range(num_steps): positions, momentums, pe, pe_grad = self._leapfrog_step( positions, momentums, step_size, mass_inv, pe_grad ) - # pyre-fixme[61]: `pe` may not be initialized here. + # pyre-ignore[61]: `pe` may not be initialized here. return positions, momentums, pe, cast(RVDict, pe_grad) def _find_reasonable_step_size( self, - initial_step_size: float, + initial_step_size: torch.Tensor, positions: RVDict, pe: torch.Tensor, pe_grad: RVDict, - ) -> float: + ) -> torch.Tensor: """A heuristic of finding a reasonable initial step size (epsilon) as introduced in Algorithm 4 of [2].""" step_size = initial_step_size diff --git a/src/beanmachine/ppl/inference/proposer/hmc_utils.py b/src/beanmachine/ppl/inference/proposer/hmc_utils.py index 265b5cf9a4..96b9c3a0a8 100644 --- a/src/beanmachine/ppl/inference/proposer/hmc_utils.py +++ b/src/beanmachine/ppl/inference/proposer/hmc_utils.py @@ -75,19 +75,19 @@ class DualAverageAdapter: https://arxiv.org/abs/1111.4246 """ - def __init__(self, initial_epsilon: float, delta: float = 0.8): - self._log_avg_epsilon = 0.0 - self._H = 0.0 - self._mu = math.log(10 * initial_epsilon) + def __init__(self, initial_epsilon: torch.Tensor, delta: float = 0.8): + self._log_avg_epsilon = torch.zeros_like(initial_epsilon) + self._H = torch.zeros_like(initial_epsilon) + self._mu = torch.log(10 * initial_epsilon) self._t0 = 10 self._delta = delta # target mean accept prob self._gamma = 0.05 self._kappa = 0.75 self._m = 1.0 # iteration count - def step(self, alpha: torch.Tensor) -> float: + def step(self, alpha: torch.Tensor) -> torch.Tensor: H_frac = 1.0 / (self._m + self._t0) - self._H = ((1 - H_frac) * self._H) + H_frac * (self._delta - alpha.item()) + self._H = ((1 - H_frac) * self._H) + H_frac * (self._delta - alpha) log_epsilon = self._mu - (math.sqrt(self._m) / self._gamma) * self._H step_frac = self._m ** (-self._kappa) @@ -95,10 +95,10 @@ def step(self, alpha: torch.Tensor) -> float: step_frac * log_epsilon + (1 - step_frac) * self._log_avg_epsilon ) self._m += 1 - return math.exp(log_epsilon) + return torch.exp(cast(torch.Tensor, log_epsilon)) - def finalize(self) -> float: - return math.exp(self._log_avg_epsilon) + def finalize(self) -> torch.Tensor: + return torch.exp(self._log_avg_epsilon) class MassMatrixAdapter: @@ -232,7 +232,7 @@ def log_abs_det_jacobian( """Computes the sum of log det jacobian `log |dy/dx|` on the pairs of Tensors""" jacobian = torch.tensor(0.0) for node in untransformed_vals: - jacobian += ( + jacobian = jacobian + ( self.transforms[node] .log_abs_det_jacobian(untransformed_vals[node], transformed_vals[node]) .sum() diff --git a/src/beanmachine/ppl/inference/proposer/nuts_proposer.py b/src/beanmachine/ppl/inference/proposer/nuts_proposer.py index 78d1585161..e00e319100 100644 --- a/src/beanmachine/ppl/inference/proposer/nuts_proposer.py +++ b/src/beanmachine/ppl/inference/proposer/nuts_proposer.py @@ -6,6 +6,7 @@ from typing import NamedTuple, Set, Tuple import torch +from beanmachine.ppl.experimental.nnc import nnc_jit from beanmachine.ppl.inference.proposer.hmc_proposer import ( HMCProposer, ) @@ -28,14 +29,14 @@ class _Tree(NamedTuple): log_weight: torch.Tensor sum_momentums: RVDict sum_accept_prob: torch.Tensor - num_proposals: int - turned_or_diverged: bool + num_proposals: torch.Tensor + turned_or_diverged: torch.Tensor class _TreeArgs(NamedTuple): log_slice: torch.Tensor - direction: int - step_size: float + direction: torch.Tensor + step_size: torch.Tensor initial_energy: torch.Tensor mass_inv: RVDict @@ -67,6 +68,8 @@ class NUTSProposer(HMCProposer): adapt_mass_matrix: Whether to adapt mass matrix using Welford Scheme, defaults to True. multinomial_sampling: Whether to use multinomial sampling as in [2], defaults to True. target_accept_prob: Target accept probability. Increasing this would lead to smaller step size. Defaults to 0.8. + nnc_compile: (Experimental) If True, NNC compiler will be used to accelerate the + inference (defaults to False). """ def __init__( @@ -81,6 +84,7 @@ def __init__( adapt_mass_matrix: bool = True, multinomial_sampling: bool = True, target_accept_prob: float = 0.8, + nnc_compile: bool = False, ): # note that trajectory_length is not used in NUTS super().__init__( @@ -92,10 +96,14 @@ def __init__( adapt_step_size=adapt_step_size, adapt_mass_matrix=adapt_mass_matrix, target_accept_prob=target_accept_prob, + nnc_compile=False, # we will use NNC at NUTS level, not at HMC level ) self._max_tree_depth = max_tree_depth self._max_delta_energy = max_delta_energy self._multinomial_sampling = multinomial_sampling + if nnc_compile: + # pyre-ignore[8] + self._build_tree_base_case = nnc_jit(self._build_tree_base_case) def _is_u_turning( self, @@ -103,12 +111,12 @@ def _is_u_turning( left_momentums: RVDict, right_momentums: RVDict, sum_momentums: RVDict, - ) -> bool: + ) -> torch.Tensor: """The generalized U-turn condition, as described in [2] Appendix 4.2""" left_r = torch.cat([left_momentums[node] for node in mass_inv]) right_r = torch.cat([right_momentums[node] for node in mass_inv]) rho = torch.cat([mass_inv[node] * sum_momentums[node] for node in mass_inv]) - return bool((torch.dot(left_r, rho) <= 0) or (torch.dot(right_r, rho) <= 0)) + return (torch.dot(left_r, rho) <= 0) or (torch.dot(right_r, rho) <= 0) def _build_tree_base_case(self, root: _TreeNode, args: _TreeArgs) -> _Tree: """Base case of the recursive tree building algorithm: take a single leapfrog @@ -142,10 +150,8 @@ def _build_tree_base_case(self, root: _TreeNode, args: _TreeArgs) -> _Tree: log_weight=log_weight, sum_momentums=momentums, sum_accept_prob=torch.clamp(torch.exp(-delta_energy), max=1.0), - num_proposals=1, - turned_or_diverged=bool( - args.log_slice >= self._max_delta_energy - new_energy - ), + num_proposals=torch.tensor(1), + turned_or_diverged=args.log_slice >= self._max_delta_energy - new_energy, ) def _build_tree(self, root: _TreeNode, tree_depth: int, args: _TreeArgs) -> _Tree: @@ -174,7 +180,7 @@ def _combine_tree( self, old_tree: _Tree, new_tree: _Tree, - direction: int, + direction: torch.Tensor, mass_inv: RVDict, biased: bool, ) -> _Tree: @@ -288,14 +294,18 @@ def propose(self, world: World) -> Tuple[World, torch.Tensor]: log_weight=torch.tensor(0.0), # log accept prob of staying at current state sum_momentums=momentums, sum_accept_prob=torch.tensor(0.0), - num_proposals=0, - turned_or_diverged=False, + num_proposals=torch.tensor(0), + turned_or_diverged=torch.tensor(False), ) for j in range(self._max_tree_depth): - direction = 1 if torch.rand(()) > 0.5 else -1 + direction = torch.tensor(1 if torch.rand(()) > 0.5 else -1) tree_args = _TreeArgs( - log_slice, direction, self.step_size, current_energy, self._mass_inv + log_slice, + direction, + self.step_size, + current_energy, + self._mass_inv, ) if direction == -1: new_tree = self._build_tree(tree.left, j, tree_args) diff --git a/src/beanmachine/ppl/inference/proposer/tests/hmc_utils_test.py b/src/beanmachine/ppl/inference/proposer/tests/hmc_utils_test.py index 1a27628c0e..3885d8e91f 100644 --- a/src/beanmachine/ppl/inference/proposer/tests/hmc_utils_test.py +++ b/src/beanmachine/ppl/inference/proposer/tests/hmc_utils_test.py @@ -31,15 +31,15 @@ def bar(self): def test_dual_average_adapter(): - adapter = DualAverageAdapter(0.1) + adapter = DualAverageAdapter(torch.tensor(0.1)) epsilon1 = adapter.step(torch.tensor(1.0)) epsilon2 = adapter.step(torch.tensor(0.0)) assert epsilon2 < adapter.finalize() < epsilon1 def test_dual_average_with_different_delta(): - adapter1 = DualAverageAdapter(1.0, delta=0.8) - adapter2 = DualAverageAdapter(1.0, delta=0.2) + adapter1 = DualAverageAdapter(torch.tensor(1.0), delta=0.8) + adapter2 = DualAverageAdapter(torch.tensor(1.0), delta=0.2) prob = torch.tensor(0.5) # prob > delta means we can increase the step size, wherease prob < delta means # we need to decrease the step size diff --git a/src/beanmachine/ppl/world/world.py b/src/beanmachine/ppl/world/world.py index d5fc66a060..3515d0dbce 100644 --- a/src/beanmachine/ppl/world/world.py +++ b/src/beanmachine/ppl/world/world.py @@ -99,7 +99,7 @@ def replace(self, values: RVDict) -> World: new_world = self.copy() for node, value in values.items(): new_world._variables[node] = new_world._variables[node].replace( - value=value.clone() + value=torch.clone(value) ) # changing the value of a node can change the dependencies of its children nodes nodes_to_update = set().union( @@ -194,7 +194,7 @@ def log_prob( log_prob = torch.tensor(0.0) for node in set(nodes): - log_prob += torch.sum(self._variables[node].log_prob) + log_prob = log_prob + torch.sum(self._variables[node].log_prob) return log_prob def enumerate_node(self, node: RVIdentifier) -> torch.Tensor: