-
Notifications
You must be signed in to change notification settings - Fork 23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Question about integrating with bayeux #141
Comments
flowMC assumes the likelihood function with the following signature:
where pts should be the parameters you want to sample, and data is a pytree containing auxiliary data you don't have to sample over. So in order to get this to work, I think you need to modify your likelihood defined around
in two ways:
To be more concrete, the init points should have a shape Let me know whether this helps resolving the problem. If this works in the end, would you mind if I link this example on our doc page so others can take a look of this as well? P.S. Would you mind pointing me to the LearnBayesStat episode? |
Thanks for the pointers! I was able to get my example running -- I have to do some funny things to get my state flat (and unflat) which I would guess hurts performance, particularly on accelerators where Performance is also terrible-ish, which I assume is user error -- here is the example with 12 chains: I updated the colab with the code that actually runs. I'll keep looking at this tomorrow, but continue to appreciate suggestions for improving performance if you have any: For what it is worth, I included samples from |
Update: I ran with HMC (which secretly requires a Does the fact that I am ignoring the Also, the podcast is here: https://open.spotify.com/episode/1wRsmH8xXTpO8JOWajgWpL?si=df762a337c7d4361 (the website https://learnbayesstats.com/ has not updated with @marylou-gabrie's episode). She did an excellent job describing the algorithm, and I'm hoping |
I played with the notebook a bit, and here is an updated version https://colab.research.google.com/gist/kazewong/033c89e548ef59b3ceb649dcf2ffe9e5/bayeux_and_flowmc.ipynb Here are some of the changes:
With all these changes, I think the flowMC result is more reasonable. Now there is one more problem that is actually interesting and I had it in the back of my mind but never really finished it. With everything else being the same as shown in the notebook, this is combining all the chains when I put This is when I use You can see the stripes are distributed in the larger The reason for this is probably because the samples produced by the local sampler are correlated while the ones by the global sampler are uncorrelated ( or way less correlated). Since To solve this, I think there is actual work to do. Basically, we want effective samples with the local sampler instead of every sample, which should provide way smoother posterior and take away the stripes artifacts. Last remark, I agree this is a problem HMC can probably solve quite well. flowMC added the extra layer of normalizing flow to deal with bad geometries, such as multimodality or really stretched out and local correlation (like a donut). This problem is rather unimodal and smooth, so HMC shouldn't have a hard time dealing with it. Please let me know if you have more problems regarding this, I am also happy to help make this an example so other users can follow the logic behind this discussion. |
Thanks again -- working on the PR now. It should do even slightly better than above, since bayeux has machinery to transform the support of models to all of R^n -- right now flowMC has no idea |
jax-ml/bayeux#23 is out now -- if you have time for comments/suggestions, please do! There's a fair amount of abstraction going on, and it may be easier to play around with it after it merges, then open one or more issues! I'll follow up with a colab using A few notes from doing this -- lmk if you'd like these to be separate issues:
|
https://colab.research.google.com/drive/1-PhneVVik5GUq6w2HlKOsqvus13ZLaBH?usp=sharing is updated with I'll probably add an example notebook based on that soon. First I have to fix the bug that doesn't allow setting keyword arguments! |
|
Hey! I heard about this library on the LearnBayesStat podcast, and was trying to integrate it with https://jax-ml.github.io/bayeux/. It seems like it should be easy, since both worth with just a log density and an initial point. However, I am getting a somewhat cryptic error from a
jnp.stack
insideflowMC
.I would guess it has to do with the fact that the initial state has shape [(), (), (8,)], and so 8 chains have shape [(8,), (8,), (8, 8)], and there is some problem that the last dimension has more dimensions than the other two.
Here's the colab I've been trying to get this to work in: https://colab.research.google.com/drive/1-PhneVVik5GUq6w2HlKOsqvus13ZLaBH?usp=sharing
Any help would be much appreciated!
The text was updated successfully, but these errors were encountered: