Skip to content

Commit

Permalink
jax.random.poisson: fix return value for lam=0
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Mar 8, 2021
1 parent 00c2957 commit c9d1ded
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
9 changes: 5 additions & 4 deletions jax/_src/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,11 +1173,12 @@ def _poisson(key, lam, shape, dtype):
# λ -> ∞, so pick some arbitrary large value.
lam_rejection = lax.select(use_knuth, lax.full_like(lam, 1e5), lam)
max_iters = dtype.type(jnp.iinfo(dtype).max) # insanely conservative
return lax.select(
use_knuth,
_poisson_knuth(key, lam_knuth, shape, dtype, max_iters),
_poisson_rejection(key, lam_rejection, shape, dtype, max_iters),
result = lax.select(
use_knuth,
_poisson_knuth(key, lam_knuth, shape, dtype, max_iters),
_poisson_rejection(key, lam_rejection, shape, dtype, max_iters),
)
return lax.select(lam == 0, jnp.zeros_like(result), result)


def poisson(key, lam, shape=(), dtype=dtypes.int_):
Expand Down
6 changes: 6 additions & 0 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,12 @@ def testPoissonShape(self):
x = random.poisson(key, np.array([2.0, 20.0]), shape=(3, 2))
assert x.shape == (3, 2)

def testPoissonZeros(self):
key = random.PRNGKey(0)
lam = jnp.concatenate([jnp.zeros(10), 20 * jnp.ones(10)])
samples = random.poisson(key, lam, shape=(2, 20))
self.assertArraysEqual(samples[:, :10], jnp.zeros_like(samples[:, :10]))

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in jtu.dtypes.floating))
Expand Down

0 comments on commit c9d1ded

Please sign in to comment.