Skip to content

Commit

Permalink
jax.random.poisson (#2805)
Browse files Browse the repository at this point in the history
* jax.random.poisson

The implementation for lam < 10 was directly copied from TensorFlow probability:
https://github.com/tensorflow/probability/blob/v0.10.0-rc0/tensorflow_probability/python/internal/backend/numpy/random_generators.py#L155

I adapted the implementation for lam > 10 from TensorFlow:
https://github.com/tensorflow/tensorflow/blob/v2.2.0-rc3/tensorflow/core/kernels/random_poisson_op.cc

The methods themselves match both TensorFlow and NumPy:
https://github.com/numpy/numpy/blob/v1.18.3/numpy/random/src/distributions/distributions.c#L574

* add a check for even larger lambda

* increment iter count

* remove comment that makes no sense

* Fix chi-squared tests in random_test.py

As far as I can tell, the previous implementation of the chi-squared test
for samples from discrete probability distributions was broken. It should have
been asserting that the p-value was greater 0.01, e.g., as illustrated here:
http://hamelg.blogspot.com/2015/11/python-for-data-analysis-part-25-chi.html

This hid a few other bugs, such a miscalculation of expected frequencies.

Fortunately, the existing random tests for Bernoulli and Categorical *mostly*
still pass, which the exception of multi-dimensional logits for Categorical.
Those tests are disabled by this PR.

* Fix accept condition (based on correct chi-squared test)

* Add moment checks for Poisson

* Add batching test, more Poisson rates
  • Loading branch information
shoyer committed May 2, 2020
1 parent ee38e1b commit 46ce80b
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 0 deletions.
110 changes: 110 additions & 0 deletions jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,114 @@ def _gamma(key, a, shape, dtype):
return random_gamma_p.bind(key, a)[0]


@partial(jit, static_argnums=(2, 3, 4))
def _poisson_knuth(key, lam, shape, dtype, max_iters):
# Knuth's algorithm for generating Poisson random variates.
# Reference:
# https://en.wikipedia.org/wiki/Poisson_distribution#Generating_Poisson-distributed_random_variables

def body_fn(carry):
i, k, rng, log_prod = carry
rng, subkey = split(rng)
k = lax.select(log_prod > -lam, k + 1, k)
u = uniform(subkey, shape, onp.float32)
return i + 1, k, rng, log_prod + np.log(u)

def cond_fn(carry):
i, log_prod = carry[0], carry[3]
return (log_prod > -lam).any() & (i < max_iters)

k_init = lax.full_like(lam, 0, dtype, shape)
log_rate_init = lax.full_like(lam, 0, onp.float32, shape)
k = lax.while_loop(cond_fn, body_fn, (0, k_init, key, log_rate_init))[1]
return (k - 1).astype(dtype)


@partial(jit, static_argnums=(2, 3, 4))
def _poisson_rejection(key, lam, shape, dtype, max_iters):
# Transformed rejection due to Hormann.
# Reference:
# http://citeseer.ist.psu.edu/viewdoc/citations;jsessionid=1BEB35946CC807879F55D42512E5490C?doi=10.1.1.48.3054.
log_lam = lax.log(lam)
b = 0.931 + 2.53 * lax.sqrt(lam)
a = -0.059 + 0.02483 * b
inv_alpha = 1.1239 + 1.1328 / (b - 3.4)
v_r = 0.9277 - 3.6224 / (b - 2)

def body_fn(carry):
i, k_out, accepted, key = carry
key, subkey_0, subkey_1 = split(key, 3)

u = uniform(subkey_0, shape, lam.dtype) - 0.5
v = uniform(subkey_1, shape, lam.dtype)
u_shifted = 0.5 - abs(u)

k = lax.floor((2 * a / u_shifted + b) * u + lam + 0.43)
s = lax.log(v * inv_alpha / (a / (u_shifted * u_shifted) + b))
t = -lam + k * log_lam - lax.lgamma(k + 1)

accept1 = (u_shifted >= 0.07) & (v <= v_r)
reject = (k < 0) | ((u_shifted < 0.013) & (v > u_shifted))
accept2 = s <= t
accept = accept1 | (~reject & accept2)

k_out = lax.select(accept, k, k_out)
accepted |= accept

return i + 1, k_out, accepted, key

def cond_fn(carry):
i, k_out, accepted, key = carry
return (~accepted).any() & (i < max_iters)

k_init = lax.full_like(lam, -1, lam.dtype, shape)
accepted = lax.full_like(lam, False, np.bool_, shape)
k = lax.while_loop(cond_fn, body_fn, (0, k_init, accepted, key))[1]
return k.astype(dtype)


@partial(jit, static_argnums=(2, 3))
def _poisson(key, lam, shape, dtype):
# The implementation matches TensorFlow and NumPy:
# https://github.com/tensorflow/tensorflow/blob/v2.2.0-rc3/tensorflow/core/kernels/random_poisson_op.cc
# https://github.com/numpy/numpy/blob/v1.18.3/numpy/random/src/distributions/distributions.c#L574
# For lambda < 10, we use the Knuth algorithm; otherwise, we use transformed
# rejection sampling.
use_knuth = lam < 10
lam_knuth = lax.select(use_knuth, lam, lax.full_like(lam, 0.0))
# The acceptance probability for rejection sampling maxes out at 89% as
# λ -> ∞, so pick some arbitrary large value.
lam_rejection = lax.select(use_knuth, lax.full_like(lam, 1e5), lam)
max_iters = np.iinfo(dtype).max # insanely conservative
return lax.select(
use_knuth,
_poisson_knuth(key, lam_knuth, shape, dtype, max_iters),
_poisson_rejection(key, lam_rejection, shape, dtype, max_iters),
)


def poisson(key, lam, shape=(), dtype=onp.int64):
"""Sample Poisson random values with given shape and integer dtype.
Args:
key: a PRNGKey used as the random key.
lam: rate parameter (mean of the distribution), must be >= 0.
shape: optional, a tuple of nonnegative integers representing the result
shape. Default ().
dtype: optional, a integer dtype for the returned values (default int64 if
jax_enable_x64 is true, otherwise int32).
Returns:
A random array with the specified shape and dtype.
"""
dtype = dtypes.canonicalize_dtype(dtype)
shape = abstract_arrays.canonicalize_shape(shape)
if onp.shape(lam) != shape:
lam = np.broadcast_to(lam, shape)
lam = lam.astype(onp.float32)
return _poisson(key, lam, shape, dtype)


def gumbel(key, shape=(), dtype=onp.float64):
"""Sample Gumbel random values with given shape and float dtype.
Expand All @@ -1042,6 +1150,7 @@ def _gumbel(key, shape, dtype):
return -np.log(-np.log(
uniform(key, shape, dtype, minval=np.finfo(dtype).eps, maxval=1.)))


def categorical(key, logits, axis=-1, shape=None):
"""Sample random values from categorical distributions.
Expand Down Expand Up @@ -1071,6 +1180,7 @@ def categorical(key, logits, axis=-1, shape=None):
sample_shape = shape[:len(shape)-len(batch_shape)]
return np.argmax(gumbel(key, sample_shape + logits.shape, logits.dtype) + logits, axis=axis)


def laplace(key, shape=(), dtype=onp.float64):
"""Sample Laplace random values with given shape and float dtype.
Expand Down
32 changes: 32 additions & 0 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,38 @@ def testGammaGradType(self):
# Should not crash with a type error.
api.vjp(f, a, b)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_lam={}_{}".format(lam, dtype),
"lam": lam, "dtype": onp.dtype(dtype).name}
for lam in [0.5, 3, 9, 11, 50, 500]
for dtype in [onp.int32, onp.int64]))
def testPoisson(self, lam, dtype):
key = random.PRNGKey(0)
rand = lambda key, lam: random.poisson(key, lam, (10000,), dtype)
crand = api.jit(rand)

uncompiled_samples = rand(key, lam)
compiled_samples = crand(key, lam)

for samples in [uncompiled_samples, compiled_samples]:
self._CheckChiSquared(samples, scipy.stats.poisson(lam).pmf)
# TODO(shoyer): determine error bounds for moments more rigorously (e.g.,
# based on the central limit theorem).
self.assertAllClose(samples.mean(), lam, rtol=0.01, check_dtypes=False)
self.assertAllClose(samples.var(), lam, rtol=0.03, check_dtypes=False)

def testPoissonBatched(self):
key = random.PRNGKey(0)
lam = np.concatenate([2 * np.ones(10000), 20 * np.ones(10000)])
samples = random.poisson(key, lam, shape=(20000,))
self._CheckChiSquared(samples[:10000], scipy.stats.poisson(2.0).pmf)
self._CheckChiSquared(samples[10000:], scipy.stats.poisson(20.0).pmf)

def testPoissonShape(self):
key = random.PRNGKey(0)
x = random.poisson(key, onp.array([2.0, 20.0]), shape=(3, 2))
assert x.shape == (3, 2)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(dtype), "dtype": onp.dtype(dtype).name}
for dtype in [onp.float32, onp.float64]))
Expand Down

0 comments on commit 46ce80b

Please sign in to comment.