Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pathfinder #194

Merged
merged 24 commits into from
May 24, 2022
Merged

Pathfinder #194

merged 24 commits into from
May 24, 2022

Conversation

miclegr
Copy link
Contributor

@miclegr miclegr commented Apr 3, 2022

ref #157

Copy link
Member

@junpenglao junpenglao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving the first round of comments, largely style wise. Need a bit more indepth review of the pathfinder implementation (vi/pathfinder.py)
There are a lot of places single letter variable names are used. Suggest changing to more informative name, for example x -> state, x0 -> initial_state, J -> param_dims, g -> grad

blackjax/adaptation/pathfinder_adaptation.py Outdated Show resolved Hide resolved
blackjax/kernels.py Outdated Show resolved Hide resolved
blackjax/kernels.py Outdated Show resolved Hide resolved
blackjax/kernels.py Outdated Show resolved Hide resolved
blackjax/kernels.py Outdated Show resolved Hide resolved
blackjax/vi/pathfinder.py Outdated Show resolved Hide resolved
blackjax/vi/pathfinder.py Outdated Show resolved Hide resolved
@rlouf
Copy link
Member

rlouf commented Apr 5, 2022

@miclegr Thank you for contributing, this looks like it will be a valuable contribution to the library. Could you please rebase main onto your branch? When you checked out to create your branch the CI tests were not properly configured and thus are not running.

@miclegr
Copy link
Contributor Author

miclegr commented Apr 6, 2022

@miclegr Thank you for contributing, this looks like it will be a valuable contribution to the library. Could you please rebase main onto your branch? When you checked out to create your branch the CI tests were not properly configured and thus are not running.

sure, did the rebasing now. Please note that when it comes to tests there's the problem I've outlined in my comment in #157 :

Unfortunately the L-BFGS optimizer is not working well when optimizing model's (negative) log-lilekihood while working in jax's default float32 mode, usually convergence fails. I've noticed that by simply dividing model's likelihood by number of observations (hence optimizing model's average likelihood) optimization converges. Unfortunately given blackjax design it's quite unnatural to ask for average likelihood. So in the end it's recommended to turn on double precision mode

A consequence of this is that, for running pathfinder tests, double precision mode is needed. Since double precision mode needs to be set at jax initialization time (see here), test suite should support some test in float32 mode and some in float64 mode (e.g. by running them in separate processes). It's feasible (see here for example) but not implemented at the moment. Right know i've just set the float64 mode for that tests: https://github.com/miclegr/blackjax/blob/128ce00bd2b28e06f79d126c9d3c097b6378ccc8/tests/test_pathfinder.py#L2-L3 But this wont work when running the full suite of tests. Happy to spend some time to get the multi float mode tests in place, if that's the solution.

@miclegr
Copy link
Contributor Author

miclegr commented Apr 12, 2022

Tentative feature request to have L-BFGS optimization path exposed directly in jax, without the code duplication of this pull request: jax-ml/jax#10243

@junpenglao
Copy link
Member

Per discussion in jax-ml/jax#10243, seems it is a better practice to use JAXopt - could you update the code and add a JAXopt dependency?

@miclegr
Copy link
Contributor Author

miclegr commented Apr 27, 2022

Yep, I'm going to start working on that. Do you prefer JAXopt as a full requirement (i.e. as an entry in requirements.txt) or as an additional one (i.e. import within a try/except)?

@junpenglao
Copy link
Member

let's make it a full requirement - I imagine we will use it in other optimization related inference method

@rlouf rlouf added enhancement New feature or request sampler Issue related to samplers labels Apr 28, 2022
@rlouf rlouf linked an issue Apr 28, 2022 that may be closed by this pull request
@rlouf
Copy link
Member

rlouf commented May 13, 2022

@miclegr is there anything we can do to help?

@miclegr
Copy link
Contributor Author

miclegr commented May 22, 2022

Hi! I was finally able to complete the refactoring to make the code use jaxopt instead of jax.scipy.optimize.
I've also taken some time to make it work as much as possible with float32, by making ad ad-hoc adjustment to the stepsize of the LBFGS during training: https://github.com/miclegr/blackjax/blob/pathfinder/blackjax/vi/pathfinder.py#L290-L295

Now the tests run fine without the need to turn on the float64 mode, also the notebook :) So it should be mergeable.
Happy to do any remaining work

Michele

@codecov
Copy link

codecov bot commented May 23, 2022

Codecov Report

Merging #194 (33ac071) into main (1393af3) will decrease coverage by 3.28%.
The diff coverage is 72.36%.

@@            Coverage Diff             @@
##             main     #194      +/-   ##
==========================================
- Coverage   98.40%   95.12%   -3.29%     
==========================================
  Files          36       39       +3     
  Lines        1383     1580     +197     
==========================================
+ Hits         1361     1503     +142     
- Misses         22       77      +55     
Impacted Files Coverage Δ
blackjax/kernels.py 84.70% <29.72%> (-15.30%) ⬇️
blackjax/adaptation/pathfinder_adaptation.py 37.14% <37.14%> (ø)
blackjax/vi/pathfinder.py 94.30% <94.30%> (ø)
blackjax/adaptation/__init__.py 100.00% <100.00%> (ø)
blackjax/vi/__init__.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 1393af3...33ac071. Read the comment docs.

@junpenglao
Copy link
Member

Great work @miclegr! I made some minor clean up - will merge once test pass.

@junpenglao junpenglao merged commit f4221d0 into blackjax-devs:main May 24, 2022
@miclegr
Copy link
Contributor Author

miclegr commented May 24, 2022

Thanks! and thanks for all the great work at pymc and tensorflow_probability 😉

junpenglao added a commit that referenced this pull request Mar 12, 2024
* pathfinder module

* pathfinder kernel

* pathfinder adaptation

* update docstrings

* pathfinder example

* pathfinder tests

* some variable renaming

* docstring module

* revert change

* inline ELBO function

* better readability

* bugfix: support pytrees when removing sample leading dim

* update with new variable names

* code formatting via black

* comment to describe removal of leading dimension

* refactoring: using jaxopt instead of jax.scipy.optimize

* added jaxopt

* fixed tests

* update notebook

* removed commented code

* Fix pre-commit

* Minor Clean up

* Fix formatting and type hint

Co-authored-by: Michele Gregori <michelegregorits@gmail.com>
Co-authored-by: junpenglao <junpenglao@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request sampler Issue related to samplers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Pathfinder
4 participants