Skip to content

Commit

Permalink
consistently seed keys indirectly by test class method in `LaxRandomT…
Browse files Browse the repository at this point in the history
…est`
  • Loading branch information
froystig committed Jul 5, 2023
1 parent 556c112 commit ff70255
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,7 +1289,7 @@ def testMultivariateNormalSingularCovariance(self, method):
# Singular covariance matrix https://github.com/google/jax/discussions/13293
mu = jnp.zeros((2,))
sigma = jnp.ones((2, 2))
key = random.PRNGKey(0)
key = self.seed_prng(0)
result = random.multivariate_normal(key, mean=mu, cov=sigma, shape=(10,), method=method)
self.assertAllClose(result[:, 0], result[:, 1], atol=1e-3, rtol=1e-3)

Expand Down Expand Up @@ -1531,7 +1531,8 @@ def test_randint_out_of_range(self):
def test_large_prng(self):
# https://github.com/google/jax/issues/11010
def f():
return random.uniform(random.PRNGKey(3), (308000000, 128), dtype=jnp.bfloat16)
return random.uniform(
self.seed_prng(3), (308000000, 128), dtype=jnp.bfloat16)

# just lower, don't run, takes too long
jax.jit(f).lower()
Expand All @@ -1545,7 +1546,7 @@ def test_categorical_shape_argument(self, shape, logits_shape_base, axis):
logits_shape.insert(axis % (len(logits_shape_base) + 1), 10)
assert logits_shape[axis] == 10
logits = jnp.ones(logits_shape)
samples = random.categorical(random.PRNGKey(0), logits=logits,
samples = random.categorical(self.seed_prng(0), logits=logits,
axis=axis, shape=shape)
self.assertEqual(samples.shape, shape)

Expand All @@ -1555,7 +1556,8 @@ def test_categorical_shape_argument(self, shape, logits_shape_base, axis):
def testChisquare(self, df, dtype):
key = self.seed_prng(0)

rand = lambda key, df: random.chisquare(key, df, shape=(10000, ), dtype=dtype)
def rand(key, df):
return random.chisquare(key, df, shape=(10000,), dtype=dtype)
crand = jax.jit(rand)

uncompiled_samples = rand(key, df)
Expand Down

0 comments on commit ff70255

Please sign in to comment.