Bad return values from jax.random.poisson
when lam
is jnp.inf
#16164
Labels
bug
Something isn't working
jax.random.poisson
when lam
is jnp.inf
#16164
Description
Output
As above,
+inf
input leads to output 0, looks like an overflow occurs, forlam = inf
, seems214748364
should be a right answer here? A large output of a+inf
rate or expectation is more convincing here. BTW, for numpyinf
will lead to a value error:For
tensorflow-probability
inf
producesinf
sOutput
and maybe for compatibility with
numpy
,jax
choosesinteger
return dtype, whiletfp
choosesfloat
, as discussed by @axch #16134 (review), seems a float return dtype may be more reasonable...What jax/jaxlib version are you using?
jax v0.4.10, jaxlib v0.4.10
Which accelerator(s) are you using?
CPU
Additional system info
Mac, Python 3.10.9
NVIDIA GPU info
None
The text was updated successfully, but these errors were encountered: