Skip to content

Commit

Permalink
Enable tests related to the Gamma distribution for non-default PRNG i…
Browse files Browse the repository at this point in the history
…mplementations only when jax_enable_custom_prng is enabled, for consistency with other tests.

PiperOrigin-RevId: 440300882
  • Loading branch information
jpuigcerver authored and jax authors committed Apr 8, 2022
1 parent 58bdcb8 commit 0c02f79
Showing 1 changed file with 12 additions and 25 deletions.
37 changes: 12 additions & 25 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,14 +869,12 @@ def testExponential(self, dtype):
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.expon().cdf)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_a={}_dtype={}_prng={}".format(a, np.dtype(dtype).name,
prng_name),
"a": a, "dtype": dtype, "prng_impl": prng_impl}
for prng_name, prng_impl in PRNG_IMPLS
{"testcase_name": "_a={}_dtype={}".format(a, np.dtype(dtype).name),
"a": a, "dtype": dtype}
for a in [0.1, 1., 10.]
for dtype in jtu.dtypes.floating))
def testGammaVsLogGamma(self, prng_impl, a, dtype):
key = prng.seed_with_impl(prng_impl, 0)
def testGammaVsLogGamma(self, a, dtype):
key = self.seed_prng(0)
rand_gamma = lambda key, a: random.gamma(key, a, (10000,), dtype)
rand_loggamma = lambda key, a: random.loggamma(key, a, (10000,), dtype)
crand_loggamma = jax.jit(rand_loggamma)
Expand All @@ -885,14 +883,12 @@ def testGammaVsLogGamma(self, prng_impl, a, dtype):
self.assertAllClose(rand_gamma(key, a), jnp.exp(crand_loggamma(key, a)))

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_a={}_dtype={}_prng={}".format(a, np.dtype(dtype).name,
prng_name),
"a": a, "dtype": dtype, "prng_impl": prng_impl}
for prng_name, prng_impl in PRNG_IMPLS
{"testcase_name": "_a={}_dtype={}".format(a, np.dtype(dtype).name),
"a": a, "dtype": dtype}
for a in [0.1, 1., 10.]
for dtype in jtu.dtypes.floating))
def testGamma(self, prng_impl, a, dtype):
key = prng.seed_with_impl(prng_impl, 0)
def testGamma(self, a, dtype):
key = self.seed_prng(0)
rand = lambda key, a: random.gamma(key, a, (10000,), dtype)
crand = jax.jit(rand)

Expand All @@ -908,13 +904,12 @@ def testGammaShape(self):
assert x.shape == (3, 2)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_a={}_prng={}_logspace={}".format(alpha, prng_name, log_space),
"alpha": alpha, "log_space": log_space, "prng_impl": prng_impl}
for prng_name, prng_impl in PRNG_IMPLS
{"testcase_name": "_a={}_logspace={}".format(alpha, log_space),
"alpha": alpha, "log_space": log_space}
for log_space in [True, False]
for alpha in [1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]))
def testGammaGrad(self, log_space, prng_impl, alpha):
rng = prng.seed_with_impl(prng_impl, 0)
def testGammaGrad(self, log_space, alpha):
rng = self.seed_prng(0)
alphas = np.full((100,), alpha)
z = random.gamma(rng, alphas)
if log_space:
Expand Down Expand Up @@ -1609,18 +1604,10 @@ def _sampler_unimplemented_with_custom_prng(*args, **kwargs):
raise SkipTest('sampler only implemented for default RNG')

for test_prefix in [
'testBeta',
'testDirichlet',
'testGamma',
'testGammaGrad',
'testGammaGradType',
'testGammaShape',
'testIssue1789',
'testPoisson',
'testPoissonBatched',
'testPoissonShape',
'testPoissonZeros',
'testT',
]:
for attr in dir(LaxRandomTest):
if attr.startswith(test_prefix):
Expand Down

0 comments on commit 0c02f79

Please sign in to comment.