Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding RWMH #74

Merged
merged 21 commits into from
Aug 5, 2021
Merged

Adding RWMH #74

merged 21 commits into from
Aug 5, 2021

Conversation

AdrienCorenflos
Copy link
Contributor

This is to add RWMH as a base algo.
See this
#73

blackjax/hmc.py Outdated Show resolved Hide resolved
@@ -109,7 +113,7 @@ def kinetic_energy(momentum: PyTree) -> float:
return kinetic_energy_val

def is_turning(
momentum_left: PyTree, momentum_right: PyTree, momentum_sum: PyTree
momentum_left: PyTree, momentum_right: PyTree, momentum_sum: PyTree
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your linter adds extra spaces in function signatures, introducing changes in files that should not have been modified. Could you please revert these change?

blackjax/rwmh.py Outdated
The next state of the chain and additional information about the current step.
"""
momentum_key, proposal_key = jax.random.split(rng_key, 2)
momentum = momentum_generator(momentum_key, state.position, unravel=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should use the metrics in hmc even though I understand where you're coming from. Add a new gaussian_proposal function, in the spirit of what I did for MCX (RWMH kernel is at the bottom of kernels.py).

blackjax/rwmh.py Outdated
Comment on lines 108 to 110
ravelled_position, unravel_fn = ravel_pytree(state.position)
ravelled_proposed_position = ravelled_position + momentum
proposed_position = unravel_fn(ravelled_proposed_position)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would simplify this as follows:

_, unravel_fn = ravel_pytree(state.position)
new_position = jax.tree_util.tree_multimap(lambda q, dq: q + dq, position, unravel_fn(move_proposal))

blackjax/rwmh.py Outdated
Comment on lines 113 to 115
u = jax.random.uniform(proposal_key)
p_accept = jnp.exp(state.potential_energy - proposed_potential)
do_accept = u < p_accept
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about

delta = proposed_potential - state.potential_energy
delta = jnp.where(jnp.isnan(delta), -jnp.inf, delta)
p_accept = jnp.clip(jnp.exp(delta), a_max=1)
do_accept = jax.random.bernoulli(key_proposal, p_accept)

Copy link
Contributor Author

@AdrienCorenflos AdrienCorenflos Aug 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you care about the nan case? nothing can compare to nan by definition, so u < jnp.nan will always be False (same for u> jnp.nan for that matter)

blackjax/rwmh.py Outdated
lambda _: (proposed_position, proposed_potential),
lambda _: (state.position, state.potential_energy),
operand=None)
return RWMHState(new_position, new_potential), RWMHInfo(p_accept, do_accept, proposed_position)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you could define

accept_state = RWMHState(proposed_position, proposed_potential)
reject_state = RWMHState(state.position, state.potential_energy)

And make lax.cond return either of those. I find it clearer.

blackjax/rwmh.py Outdated

def kernel(
potential_fn: Callable,
inverse_mass_matrix: Array,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to pass a proposal_generator instead of an inverse_mass_matrix here. We don't want to limit the algorithm to moves drawn from a normal distribution.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I now mostly addressed this, which made a lot of the above logic disappear. What do you think about it?

@codecov
Copy link

codecov bot commented Aug 2, 2021

Codecov Report

❗ No coverage uploaded for pull request base (main@39348d3). Click here to learn what that means.
The diff coverage is n/a.

Impacted file tree graph

@@           Coverage Diff           @@
##             main      #74   +/-   ##
=======================================
  Coverage        ?   98.55%           
=======================================
  Files           ?       25           
  Lines           ?      969           
  Branches        ?        0           
=======================================
  Hits            ?      955           
  Misses          ?       14           
  Partials        ?        0           

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 39348d3...f99efa5. Read the comment docs.

@AdrienCorenflos
Copy link
Contributor Author

Isn't the code stype supposed to be autocorrected by the pre-commit actions?

@AdrienCorenflos
Copy link
Contributor Author

Conflicts are weird. I manually resolved them already.

@rlouf
Copy link
Member

rlouf commented Aug 4, 2021

If you do git rebase main you'll get many conflicts. As for pre-commit it doesn't fix automatically in the CI. You have to run it locally.

@rlouf
Copy link
Member

rlouf commented Aug 4, 2021

That's a lot better! I'll go over the PR whenever I find time.

@rlouf
Copy link
Member

rlouf commented Aug 5, 2021

Great work! Refactored the code and moved things around a little, now merging. We'll freeze features until 0.3 is released. Before releasing we'll need to merge #81, possibly #44 (if #81 does not come first) and do some light refactoring / documentation improvements.

@rlouf rlouf merged commit fe10807 into blackjax-devs:main Aug 5, 2021
junpenglao pushed a commit that referenced this pull request Mar 12, 2024
Co-authored-by: Rémi Louf <remilouf@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants