-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pathfinder #194
Pathfinder #194
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leaving the first round of comments, largely style wise. Need a bit more indepth review of the pathfinder implementation (vi/pathfinder.py)
There are a lot of places single letter variable names are used. Suggest changing to more informative name, for example x
-> state
, x0
-> initial_state
, J
-> param_dims
, g
-> grad
@miclegr Thank you for contributing, this looks like it will be a valuable contribution to the library. Could you please rebase |
sure, did the rebasing now. Please note that when it comes to tests there's the problem I've outlined in my comment in #157 :
|
Tentative feature request to have L-BFGS optimization path exposed directly in jax, without the code duplication of this pull request: jax-ml/jax#10243 |
Per discussion in jax-ml/jax#10243, seems it is a better practice to use JAXopt - could you update the code and add a JAXopt dependency? |
Yep, I'm going to start working on that. Do you prefer JAXopt as a full requirement (i.e. as an entry in requirements.txt) or as an additional one (i.e. import within a try/except)? |
let's make it a full requirement - I imagine we will use it in other optimization related inference method |
@miclegr is there anything we can do to help? |
Hi! I was finally able to complete the refactoring to make the code use Now the tests run fine without the need to turn on the float64 mode, also the notebook :) So it should be mergeable. Michele |
Codecov Report
@@ Coverage Diff @@
## main #194 +/- ##
==========================================
- Coverage 98.40% 95.12% -3.29%
==========================================
Files 36 39 +3
Lines 1383 1580 +197
==========================================
+ Hits 1361 1503 +142
- Misses 22 77 +55
Continue to review full report at Codecov.
|
Great work @miclegr! I made some minor clean up - will merge once test pass. |
Thanks! and thanks for all the great work at pymc and tensorflow_probability 😉 |
* pathfinder module * pathfinder kernel * pathfinder adaptation * update docstrings * pathfinder example * pathfinder tests * some variable renaming * docstring module * revert change * inline ELBO function * better readability * bugfix: support pytrees when removing sample leading dim * update with new variable names * code formatting via black * comment to describe removal of leading dimension * refactoring: using jaxopt instead of jax.scipy.optimize * added jaxopt * fixed tests * update notebook * removed commented code * Fix pre-commit * Minor Clean up * Fix formatting and type hint Co-authored-by: Michele Gregori <michelegregorits@gmail.com> Co-authored-by: junpenglao <junpenglao@gmail.com>
ref #157