Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Commit

Permalink
Throw an exception when step size becomes zero in NUTS/HMC (#1606)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1606

- Throw an exception when the step size becomes zero in NUTS/HMC
- Used error messages that stan uses (https://github.com/stan-dev/stan/blob/develop/src/stan/mcmc/hmc/base_hmc.hpp#L131-L139)

Reviewed By: horizon-blue

Differential Revision: D38726138

fbshipit-source-id: d2fec5e47d589fa775576b962f2f4556e72de1cf
  • Loading branch information
CactusWin authored and facebook-github-bot committed Aug 17, 2022
1 parent 9e4d000 commit 15a98fe
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/beanmachine/ppl/inference/proposer/hmc_proposer.py
Expand Up @@ -258,6 +258,15 @@ def _find_reasonable_step_size(
step_size_scale = 2**direction
while new_direction == direction:
step_size *= step_size_scale
if step_size == 0:
raise ValueError(
f"Current step size is {step_size}. No acceptably small step size could be found."
"Perhaps the posterior is not continuous?"
)
if step_size > 1e7:
raise ValueError(
f"Current step size is {step_size}. Posterior is improper. Please check your model"
)
# not covered in the paper, but both Stan and Pyro re-sample the momentum
# after each update
momentums = self._initialize_momentums(positions)
Expand Down
21 changes: 21 additions & 0 deletions src/beanmachine/ppl/inference/proposer/tests/hmc_proposer_test.py
Expand Up @@ -64,3 +64,24 @@ def test_leapfrog_step(hmc):
)
assert momentums == new_momentums
assert new_positions == hmc._positions


@pytest.mark.parametrize(
# forcing the step_size to be 0 for HMC/ NUTS
"algorithm",
[
bm.GlobalNoUTurnSampler(initial_step_size=0.0),
bm.GlobalHamiltonianMonteCarlo(trajectory_length=1.0, initial_step_size=0.0),
],
)
def test_step_size_exception(algorithm):
queries = [foo()]
observations = {bar(): torch.tensor(0.5)}

with pytest.raises(ValueError):
algorithm.infer(
queries,
observations,
num_samples=20,
num_chains=1,
)

0 comments on commit 15a98fe

Please sign in to comment.