Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bad return values from jax.random.poisson when lam is jnp.inf #16164

Open
JiaYaobo opened this issue May 27, 2023 · 2 comments
Open

Bad return values from jax.random.poisson when lam is jnp.inf #16164

JiaYaobo opened this issue May 27, 2023 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@JiaYaobo
Copy link
Contributor

JiaYaobo commented May 27, 2023

Description

import jax.random as jr
import jax.numpy as jnp
jr.poisson(key, jnp.inf, shape=(2, ))

Output

Array([2147483647,          0], dtype=int32)

As above, +inf input leads to output 0, looks like an overflow occurs, for lam = inf, seems 214748364 should be a right answer here? A large output of a +inf rate or expectation is more convincing here. BTW, for numpy inf will lead to a value error:

ValueError: lam value too large

For tensorflow-probability inf produces infs

Poisson(jnp.inf)sample.(seed=key, sample_shape=(10, ))

Output

Array([inf, inf, inf, inf, inf, inf, inf, inf, inf, inf], dtype=float32)

and maybe for compatibility with numpy, jax chooses integer return dtype, while tfp chooses float, as discussed by @axch #16134 (review), seems a float return dtype may be more reasonable...

What jax/jaxlib version are you using?

jax v0.4.10, jaxlib v0.4.10

Which accelerator(s) are you using?

CPU

Additional system info

Mac, Python 3.10.9

NVIDIA GPU info

None

@JiaYaobo JiaYaobo added the bug Something isn't working label May 27, 2023
@jakevdp
Copy link
Collaborator

jakevdp commented Jun 1, 2023

Thanks for the report. JAX has some constraints here, namely:

  • we cannot raise a runtime error triggered by the value of an input – this is quite similar, for example, to the reason JAX cannot raise IndexErrors for out-of-bound indices.
  • we cannot change the type of the output based on the value of the input, so only integer arrays are an option.

With that in mind, given that the API returns an integer array, all we can do is choose some valid integer value to return when lambda is too large. There's no good return value here, but the largest representable integer is probably the least bad option. Alternatively, we could change the API contract so that it returns a float array, in which case we could use NaN or inf for invalid outputs (though this would require a deprecation cycle for the API change, so would be somewhat complicated) – what do you think?

@jakevdp jakevdp self-assigned this Jun 1, 2023
@JiaYaobo
Copy link
Contributor Author

JiaYaobo commented Jun 2, 2023

Hi @jakevdp, thanks for your reply! In my opinion, JAX may utlimately support a float array return, there're some reasons you mentioned, nan or inf are more intuitive and mathematical than -1, other lib (tfp) and framework (torch) return a float array too (but seems bug exists too 🤔 pytorch/pytorch#102811), and for my use case, log transformation is often used.

Surely it's not a must and depends on personal taste and may impact latter features, such as what type will other discrete distributions(binomial, multinomial) return?. Consistency between these apis seems important too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants