Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Joss Review: Example Usage #26

Closed
animikhaich opened this issue Jan 31, 2024 · 5 comments
Closed

Joss Review: Example Usage #26

animikhaich opened this issue Jan 31, 2024 · 5 comments

Comments

@animikhaich
Copy link

animikhaich commented Jan 31, 2024

I installed Jax CPU, followed by Surjectors. A clean install inside a docker container.

Once installed, ran the example for constructing a simple normalizing flow as mentioned in the Readme#examples. I got the following error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.10/dist-packages/surjectors/bijectors/lu_linear.py", line 16, in __init__
    super().__init__(n_keep, None, None, "bijection", dtype)
  File "/usr/local/lib/python3.10/dist-packages/surjectors/surjectors/surjector.py", line 24, in __init__
    raise ValueError(
ValueError: 'kind' argument needs to be either of: inference_surjector/generative_surjector/bijector/surjector

It failed in the line: transform = Chain([Slice(10, decoder_fn(10)), LULinear(5)])

Link to the Review Thread: openjournals/joss-reviews#6188

@dirmeier
Copy link
Owner

dirmeier commented Feb 1, 2024

Hi, my bad, I hadn't updated the pip version to v0.3.0. Assuming you are running version v0.3.0, the examples should run (at least the GitHub actions run and the documentation builds).

I updated the example in the README though to make it more friendly for users w/o Haiku background. Could you test this:

pip install surjectors==0.3.0

and then

import distrax
import haiku as hk
from jax import numpy as jnp, random as jr
from surjectors import Slice, LULinear, Chain
from surjectors import TransformedDistribution
from surjectors.nn import make_mlp

def decoder_fn(n_dim):
    def _fn(z):
        params = make_mlp([32, 32, n_dim * 2])(z)
        means, log_scales = jnp.split(params, 2, -1)
        return distrax.Independent(distrax.Normal(means, jnp.exp(log_scales)))
    return _fn

@hk.without_apply_rng
@hk.transform
def flow(x):
    base_distribution = distrax.Independent(
        distrax.Normal(jnp.zeros(5), jnp.ones(5)), 1
    )
    transform = Chain([Slice(5, decoder_fn(5)), LULinear(5)])
    pushforward = TransformedDistribution(base_distribution, transform)
    return pushforward.log_prob(x)

x = jr.normal(jr.PRNGKey(1), (1, 10))
params = flow.init(jr.PRNGKey(2), x)
lp = flow.apply(params, x)
print(lp)

Thanks and cheers,
Simon

@dirmeier
Copy link
Owner

dirmeier commented Feb 1, 2024

I think the example in the README is not very illustrative though. The ones in examples or here are better.

@animikhaich
Copy link
Author

@dirmeier Thanks, that fixed it. But I found another minor bug in the documentation.

The Example in the Docs: https://surjectors.readthedocs.io/en/latest/#example is missing this import line: from surjectors import Slice, LULinear, Chain, which is causing the example to fail. Please update the documentation with the import.

Thanks!

@dirmeier
Copy link
Owner

dirmeier commented Feb 2, 2024

Thanks! I made a PR #29 with the fix. Ill merge it after the review is done (I dont wanna add too many non-functional commits).

@animikhaich
Copy link
Author

Understood Thanks! Closing this issue then and giving the go-ahead from my end. Good Luck! 👍🏻

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants