Skip to content

Commit

Permalink
Run random_test with rank_promotion='raise'
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jul 8, 2021
1 parent f0c3049 commit 233d9f7
Showing 1 changed file with 25 additions and 11 deletions.
36 changes: 25 additions & 11 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@

class LaxRandomTest(jtu.JaxTestCase):

def setUp(self):
self._jax_numpy_rank_promotion = config.jax_numpy_rank_promotion
config.update("jax_numpy_rank_promotion", "raise")
super().setUp()

def tearDown(self):
super().tearDown()
config.update("jax_numpy_rank_promotion", self._jax_numpy_rank_promotion)

def _CheckCollisions(self, samples, nbits):
fail_prob = 0.01 # conservative bound on statistical fail prob by Chebyshev
nitems = len(samples)
Expand Down Expand Up @@ -379,7 +388,6 @@ def testBernoulli(self, p, dtype):
]
for sample_shape in [(10000,), (5000, 2)]
for dtype in jtu.dtypes.floating))
@jtu.disable_implicit_rank_promotion
def testCategorical(self, p, axis, dtype, sample_shape):
key = random.PRNGKey(0)
p = np.array(p, dtype=dtype)
Expand Down Expand Up @@ -409,7 +417,8 @@ def testCategorical(self, p, axis, dtype, sample_shape):

def testBernoulliShape(self):
key = random.PRNGKey(0)
x = random.bernoulli(key, np.array([0.2, 0.3]), shape=(3, 2))
with jax.numpy_rank_promotion('allow'):
x = random.bernoulli(key, np.array([0.2, 0.3]), shape=(3, 2))
assert x.shape == (3, 2)

@parameterized.named_parameters(jtu.cases_from_list(
Expand Down Expand Up @@ -628,7 +637,8 @@ def testPareto(self, b, dtype):

def testParetoShape(self):
key = random.PRNGKey(0)
x = random.pareto(key, np.array([0.2, 0.3]), shape=(3, 2))
with jax.numpy_rank_promotion('allow'):
x = random.pareto(key, np.array([0.2, 0.3]), shape=(3, 2))
assert x.shape == (3, 2)

@parameterized.named_parameters(jtu.cases_from_list(
Expand Down Expand Up @@ -666,8 +676,9 @@ def testMultivariateNormal(self, dim, dtype, method):
shape=(10000,), method=method)
crand = api.jit(rand)

uncompiled_samples = np.asarray(rand(key), np.float64)
compiled_samples = np.asarray(crand(key), np.float64)
with jax.numpy_rank_promotion('allow'):
uncompiled_samples = np.asarray(rand(key), np.float64)
compiled_samples = np.asarray(crand(key), np.float64)

inv_scale = scipy.linalg.lapack.dtrtri(np.linalg.cholesky(cov), lower=True)[0]
for samples in [uncompiled_samples, compiled_samples]:
Expand Down Expand Up @@ -701,7 +712,8 @@ def testMultivariateNormalShapes(self, dim, mean_batch_size, cov_batch_size,
cov = np.einsum('...ij,...kj->...ik', cov_factor, cov_factor)
cov += 1e-3 * np.eye(dim)
shape = shape + eff_batch_size
samples = random.multivariate_normal(key, mean, cov, shape=shape)
with jax.numpy_rank_promotion('allow'):
samples = random.multivariate_normal(key, mean, cov, shape=shape)
assert samples.shape == shape + (dim,)

def testMultivariateNormalCovariance(self):
Expand All @@ -716,7 +728,8 @@ def testMultivariateNormalCovariance(self):
out_np = np.random.RandomState(0).multivariate_normal(mean, cov, N)

key = random.PRNGKey(0)
out_jnp = random.multivariate_normal(key, mean=mean, cov=cov, shape=(N,))
with jax.numpy_rank_promotion('allow'):
out_jnp = random.multivariate_normal(key, mean=mean, cov=cov, shape=(N,))

var_np = out_np.var(axis=0)
var_jnp = out_jnp.var(axis=0)
Expand Down Expand Up @@ -819,10 +832,11 @@ def testRandomBroadcast(self):
# test for broadcast issue in https://github.com/google/jax/issues/4033
key = random.PRNGKey(0)
shape = (10, 2)
x = random.uniform(key, shape, minval=jnp.zeros(2), maxval=jnp.ones(2))
assert x.shape == shape
x = random.randint(key, shape, jnp.array([0, 1]), jnp.array([1, 2]))
assert x.shape == shape
with jax.numpy_rank_promotion('allow'):
x1 = random.uniform(key, shape, minval=jnp.zeros(2), maxval=jnp.ones(2))
x2 = random.randint(key, shape, jnp.array([0, 1]), jnp.array([1, 2]))
assert x1.shape == shape
assert x2.shape == shape

def testMaxwellSample(self):
num_samples = 10**5
Expand Down

0 comments on commit 233d9f7

Please sign in to comment.