Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stan warmup #16

Merged
merged 20 commits into from
Jun 2, 2021
Merged

Stan warmup #16

merged 20 commits into from
Jun 2, 2021

Conversation

rlouf
Copy link
Member

@rlouf rlouf commented Apr 8, 2021

In this Pull Request we implement the Stan warmup for a single chain. We also change the signature of the HMC and NUTS kernel factories. Example:

kernel_factory = lambda step_size, inverse_mass_matrix: nuts.kernel(potential_fn, step_size, inverse_mass_matrix)

# `info` is a tuple that contains the chain, all warmup states and the chain info
state, (step_size, inverse_mass_matrix), info = stan_warmup(
    rng_key,
    kernel_factory,
    initial_state,
    num_steps=1000
)

# We can use different kernels in the HMC family for warmup and sampling
kernel = hmc.kernel(potential_fn, steps_size, inverse_mass_matrix)

@codecov
Copy link

codecov bot commented Apr 8, 2021

Codecov Report

Merging #16 (b064f27) into master (c6f75e9) will increase coverage by 1.31%.
The diff coverage is 98.36%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master      #16      +/-   ##
==========================================
+ Coverage   95.47%   96.78%   +1.31%     
==========================================
  Files          10       15       +5     
  Lines         398      623     +225     
==========================================
+ Hits          380      603     +223     
- Misses         18       20       +2     
Impacted Files Coverage Δ
blackjax/stan_warmup.py 96.19% <96.19%> (ø)
blackjax/adaptation/mass_matrix.py 100.00% <100.00%> (ø)
blackjax/adaptation/step_size.py 100.00% <100.00%> (ø)
blackjax/hmc.py 100.00% <100.00%> (+1.92%) ⬆️
blackjax/inference/proposal.py 100.00% <100.00%> (ø)
blackjax/nuts.py 100.00% <100.00%> (+1.85%) ⬆️
blackjax/optimizers/__init__.py 100.00% <100.00%> (ø)
blackjax/optimizers/dual_averaging.py 100.00% <100.00%> (ø)
... and 3 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update c6f75e9...b064f27. Read the comment docs.

@rlouf rlouf mentioned this pull request May 24, 2021
@rlouf rlouf marked this pull request as draft May 24, 2021 08:53
@rlouf rlouf force-pushed the dual_averaging branch 2 times, most recently from ad14448 to 0779d3e Compare May 25, 2021 11:25
@rlouf rlouf changed the title Step size adaptation Stan warmup May 25, 2021
@rlouf
Copy link
Member Author

rlouf commented May 29, 2021

I need to separate "run" from the function that updates the stan warmup, in case we want to share information across chains.

@rlouf
Copy link
Member Author

rlouf commented May 29, 2021

This is ready to review. We will just have to wait until #21 is merged to be able to test the Stan warmup with the NUTS kernel.

@rlouf rlouf marked this pull request as ready for review May 29, 2021 20:24
@rlouf rlouf force-pushed the dual_averaging branch 4 times, most recently from 8832413 to a0afa52 Compare June 2, 2021 10:25
@rlouf rlouf merged commit 7be2822 into master Jun 2, 2021
@rlouf rlouf deleted the dual_averaging branch June 2, 2021 12:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants