-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding RWMH #74
Adding RWMH #74
Conversation
blackjax/inference/metrics.py
Outdated
@@ -109,7 +113,7 @@ def kinetic_energy(momentum: PyTree) -> float: | |||
return kinetic_energy_val | |||
|
|||
def is_turning( | |||
momentum_left: PyTree, momentum_right: PyTree, momentum_sum: PyTree | |||
momentum_left: PyTree, momentum_right: PyTree, momentum_sum: PyTree |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your linter adds extra spaces in function signatures, introducing changes in files that should not have been modified. Could you please revert these change?
blackjax/rwmh.py
Outdated
The next state of the chain and additional information about the current step. | ||
""" | ||
momentum_key, proposal_key = jax.random.split(rng_key, 2) | ||
momentum = momentum_generator(momentum_key, state.position, unravel=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should use the metrics
in hmc
even though I understand where you're coming from. Add a new gaussian_proposal
function, in the spirit of what I did for MCX (RWMH kernel is at the bottom of kernels.py
).
blackjax/rwmh.py
Outdated
ravelled_position, unravel_fn = ravel_pytree(state.position) | ||
ravelled_proposed_position = ravelled_position + momentum | ||
proposed_position = unravel_fn(ravelled_proposed_position) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would simplify this as follows:
_, unravel_fn = ravel_pytree(state.position)
new_position = jax.tree_util.tree_multimap(lambda q, dq: q + dq, position, unravel_fn(move_proposal))
blackjax/rwmh.py
Outdated
u = jax.random.uniform(proposal_key) | ||
p_accept = jnp.exp(state.potential_energy - proposed_potential) | ||
do_accept = u < p_accept |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about
delta = proposed_potential - state.potential_energy
delta = jnp.where(jnp.isnan(delta), -jnp.inf, delta)
p_accept = jnp.clip(jnp.exp(delta), a_max=1)
do_accept = jax.random.bernoulli(key_proposal, p_accept)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you care about the nan case? nothing can compare to nan by definition, so u < jnp.nan will always be False (same for u> jnp.nan for that matter)
blackjax/rwmh.py
Outdated
lambda _: (proposed_position, proposed_potential), | ||
lambda _: (state.position, state.potential_energy), | ||
operand=None) | ||
return RWMHState(new_position, new_potential), RWMHInfo(p_accept, do_accept, proposed_position) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here you could define
accept_state = RWMHState(proposed_position, proposed_potential)
reject_state = RWMHState(state.position, state.potential_energy)
And make lax.cond
return either of those. I find it clearer.
blackjax/rwmh.py
Outdated
|
||
def kernel( | ||
potential_fn: Callable, | ||
inverse_mass_matrix: Array, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want to pass a proposal_generator
instead of an inverse_mass_matrix
here. We don't want to limit the algorithm to moves drawn from a normal distribution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I now mostly addressed this, which made a lot of the above logic disappear. What do you think about it?
Codecov Report
@@ Coverage Diff @@
## main #74 +/- ##
=======================================
Coverage ? 98.55%
=======================================
Files ? 25
Lines ? 969
Branches ? 0
=======================================
Hits ? 955
Misses ? 14
Partials ? 0 Continue to review full report at Codecov.
|
Isn't the code stype supposed to be autocorrected by the pre-commit actions? |
Conflicts are weird. I manually resolved them already. |
If you do |
# Conflicts: # blackjax/hmc.py # blackjax/inference/hmc/base.py # blackjax/inference/hmc/integrators.py # blackjax/nuts.py # blackjax/stan_warmup.py # blackjax/types.py
That's a lot better! I'll go over the PR whenever I find time. |
Co-authored-by: Rémi Louf <remilouf@gmail.com>
This is to add RWMH as a base algo.
See this
#73