Skip to content

Commit

Permalink
Re-parameterize jax.random.gamma for better behavior at endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jul 19, 2023
1 parent 0c4c020 commit 7205160
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 33 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Expand Up @@ -18,6 +18,12 @@ Remember to align the itemized text with the first line of an item within a list
correspond to donate_argnames (or vice versa). If both donate_argnums and donate_argnames are provided, inspect.signature is not used, and only actual
parameters listed in either donate_argnums or donate_argnames will
be donated.
* {func}`jax.random.gamma` has been re-factored to a more efficient algorithm
with more robust endpoint behavior ({jax-issue}`#16779`). This means that the
sequence of values returned for a given `key` will change between JAX v0.4.13
and v0.4.14 for `gamma` and related samplers (including {func}`jax.random.ball`,
{func}`jax.random.beta`, {func}`jax.random.chisquare`, {func}`jax.random.dirichlet`,
{func}`jax.random.generalized_normal`, {func}`jax.random.loggamma`, {func}`jax.random.t`).

* Deletions
* `in_axis_resources` and `out_axis_resources` have been deleted from pjit since
Expand Down
22 changes: 9 additions & 13 deletions jax/_src/random.py
Expand Up @@ -1087,11 +1087,9 @@ def _gamma_one(key: KeyArray, alpha, log_space) -> Array:
# in floating point underflow; for this reason we compute it in log space if
# specified by the `log_space` argument:
# log[Gamma(alpha)] ~ log[Gamma(alpha + 1)] + log[Uniform()] / alpha
# Note that log[Uniform()] ~ Exponential(), but the exponential() function is
# computed via log[1 - Uniform()] to avoid taking log(0). We want the generated
# sequence to match between log_space=True and log_space=False, so we avoid this
# for now to maintain backward compatibility with the original implementation.
# TODO(jakevdp) should we change the convention to avoid -inf in log-space?
# Note that log[Uniform()] ~ -Exponential(), but to avoid problems at x=0
# exponential is computed in terms of log[1 - Uniform()]; we must account for this
# so that log-space and non-log-space samples match.
boost_mask = lax.ge(alpha, one)
alpha_orig = alpha
alpha = lax.select(boost_mask, alpha, lax.add(alpha, one))
Expand Down Expand Up @@ -1128,17 +1126,15 @@ def _next_kxv(kxv):

# initial state is chosen such that _cond_fn will return True
key, subkey = _split(key)
u_boost = uniform(subkey, (), dtype=dtype)
_, _, V, _ = lax.while_loop(_cond_fn, _body_fn, (key, zero, one, _lax_const(alpha, 2)))
if log_space:
# TODO(jakevdp): there are negative infinities here due to issues mentioned above. How should
# we handle those?
log_boost = lax.select(boost_mask, zero, lax.mul(lax.log(u_boost), lax.div(one, alpha_orig)))
log_samples = lax.neg(exponential(subkey, (), dtype=dtype))
log_boost = lax.select(boost_mask, zero, lax.mul(log_samples, lax.div(one, alpha_orig)))
return lax.add(lax.add(lax.log(d), lax.log(V)), log_boost)
else:
boost = lax.select(boost_mask, one, lax.pow(u_boost, lax.div(one, alpha_orig)))
z = lax.mul(lax.mul(d, V), boost)
return lax.select(lax.eq(z, zero), jnp.finfo(z.dtype).tiny, z)
samples = 1 - uniform(subkey, (), dtype=dtype)
boost = lax.select(boost_mask, one, lax.pow(samples, lax.div(one, alpha_orig)))
return lax.mul(lax.mul(d, V), boost)


def _gamma_grad(sample, a, *, log_space):
Expand All @@ -1147,7 +1143,7 @@ def _gamma_grad(sample, a, *, log_space):
if log_space:
# d[log(sample)] = d[sample] / sample
# This requires computing exp(log_sample), which may be zero due to float roundoff.
# In this case, we use the same zero-correction used in gamma() above.
# In this case, correct it to smallest representable float.
samples = lax.exp(samples)
zero = lax_internal._const(sample, 0)
tiny = lax.full_like(samples, jnp.finfo(samples.dtype).tiny)
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/tests/shape_poly_test.py
Expand Up @@ -2339,7 +2339,7 @@ def f_jax(operand, start_indices, x):
PolyHarness("random_gamma", f"{flags_name}",
lambda key, a: jax.random.gamma(key, a),
arg_descriptors=[RandArg((3, key_size), np.uint32), RandArg((3, 4, 5), _f32)],
polymorphic_shapes=["b, ...", "b, w, ..."],
polymorphic_shapes=["b, ...", "b, w, ..."], tol=1E-5,
override_jax_config_flags=override_jax_config_flags), # type: ignore
# The known dimensions product must be even.
PolyHarness("random_categorical", f"axis=0_{flags_name}",
Expand Down
43 changes: 24 additions & 19 deletions tests/random_test.py
Expand Up @@ -110,9 +110,9 @@ def _seed(self):
RandomValuesCase("bernoulli", "rbg", (5,), None, {'p': 0.5},
np.array([True, True, True, True, True]), on_x64=OnX64.SKIP),
RandomValuesCase("beta", "threefry2x32", (5,), np.float32, {'a': 0.8, 'b': 0.9},
np.array([0.533685, 0.843179, 0.063495, 0.573444, 0.459514], dtype='float32')),
np.array([0.13259 , 0.824893, 0.948363, 0.964155, 0.235448], dtype='float32')),
RandomValuesCase("beta", "rbg", (5,), np.float32, {'a': 0.8, 'b': 0.9},
np.array([0.841308, 0.669989, 0.731763, 0.985127, 0.022745], dtype='float32')),
np.array([0.93215 , 0.833959, 0.121902, 0.270003, 0.429541], dtype='float32')),
# TODO(frostig,jakevdp) add coverage for non-threefry bits
RandomValuesCase("bits", "threefry2x32", (5,), np.uint8, {},
np.array([10, 158, 82, 54, 158], dtype='uint8')),
Expand All @@ -129,9 +129,9 @@ def _seed(self):
RandomValuesCase("cauchy", "rbg", (5,), np.float32, {},
np.array([0.008389, 0.108793, -0.031826, -0.01876, 0.963218], dtype='float32')),
RandomValuesCase("dirichlet", "threefry2x32", (2,), np.float32, {'alpha': np.array([0.5, 0.6, 0.7], dtype='float32')},
np.array([[0.556287, 0.304219, 0.139494], [0.15221 , 0.632251, 0.21554]], dtype='float32')),
np.array([[0.003128, 0.009694, 0.987178], [0.025938, 0.479091, 0.494971]], dtype='float32')),
RandomValuesCase("dirichlet", "rbg", (2,), np.float32, {'alpha': np.array([0.5, 0.6, 0.7], dtype='float32')},
np.array([[0.024769, 0.002189, 0.973041], [0.326, 0.00244, 0.67156]], dtype='float32')),
np.array([[0.080742, 0.525493, 0.393765], [0.006837, 0.804796, 0.188366]], dtype='float32')),
RandomValuesCase("double_sided_maxwell", "threefry2x32", (5,), np.float32, {"loc": 1, "scale": 2},
np.array([-2.408914, -3.370437, 3.235352, -0.907734, -1.708732], dtype='float32'), on_x64=OnX64.SKIP),
RandomValuesCase("double_sided_maxwell", "rbg", (5,), np.float32, {"loc": 1, "scale": 2},
Expand All @@ -141,9 +141,9 @@ def _seed(self):
RandomValuesCase("exponential", "rbg", (5,), np.float32, {},
np.array([0.231303, 0.684814, 0.017181, 0.089552, 0.345087], dtype='float32')),
RandomValuesCase("gamma", "threefry2x32", (5,), np.float32, {'a': 0.8},
np.array([0.332641, 0.10187 , 1.816109, 0.023457, 0.487853], dtype='float32')),
np.array([0.824221, 1.724476, 0.502882, 5.386132, 0.685543], dtype='float32')),
RandomValuesCase("gamma", "rbg", (5,), np.float32, {'a': 0.8},
np.array([0.235293, 0.446747, 0.146372, 0.79252 , 0.294762], dtype='float32')),
np.array([0.994946, 0.519941, 1.754347, 0.479223, 1.16932 ], dtype='float32')),
RandomValuesCase("gumbel", "threefry2x32", (5,), np.float32, {},
np.array([2.06701, 0.911726, 0.145736, 0.185427, -0.00711], dtype='float32')),
RandomValuesCase("gumbel", "rbg", (5,), np.float32, {},
Expand All @@ -153,9 +153,9 @@ def _seed(self):
RandomValuesCase("laplace", "rbg", (5,), np.float32, {},
np.array([-2.970422, 1.925082, -0.757887, -4.444797, 0.561983], dtype='float32')),
RandomValuesCase("loggamma", "threefry2x32", (5,), np.float32, {'a': 0.8},
np.array([-0.899633, -0.424083, 0.631593, 0.102374, -1.07189], dtype='float32')),
np.array([ 0.240559, -3.575443, -0.450946, -2.161372, -2.943277], dtype='float32')),
RandomValuesCase("loggamma", "rbg", (5,), np.float32, {'a': 0.8},
np.array([-1.333825, 0.287259, -0.343074, -0.998258, -0.773598], dtype='float32')),
np.array([-0.107021, -0.809968, -0.25546 , -1.212273, -1.946579], dtype='float32')),
RandomValuesCase("logistic", "threefry2x32", (5,), np.float32, {},
np.array([0.19611, -1.709053, -0.274093, -0.208322, -1.675489], dtype='float32')),
RandomValuesCase("logistic", "rbg", (5,), np.float32, {},
Expand Down Expand Up @@ -913,6 +913,7 @@ def testBeta(self, a, b, dtype):
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.beta(a, b).cdf)

@jtu.skip_on_devices("tpu") # TPU precision causes issues.
def testBetaSmallParameters(self, dtype=np.float32):
# Regression test for beta version of https://github.com/google/jax/issues/9896
key = self.make_key(0)
Expand Down Expand Up @@ -959,10 +960,11 @@ def testDirichlet(self, alpha, dtype):
for i, a in enumerate(alpha):
self._CheckKolmogorovSmirnovCDF(samples[..., i], scipy.stats.beta(a, alpha_sum - a).cdf)

@jtu.skip_on_devices("tpu") # lower accuracy leads to failures.
def testDirichletSmallAlpha(self, dtype=np.float32):
# Regression test for https://github.com/google/jax/issues/9896
key = self.make_key(0)
alpha = 0.0001 * jnp.ones(3)
alpha = 0.00001 * jnp.ones(3)
samples = random.dirichlet(key, alpha, shape=(100,), dtype=dtype)

# Check that results lie on the simplex.
Expand Down Expand Up @@ -990,21 +992,26 @@ def testExponential(self, dtype):
a=[0.1, 1., 10.],
dtype=jtu.dtypes.floating,
)
@jtu.skip_on_devices("tpu") # low accuracy leads to failures.
def testGammaVsLogGamma(self, a, dtype):
# Test that gamma() and loggamma() produce equivalent samples.
key = self.make_key(0)
rand_gamma = lambda key, a: random.gamma(key, a, (10000,), dtype)
rand_loggamma = lambda key, a: random.loggamma(key, a, (10000,), dtype)
rand_gamma = lambda key, a: random.gamma(key, a, (100,), dtype)
rand_loggamma = lambda key, a: random.loggamma(key, a, (100,), dtype)
crand_loggamma = jax.jit(rand_loggamma)
tol = {np.float32: 1E-6, np.float64: 1E-12}

self.assertAllClose(rand_gamma(key, a), jnp.exp(rand_loggamma(key, a)))
self.assertAllClose(rand_gamma(key, a), jnp.exp(crand_loggamma(key, a)))
self.assertAllClose(rand_gamma(key, a), jnp.exp(rand_loggamma(key, a)),
atol=tol, rtol=tol)
self.assertAllClose(rand_gamma(key, a), jnp.exp(crand_loggamma(key, a)),
atol=tol, rtol=tol)

@jtu.sample_product(
a=[0.1, 1., 10.],
dtype=jtu.dtypes.floating,
)
def testGamma(self, a, dtype):
key = self.make_key(0)
key = self.make_key(1)
rand = lambda key, a: random.gamma(key, a, (10000,), dtype)
crand = jax.jit(rand)

Expand All @@ -1029,9 +1036,6 @@ def testGammaGrad(self, log_space, alpha):
z = random.gamma(rng, alphas)
if log_space:
actual_grad = jax.grad(lambda x: lax.exp(random.loggamma(rng, x)).sum())(alphas)
# TODO(jakevdp): this NaN correction is required because we generate negative infinities
# in the log-space computation; see related TODO in the source of random._gamma_one().
actual_grad = jnp.where(jnp.isnan(actual_grad), 0.0, actual_grad)
else:
actual_grad = jax.grad(lambda x: random.gamma(rng, x).sum())(alphas)

Expand Down Expand Up @@ -1179,8 +1183,9 @@ def testGeneralizedNormal(self, p, shape, dtype):
shape=[(), (5,), (10, 5)],
dtype=jtu.dtypes.floating,
)
@jtu.skip_on_devices("tpu") # TPU precision causes issues.
def testBall(self, d, p, shape, dtype):
key = self.make_key(0)
key = self.make_key(123)
rand = lambda key, p: random.ball(key, d, p, shape, dtype)
crand = jax.jit(rand)
uncompiled_samples = rand(key, p)
Expand Down Expand Up @@ -1577,7 +1582,7 @@ def test_categorical_shape_argument(self, shape, logits_shape_base, axis):
df = [0.2, 1., 10., 100.],
dtype=jtu.dtypes.floating)
def testChisquare(self, df, dtype):
key = self.make_key(0)
key = self.make_key(1)

def rand(key, df):
return random.chisquare(key, df, shape=(10000,), dtype=dtype)
Expand Down

0 comments on commit 7205160

Please sign in to comment.