-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Joss Review: Example Usage #26
Comments
Hi, my bad, I hadn't updated the pip version to I updated the example in the README though to make it more friendly for users w/o Haiku background. Could you test this: pip install surjectors==0.3.0 and then import distrax
import haiku as hk
from jax import numpy as jnp, random as jr
from surjectors import Slice, LULinear, Chain
from surjectors import TransformedDistribution
from surjectors.nn import make_mlp
def decoder_fn(n_dim):
def _fn(z):
params = make_mlp([32, 32, n_dim * 2])(z)
means, log_scales = jnp.split(params, 2, -1)
return distrax.Independent(distrax.Normal(means, jnp.exp(log_scales)))
return _fn
@hk.without_apply_rng
@hk.transform
def flow(x):
base_distribution = distrax.Independent(
distrax.Normal(jnp.zeros(5), jnp.ones(5)), 1
)
transform = Chain([Slice(5, decoder_fn(5)), LULinear(5)])
pushforward = TransformedDistribution(base_distribution, transform)
return pushforward.log_prob(x)
x = jr.normal(jr.PRNGKey(1), (1, 10))
params = flow.init(jr.PRNGKey(2), x)
lp = flow.apply(params, x)
print(lp) Thanks and cheers, |
@dirmeier Thanks, that fixed it. But I found another minor bug in the documentation. The Example in the Docs: https://surjectors.readthedocs.io/en/latest/#example is missing this import line: Thanks! |
Thanks! I made a PR #29 with the fix. Ill merge it after the review is done (I dont wanna add too many non-functional commits). |
Understood Thanks! Closing this issue then and giving the go-ahead from my end. Good Luck! 👍🏻 |
I installed Jax CPU, followed by Surjectors. A clean install inside a docker container.
Once installed, ran the example for constructing a simple normalizing flow as mentioned in the Readme#examples. I got the following error:
It failed in the line:
transform = Chain([Slice(10, decoder_fn(10)), LULinear(5)])
Link to the Review Thread: openjournals/joss-reviews#6188
The text was updated successfully, but these errors were encountered: