Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem with shapes in simple transformed distributions #42

Closed
chriscarmona opened this issue Aug 15, 2021 · 2 comments
Closed

Problem with shapes in simple transformed distributions #42

chriscarmona opened this issue Aug 15, 2021 · 2 comments

Comments

@chriscarmona
Copy link

Hi,

I'm working on an application with normalizing flows which requires sampling constrained values for some dimensions of the flow.
I found that the dimension verification to be a bit unpredictable in distrax.Transformed.

For example, a simple a sigmoid transformation of a multivariate normal works well using tfp jax substrate but doesn't work with distrax:

from jax import numpy as jnp
import haiku as hk
import distrax
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors

prng_seq = hk.PRNGSequence(123)
event_shape = (3,)

base_dist = tfd.MultivariateNormalDiag(
    loc=jnp.zeros(event_shape),
    scale_diag=jnp.ones(event_shape),
)
q_distr = tfd.TransformedDistribution(base_dist, tfb.Sigmoid())
q_distr.sample(seed=next(prng_seq)) # All good :)

base_dist = distrax.MultivariateNormalDiag(
    loc=jnp.zeros(event_shape),
    scale_diag=jnp.ones(event_shape),
)
q_distr = distrax.Transformed(base_dist, distrax.Sigmoid()) # Doesn't work :(
# ValueError: Base distribution 'MultivariateNormalDiag' has event shape (3,), but bijector 'Sigmoid' expects events to have 0 dimensions. Perhaps use `distrax.Block` or `distrax.Independent`?
q_distr = distrax.Transformed(base_dist, tfb.Exp()) # Doesn't work either :(

Would be very nice to have this supported by distrax.

Thank you!!!

All best,
Chris

@franrruiz
Copy link
Collaborator

Hi Chris,

Thank you for your comment. This was an intentional design choice for distrax to make bijectors closer to their mathematical definition. Note that distrax.Sigmoid is a bijector that transforms a scalar value x into another scalar value y by applying the transformation y = sigmoid(x). Since the bijector acts on scalars, having a 3-dimensional input vector throws an error.

To avoid the error, you need a bijector that transforms a vector x into a vector y by applying the transformation y_i = sigmoid(x_i) for each component i of the vector. To obtain such bijector, as suggested by the error message, you can use distrax.Block(). For example, this code snippet should work:

base_dist = distrax.MultivariateNormalDiag(
    loc=jnp.zeros(event_shape),
    scale_diag=jnp.ones(event_shape),
)
bij = distrax.Block(distrax.Sigmoid(), 1)
q_distr = distrax.Transformed(base_dist, bij)

@chriscarmona
Copy link
Author

Hi Francisco,

Thanks so much for the quick and insightful reply.
I agree, making this explicit declaration of Block brings clarity in many cases. I'll modify my scripts as suggested.

Enjoy the rest of the weekend :)
Best, Chris

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants