Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Commit

Permalink
Adding two unit tests for the nnc_compile option
Browse files Browse the repository at this point in the history
  • Loading branch information
horizon-blue committed Mar 24, 2022
1 parent 8eda5f0 commit acd2287
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 1 deletion.
47 changes: 47 additions & 0 deletions src/beanmachine/ppl/experimental/tests/nnc_test.py
@@ -0,0 +1,47 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import beanmachine.ppl as bm
import torch.distributions as dist
import warnings
import torch
import pytest


class SampleModel:
@bm.random_variable
def foo(self):
return dist.Normal(0.0, 1.0)

@bm.random_variable
def bar(self):
return dist.Normal(self.foo(), 1.0)


@pytest.mark.parametrize(
"algorithm",
[
bm.GlobalNoUTurnSampler(nnc_compile=True),
bm.GlobalHamiltonianMonteCarlo(trajectory_length=1.0, nnc_compile=True),
],
)
def test_nnc_compile(algorithm):
model = SampleModel()
queries = [model.foo()]
observations = {model.bar(): torch.tensor(0.5)}
num_samples = 30
num_chains = 2
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# verify that NNC can run through
samples = algorithm.infer(
queries,
observations,
num_samples,
num_adaptive_samples=num_samples,
num_chains=num_chains,
)
# sanity check: make sure that the samples are valid
assert not torch.isnan(samples[model.foo()]).any()
2 changes: 1 addition & 1 deletion src/beanmachine/ppl/inference/proposer/nuts_proposer.py
Expand Up @@ -306,7 +306,7 @@ def propose(self, world: World) -> Tuple[World, torch.Tensor]:
current_energy,
self._mass_inv,
)
if direction < 0:
if direction == -1:
new_tree = self._build_tree(tree.left, j, tree_args)
else:
new_tree = self._build_tree(tree.right, j, tree_args)
Expand Down

0 comments on commit acd2287

Please sign in to comment.