From ce9c2d650a5c4bf6f6f418a265d12937c3949201 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Wed, 5 Jul 2023 13:12:13 -0700 Subject: [PATCH] rename `seed_prng` test method to `make_key` --- tests/random_test.py | 178 +++++++++++++++++++++---------------------- 1 file changed, 89 insertions(+), 89 deletions(-) diff --git a/tests/random_test.py b/tests/random_test.py index 9fab0b84d0f9..c32f3e875eb2 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -609,7 +609,7 @@ def _CheckChiSquared(self, samples, pmf): 'Expected vs. actual frequencies:\n' f'{expected_freq}\n{actual_freq}') - def seed_prng(self, seed): + def make_key(self, seed): return random.threefry2x32_key(seed) @jtu.sample_product(dtype=jtu.dtypes.floating) @@ -622,7 +622,7 @@ def testNumpyAndXLAAgreeOnFloatEndianness(self, dtype): @jtu.sample_product(dtype=float_dtypes) def testRngUniform(self, dtype): - key = self.seed_prng(0) + key = self.make_key(0) rand = lambda key: random.uniform(key, (10000,), dtype) crand = jax.jit(rand) @@ -638,7 +638,7 @@ def testRngRandint(self, dtype): lo = 5 hi = 10 - key = self.seed_prng(0) + key = self.make_key(0) rand = lambda key: random.randint(key, (10000,), lo, hi, dtype) crand = jax.jit(rand) @@ -651,7 +651,7 @@ def testRngRandint(self, dtype): @jtu.sample_product(dtype=float_dtypes) def testNormal(self, dtype): - key = self.seed_prng(0) + key = self.make_key(0) rand = lambda key: random.normal(key, (10000,), dtype) crand = jax.jit(rand) @@ -664,13 +664,13 @@ def testNormal(self, dtype): def testNormalBfloat16(self): # Passing bfloat16 as dtype string. # https://github.com/google/jax/issues/6813 - res_bfloat16_str = random.normal(self.seed_prng(0), dtype='bfloat16') - res_bfloat16 = random.normal(self.seed_prng(0), dtype=jnp.bfloat16) + res_bfloat16_str = random.normal(self.make_key(0), dtype='bfloat16') + res_bfloat16 = random.normal(self.make_key(0), dtype=jnp.bfloat16) self.assertAllClose(res_bfloat16, res_bfloat16_str) @jtu.sample_product(dtype=complex_dtypes) def testNormalComplex(self, dtype): - key = self.seed_prng(0) + key = self.make_key(0) rand = lambda key: random.normal(key, (10000,), dtype) crand = jax.jit(rand) @@ -684,7 +684,7 @@ def testNormalComplex(self, dtype): @jtu.sample_product(dtype=float_dtypes) def testTruncatedNormal(self, dtype): - key = self.seed_prng(0) + key = self.make_key(0) rand = lambda key: random.truncated_normal(key, -0.3, 0.3, (10000,), dtype) crand = jax.jit(rand) @@ -700,7 +700,7 @@ def testTruncatedNormal(self, dtype): @jtu.sample_product(dtype=jtu.dtypes.floating + jtu.dtypes.integer) def testShuffle(self, dtype): - key = self.seed_prng(0) + key = self.make_key(0) x = np.arange(100).astype(dtype) rand = lambda key: random.shuffle(key, x) crand = jax.jit(rand) @@ -734,7 +734,7 @@ def testChoice(self, dtype, input_range_or_shape, shape, replace, weighted, axis np_choice = np.random.default_rng(0).choice p_dtype = dtypes.to_inexact_dtype(dtype) - key = self.seed_prng(0) + key = self.make_key(0) is_range = type(input_range_or_shape) is int x = (input_range_or_shape if is_range else self.rng().permutation(np.arange(math.prod( @@ -774,7 +774,7 @@ def lsort(x): independent=[True, False], ) def testPermutation(self, dtype, range_or_shape, axis, independent): - key = self.seed_prng(0) + key = self.make_key(0) is_range = type(range_or_shape) is int x = (range_or_shape if is_range else self.rng().permutation(np.arange( @@ -803,7 +803,7 @@ def lsort(x): 'x' if is_range else None)(key, x)) def testPermutationErrors(self): - key = self.seed_prng(0) + key = self.make_key(0) with self.assertRaises(ValueError): random.permutation(key, 10, axis=3) with self.assertRaises(TypeError): @@ -816,7 +816,7 @@ def testPermutationErrors(self): dtype=jtu.dtypes.floating, ) def testBernoulli(self, p, dtype): - key = self.seed_prng(0) + key = self.make_key(0) p = np.array(p, dtype=dtype) rand = lambda key, p: random.bernoulli(key, p, (10000,)) crand = jax.jit(rand) @@ -840,7 +840,7 @@ def testBernoulli(self, p, dtype): dtype=jtu.dtypes.floating, ) def testCategorical(self, p, axis, dtype, sample_shape): - key = self.seed_prng(0) + key = self.make_key(0) p = np.array(p, dtype=dtype) logits = np.log(p) - 42 # test unnormalized out_shape = tuple(np.delete(logits.shape, axis)) @@ -867,7 +867,7 @@ def testCategorical(self, p, axis, dtype, sample_shape): self._CheckChiSquared(samples, pmf=pmf) def testBernoulliShape(self): - key = self.seed_prng(0) + key = self.make_key(0) with jax.numpy_rank_promotion('allow'): x = random.bernoulli(key, np.array([0.2, 0.3]), shape=(3, 2)) assert x.shape == (3, 2) @@ -880,7 +880,7 @@ def testBernoulliShape(self): def testBeta(self, a, b, dtype): if not config.x64_enabled: raise SkipTest("skip test except on X64") - key = self.seed_prng(0) + key = self.make_key(0) rand = lambda key, a, b: random.beta(key, a, b, (10000,), dtype) crand = jax.jit(rand) @@ -892,7 +892,7 @@ def testBeta(self, a, b, dtype): def testBetaSmallParameters(self, dtype=np.float32): # Regression test for beta version of https://github.com/google/jax/issues/9896 - key = self.seed_prng(0) + key = self.make_key(0) a, b = 0.0001, 0.0002 samples = random.beta(key, a, b, shape=(100,), dtype=dtype) @@ -907,7 +907,7 @@ def testBetaSmallParameters(self, dtype=np.float32): @jtu.sample_product(dtype=float_dtypes) def testCauchy(self, dtype): - key = self.seed_prng(0) + key = self.make_key(0) rand = lambda key: random.cauchy(key, (10000,), dtype) crand = jax.jit(rand) @@ -923,7 +923,7 @@ def testCauchy(self, dtype): ) @jtu.skip_on_devices("tpu") # TODO(mattjj): slow compilation times def testDirichlet(self, alpha, dtype): - key = self.seed_prng(0) + key = self.make_key(0) rand = lambda key, alpha: random.dirichlet(key, alpha, (10000,), dtype) crand = jax.jit(rand) @@ -938,7 +938,7 @@ def testDirichlet(self, alpha, dtype): def testDirichletSmallAlpha(self, dtype=np.float32): # Regression test for https://github.com/google/jax/issues/9896 - key = self.seed_prng(0) + key = self.make_key(0) alpha = 0.0001 * jnp.ones(3) samples = random.dirichlet(key, alpha, shape=(100,), dtype=dtype) @@ -953,7 +953,7 @@ def testDirichletSmallAlpha(self, dtype=np.float32): @jtu.sample_product(dtype=float_dtypes) def testExponential(self, dtype): - key = self.seed_prng(0) + key = self.make_key(0) rand = lambda key: random.exponential(key, (10000,), dtype) crand = jax.jit(rand) @@ -968,7 +968,7 @@ def testExponential(self, dtype): dtype=jtu.dtypes.floating, ) def testGammaVsLogGamma(self, a, dtype): - key = self.seed_prng(0) + key = self.make_key(0) rand_gamma = lambda key, a: random.gamma(key, a, (10000,), dtype) rand_loggamma = lambda key, a: random.loggamma(key, a, (10000,), dtype) crand_loggamma = jax.jit(rand_loggamma) @@ -981,7 +981,7 @@ def testGammaVsLogGamma(self, a, dtype): dtype=jtu.dtypes.floating, ) def testGamma(self, a, dtype): - key = self.seed_prng(0) + key = self.make_key(0) rand = lambda key, a: random.gamma(key, a, (10000,), dtype) crand = jax.jit(rand) @@ -992,7 +992,7 @@ def testGamma(self, a, dtype): self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gamma(a).cdf) def testGammaShape(self): - key = self.seed_prng(0) + key = self.make_key(0) x = random.gamma(key, np.array([0.2, 0.3]), shape=(3, 2)) assert x.shape == (3, 2) @@ -1001,7 +1001,7 @@ def testGammaShape(self): alpha=[1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4], ) def testGammaGrad(self, log_space, alpha): - rng = self.seed_prng(0) + rng = self.make_key(0) alphas = np.full((100,), alpha) z = random.gamma(rng, alphas) if log_space: @@ -1025,7 +1025,7 @@ def testGammaGrad(self, log_space, alpha): def testGammaGradType(self): # Regression test for https://github.com/google/jax/issues/2130 - key = self.seed_prng(0) + key = self.make_key(0) a = jnp.array(1., dtype=jnp.float32) b = jnp.array(3., dtype=jnp.float32) f = lambda x, y: random.gamma(key=key, a=x, dtype=jnp.float32) / y @@ -1037,7 +1037,7 @@ def testGammaGradType(self): dtype=[np.int16, np.int32, np.int64], ) def testPoisson(self, lam, dtype): - key = self.seed_prng(0) + key = self.make_key(0) rand = lambda key, lam: random.poisson(key, lam, (10000,), dtype) crand = jax.jit(rand) @@ -1052,38 +1052,38 @@ def testPoisson(self, lam, dtype): self.assertAllClose(samples.var(), lam, rtol=0.03, check_dtypes=False) def testPoissonBatched(self): - key = self.seed_prng(1) + key = self.make_key(1) lam = jnp.concatenate([2 * jnp.ones(10000), 20 * jnp.ones(10000)]) samples = random.poisson(key, lam, shape=(20000,)) self._CheckChiSquared(samples[:10000], scipy.stats.poisson(2.0).pmf) self._CheckChiSquared(samples[10000:], scipy.stats.poisson(20.0).pmf) def testPoissonWithoutShape(self): - key = self.seed_prng(1) + key = self.make_key(1) lam = 2 * jnp.ones(10000) samples = random.poisson(key, lam) self._CheckChiSquared(samples, scipy.stats.poisson(2.0).pmf) def testPoissonShape(self): - key = self.seed_prng(0) + key = self.make_key(0) x = random.poisson(key, np.array([2.0, 20.0]), shape=(3, 2)) assert x.shape == (3, 2) def testPoissonZeros(self): - key = self.seed_prng(0) + key = self.make_key(0) lam = jnp.concatenate([jnp.zeros(10), 20 * jnp.ones(10)]) samples = random.poisson(key, lam, shape=(2, 20)) self.assertArraysEqual(samples[:, :10], jnp.zeros_like(samples[:, :10])) def testPoissonCornerCases(self): - key = self.seed_prng(0) + key = self.make_key(0) lam = jnp.array([-1, 0, jnp.nan]) samples = random.poisson(key, lam, shape=(3,)) self.assertArraysEqual(samples, jnp.array([-1, 0, -1]), check_dtypes=False) @jtu.sample_product(dtype=jtu.dtypes.floating) def testGumbel(self, dtype): - key = self.seed_prng(0) + key = self.make_key(0) rand = lambda key: random.gumbel(key, (10000,), dtype) crand = jax.jit(rand) @@ -1095,7 +1095,7 @@ def testGumbel(self, dtype): @jtu.sample_product(dtype=float_dtypes) def testLaplace(self, dtype): - key = self.seed_prng(0) + key = self.make_key(0) rand = lambda key: random.laplace(key, (10000,), dtype) crand = jax.jit(rand) @@ -1107,7 +1107,7 @@ def testLaplace(self, dtype): @jtu.sample_product(dtype=float_dtypes) def testLogistic(self, dtype): - key = self.seed_prng(0) + key = self.make_key(0) rand = lambda key: random.logistic(key, (10000,), dtype) crand = jax.jit(rand) @@ -1124,7 +1124,7 @@ def testLogistic(self, dtype): ) @jax.default_matmul_precision("float32") def testOrthogonal(self, n, shape, dtype): - key = self.seed_prng(0) + key = self.make_key(0) q = random.orthogonal(key, n, shape, dtype) self.assertEqual(q.shape, (*shape, n, n)) self.assertEqual(q.dtype, dtype) @@ -1140,7 +1140,7 @@ def testOrthogonal(self, n, shape, dtype): dtype=jtu.dtypes.floating, ) def testGeneralizedNormal(self, p, shape, dtype): - key = self.seed_prng(0) + key = self.make_key(0) rand = lambda key, p: random.generalized_normal(key, p, shape, dtype) crand = jax.jit(rand) uncompiled_samples = rand(key, p) @@ -1157,7 +1157,7 @@ def testGeneralizedNormal(self, p, shape, dtype): dtype=jtu.dtypes.floating, ) def testBall(self, d, p, shape, dtype): - key = self.seed_prng(0) + key = self.make_key(0) rand = lambda key, p: random.ball(key, d, p, shape, dtype) crand = jax.jit(rand) uncompiled_samples = rand(key, p) @@ -1174,7 +1174,7 @@ def testBall(self, d, p, shape, dtype): dtype=jtu.dtypes.floating, ) def testPareto(self, b, dtype): - key = self.seed_prng(0) + key = self.make_key(0) rand = lambda key, b: random.pareto(key, b, (10000,), dtype) crand = jax.jit(rand) @@ -1185,7 +1185,7 @@ def testPareto(self, b, dtype): self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.pareto(b).cdf) def testParetoShape(self): - key = self.seed_prng(0) + key = self.make_key(0) with jax.numpy_rank_promotion('allow'): x = random.pareto(key, np.array([0.2, 0.3]), shape=(3, 2)) assert x.shape == (3, 2) @@ -1196,7 +1196,7 @@ def testParetoShape(self): ) @jtu.skip_on_devices("cpu", "tpu") # TODO(phawkins): slow compilation times def testT(self, df, dtype): - key = self.seed_prng(1) + key = self.make_key(1) rand = lambda key, df: random.t(key, df, (10000,), dtype) crand = jax.jit(rand) @@ -1217,7 +1217,7 @@ def testMultivariateNormal(self, dim, dtype, method): cov_factor = r.randn(dim, dim) cov = np.dot(cov_factor, cov_factor.T) + dim * np.eye(dim) - key = self.seed_prng(0) + key = self.make_key(0) rand = partial(random.multivariate_normal, mean=mean, cov=cov, shape=(10000,), method=method) crand = jax.jit(rand) @@ -1246,7 +1246,7 @@ def testMultivariateNormal(self, dim, dtype, method): def testMultivariateNormalShapes(self, dim, mean_batch_size, cov_batch_size, shape, method): r = self.rng() - key = self.seed_prng(0) + key = self.make_key(0) eff_batch_size = mean_batch_size \ if len(mean_batch_size) > len(cov_batch_size) else cov_batch_size mean = r.randn(*(mean_batch_size + (dim,))) @@ -1269,7 +1269,7 @@ def testMultivariateNormalCovariance(self): out_np = self.rng().multivariate_normal(mean, cov, N) - key = self.seed_prng(0) + key = self.make_key(0) with jax.numpy_rank_promotion('allow'): out_jnp = random.multivariate_normal(key, mean=mean, cov=cov, shape=(N,)) @@ -1289,7 +1289,7 @@ def testMultivariateNormalSingularCovariance(self, method): # Singular covariance matrix https://github.com/google/jax/discussions/13293 mu = jnp.zeros((2,)) sigma = jnp.ones((2, 2)) - key = self.seed_prng(0) + key = self.make_key(0) result = random.multivariate_normal(key, mean=mu, cov=sigma, shape=(10,), method=method) self.assertAllClose(result[:, 0], result[:, 1], atol=1e-3, rtol=1e-3) @@ -1300,16 +1300,16 @@ def testMultivariateNormalSingularCovariance(self, method): self.assertFalse(np.any(np.isnan(result))) def testIssue222(self): - x = random.randint(self.seed_prng(10003), (), 0, 0) + x = random.randint(self.make_key(10003), (), 0, 0) assert x == 0 def testFoldIn(self): - key = self.seed_prng(0) + key = self.make_key(0) keys = [_prng_key_as_array(random.fold_in(key, i)) for i in range(10)] assert np.unique(keys, axis=0).shape[0] == 10 def testFoldInBig(self): - key = self.seed_prng(0) + key = self.make_key(0) seeds = [2 ** 32 - 2, 2 ** 32 - 1] keys = [_prng_key_as_array(random.fold_in(key, seed)) for seed in seeds] assert np.unique(keys, axis=0).shape[0] == 2 @@ -1320,7 +1320,7 @@ def testStaticShapeErrors(self): @jax.jit def feature_map(n, d, sigma=1.0, seed=123): - key = self.seed_prng(seed) + key = self.make_key(seed) W = random.normal(key, (d, n)) / sigma w = random.normal(key, (d, )) / sigma b = 2 * jnp.pi * random.uniform(key, (d, )) @@ -1332,24 +1332,24 @@ def feature_map(n, d, sigma=1.0, seed=123): lambda: feature_map(5, 3)) def testIssue756(self): - key = self.seed_prng(0) + key = self.make_key(0) w = random.normal(key, ()) self.assertEqual(w.dtype, dtypes.canonicalize_dtype(jnp.float_)) def testIssue1789(self): def f(x): - return random.gamma(self.seed_prng(0), x) + return random.gamma(self.make_key(0), x) grad(lambda x: jnp.sum(vmap(f)(x)))(jnp.ones(2)) def testDtypeErrorMessage(self): with self.assertRaisesRegex(ValueError, r"dtype argument to.*"): - random.normal(self.seed_prng(0), (), dtype=jnp.int32) + random.normal(self.make_key(0), (), dtype=jnp.int32) def testRandomBroadcast(self): """Issue 4033""" # test for broadcast issue in https://github.com/google/jax/issues/4033 - key = self.seed_prng(0) + key = self.make_key(0) shape = (10, 2) with jax.numpy_rank_promotion('allow'): x1 = random.uniform(key, shape, minval=jnp.zeros(2), maxval=jnp.ones(2)) @@ -1359,7 +1359,7 @@ def testRandomBroadcast(self): def testMaxwellSample(self): num_samples = 10**5 - rng = self.seed_prng(0) + rng = self.make_key(0) rand = lambda x: random.maxwell(x, (num_samples, )) crand = jax.jit(rand) @@ -1382,7 +1382,7 @@ def testMaxwellSample(self): ('test2', 2.0, 3.0)) def testWeibullSample(self, concentration, scale): num_samples = 10**5 - rng = self.seed_prng(0) + rng = self.make_key(0) rand = lambda x: random.weibull_min(x, scale, concentration, (num_samples,)) crand = jax.jit(rand) @@ -1406,7 +1406,7 @@ def testWeibullSample(self, concentration, scale): ('test2', 2.0, 3.0)) def testDoublesidedMaxwellSample(self, loc, scale): num_samples = 10**4 - rng = self.seed_prng(0) + rng = self.make_key(0) rand = lambda key: random.double_sided_maxwell( rng, loc, scale, (num_samples,)) @@ -1443,7 +1443,7 @@ def double_sided_maxwell_cdf(x, loc, scale): samples, lambda x: double_sided_maxwell_cdf(x, loc, scale)) def testRadamacher(self): - rng = self.seed_prng(0) + rng = self.make_key(0) num_samples = 10**5 rand = lambda x: random.rademacher(x, (num_samples,)) @@ -1463,7 +1463,7 @@ def testRadamacher(self): counts[1] / num_samples, 0.5, rtol=1e-02, atol=1e-02) def testChoiceShapeIsNotSequenceError(self): - key = self.seed_prng(0) + key = self.make_key(0) with self.assertRaises(TypeError): random.choice(key, 5, 2, replace=False) with self.assertRaises(TypeError): @@ -1471,7 +1471,7 @@ def testChoiceShapeIsNotSequenceError(self): def test_eval_shape_big_random_array(self): def f(x): - return random.normal(self.seed_prng(x), (int(1e12),)) + return random.normal(self.make_key(x), (int(1e12),)) with jax.enable_checks(False): # check_jaxpr will materialize array jax.eval_shape(f, 0) # doesn't error @@ -1486,18 +1486,18 @@ def test_prng_jit_invariance(self, seed, type_): self.skipTest("Expected failure: Python int too large.") type_ = {"int": int, "np.array": np.array, "jnp.array": jnp.array}[type_] args_maker = lambda: [type_(seed)] - f = lambda s: _maybe_unwrap(self.seed_prng(s)) + f = lambda s: _maybe_unwrap(self.make_key(s)) self._CompileAndCheck(f, args_maker) def test_prng_errors(self): seed = np.iinfo(np.int64).max + 1 with self.assertRaises(OverflowError): - self.seed_prng(seed) + self.make_key(seed) with self.assertRaises(OverflowError): - jax.jit(self.seed_prng)(seed) + jax.jit(self.make_key)(seed) def test_random_split_doesnt_device_put_during_tracing(self): - key = self.seed_prng(1).block_until_ready() + key = self.make_key(1).block_until_ready() with jtu.count_device_put() as count: jax.jit(random.split)(key) self.assertLessEqual(count[0], 1) # 1 for the argument device_put @@ -1506,7 +1506,7 @@ def test_random_split_doesnt_device_put_during_tracing(self): def test_randint_bounds(self, dtype): min = np.iinfo(dtype).min max = np.iinfo(dtype).max - key = self.seed_prng(1701) + key = self.make_key(1701) shape = (10,) if np.iinfo(dtype).bits < np.iinfo(dtypes.canonicalize_dtype(int)).bits: expected = random.randint(key, shape, min, max, dtype) @@ -1515,7 +1515,7 @@ def test_randint_bounds(self, dtype): self.assertRaises(OverflowError, random.randint, key, shape, min - 12345, max + 12345, dtype) def test_randint_out_of_range(self): - key = self.seed_prng(0) + key = self.make_key(0) r = random.randint(key, (10,), 255, 256, np.uint8) self.assertAllClose(r, jnp.full_like(r, 255)) @@ -1532,7 +1532,7 @@ def test_large_prng(self): # https://github.com/google/jax/issues/11010 def f(): return random.uniform( - self.seed_prng(3), (308000000, 128), dtype=jnp.bfloat16) + self.make_key(3), (308000000, 128), dtype=jnp.bfloat16) # just lower, don't run, takes too long jax.jit(f).lower() @@ -1546,7 +1546,7 @@ def test_categorical_shape_argument(self, shape, logits_shape_base, axis): logits_shape.insert(axis % (len(logits_shape_base) + 1), 10) assert logits_shape[axis] == 10 logits = jnp.ones(logits_shape) - samples = random.categorical(self.seed_prng(0), logits=logits, + samples = random.categorical(self.make_key(0), logits=logits, axis=axis, shape=shape) self.assertEqual(samples.shape, shape) @@ -1554,7 +1554,7 @@ def test_categorical_shape_argument(self, shape, logits_shape_base, axis): df = [0.2, 1., 10., 100.], dtype=jtu.dtypes.floating) def testChisquare(self, df, dtype): - key = self.seed_prng(0) + key = self.make_key(0) def rand(key, df): return random.chisquare(key, df, shape=(10000,), dtype=dtype) @@ -1571,7 +1571,7 @@ def rand(key, df): dfden = [1. ,2., 10., 100.], dtype=jtu.dtypes.floating) def testF(self, dfnum, dfden, dtype): - key = self.seed_prng(1) + key = self.make_key(1) rand = lambda key: random.f(key, dfnum, dfden, shape = (10000, ), dtype = dtype) crand = jax.jit(rand) @@ -1585,7 +1585,7 @@ def testF(self, dfnum, dfden, dtype): scale= [0.2, 1., 2., 10. ,100.], dtype=jtu.dtypes.floating) def testRayleigh(self, scale, dtype): - key = self.seed_prng(0) + key = self.make_key(0) rand = lambda key: random.rayleigh(key, scale, shape = (10000, ), dtype = dtype) crand = jax.jit(rand) @@ -1599,7 +1599,7 @@ def testRayleigh(self, scale, dtype): mean= [0.2, 1., 2., 10. ,100.], dtype=jtu.dtypes.floating) def testWald(self, mean, dtype): - key = self.seed_prng(0) + key = self.make_key(0) rand = lambda key: random.wald(key, mean, shape=(10000, ), dtype=dtype) crand = jax.jit(rand) @@ -1613,7 +1613,7 @@ def testWald(self, mean, dtype): p= [0.2, 0.3, 0.4, 0.5 ,0.6], dtype= [np.int16, np.int32, np.int64]) def testGeometric(self, p, dtype): - key = self.seed_prng(1) + key = self.make_key(1) rand = lambda key: random.geometric(key, p, shape=(10000, ), dtype=dtype) crand = jax.jit(rand) @@ -2069,17 +2069,17 @@ def _double_threefry_fold_in(key, data): @jtu.with_config(jax_default_prng_impl='threefry2x32') class LaxRandomWithCustomPRNGTest(LaxRandomTest): - def seed_prng(self, seed): + def make_key(self, seed): return prng.seed_with_impl(double_threefry_prng_impl, seed) def test_split_shape(self): - key = self.seed_prng(73) + key = self.make_key(73) keys = random.split(key, 10) self.assertEqual(keys.shape, (10,)) def test_vmap_fold_in_shape(self): # broadcast with scalar - keys = random.split(self.seed_prng(73), 2) + keys = random.split(self.make_key(73), 2) msgs = jnp.arange(3) out = vmap(lambda i: random.fold_in(keys[0], i))(msgs) self.assertEqual(out.shape, (3,)) @@ -2096,7 +2096,7 @@ def test_vmap_fold_in_shape(self): self.assertEqual(out.shape, (2,)) # nested vmap - keys = random.split(self.seed_prng(73), 2 * 3).reshape((2, 3)) + keys = random.split(self.make_key(73), 2 * 3).reshape((2, 3)) msgs = jnp.arange(2 * 3).reshape((2, 3)) out = vmap(vmap(random.fold_in), in_axes=(0, 1))(keys, msgs.T) self.assertEqual(out.shape, (2, 3)) @@ -2104,7 +2104,7 @@ def test_vmap_fold_in_shape(self): self.assertEqual(out.shape, (3, 2)) def test_vmap_split_mapped_key(self): - key = self.seed_prng(73) + key = self.make_key(73) mapped_keys = random.split(key, num=3) forloop_keys = [random.split(k) for k in mapped_keys] vmapped_keys = vmap(random.split)(mapped_keys) @@ -2114,7 +2114,7 @@ def test_vmap_split_mapped_key(self): vk.unsafe_raw_array()) def test_cannot_add(self): - key = self.seed_prng(73) + key = self.make_key(73) self.assertRaisesRegex( ValueError, r'dtype=key<.*> is not a valid dtype for JAX type promotion.', lambda: key + 47) @@ -2122,7 +2122,7 @@ def test_cannot_add(self): @skipIf(np.__version__ == "1.21.0", "https://github.com/numpy/numpy/issues/19305") def test_grad_of_prng_key(self): - key = self.seed_prng(73) + key = self.make_key(73) with self.assertRaisesRegex(TypeError, 'grad requires real- or complex-valued inputs'): jax.grad(lambda x: 1.)(key) out = jax.grad(lambda x: 1., allow_int=True)(key) @@ -2131,17 +2131,17 @@ def test_grad_of_prng_key(self): @jtu.with_config(jax_default_prng_impl='rbg') class LaxRandomWithRBGPRNGTest(LaxRandomTest): - def seed_prng(self, seed): - return random.PRNGKey(seed, impl='rbg') + def make_key(self, seed): + return random.rbg_key(seed) def test_split_shape(self): - key = self.seed_prng(73) + key = self.make_key(73) keys = random.split(key, 10) self.assertEqual(keys.shape, (10, *key.shape)) def test_vmap_fold_in_shape(self): # broadcast with scalar - keys = random.split(self.seed_prng(73), 2) + keys = random.split(self.make_key(73), 2) msgs = jnp.arange(3) out = vmap(lambda i: random.fold_in(keys[0], i))(msgs) @@ -2155,7 +2155,7 @@ def test_vmap_fold_in_shape(self): self.assertEqual(out.shape, keys.shape) def test_vmap_split_not_mapped_key(self): - key = self.seed_prng(73) + key = self.make_key(73) single_split_key = random.split(key) vmapped_keys = vmap(lambda _: random.split(key))(jnp.zeros(3,)) self.assertEqual(vmapped_keys.shape, (3, 2, *key.shape)) @@ -2164,7 +2164,7 @@ def test_vmap_split_not_mapped_key(self): _prng_key_as_array(single_split_key)) def test_vmap_split_mapped_key(self): - key = self.seed_prng(73) + key = self.make_key(73) mapped_keys = random.split(key, num=3) forloop_keys = [random.split(k) for k in mapped_keys] vmapped_keys = vmap(random.split)(mapped_keys) @@ -2175,7 +2175,7 @@ def test_vmap_split_mapped_key(self): def test_vmap_random_bits(self): rand_fun = lambda key: random.randint(key, (), 0, 100) - key = self.seed_prng(73) + key = self.make_key(73) mapped_keys = random.split(key, num=3) forloop_rand_nums = [rand_fun(k) for k in mapped_keys] rand_nums = vmap(rand_fun)(mapped_keys) @@ -2183,7 +2183,7 @@ def test_vmap_random_bits(self): self.assertArraysEqual(rand_nums, jnp.array(forloop_rand_nums)) def test_cannot_add(self): - key = self.seed_prng(73) + key = self.make_key(73) if not isinstance(key, random.PRNGKeyArray): raise SkipTest('relies on typed key arrays') self.assertRaisesRegex( @@ -2193,7 +2193,7 @@ def test_cannot_add(self): @skipIf(np.__version__ == "1.21.0", "https://github.com/numpy/numpy/issues/19305") def test_grad_of_prng_key(self): - key = self.seed_prng(73) + key = self.make_key(73) with self.assertRaisesRegex(TypeError, 'grad requires real- or complex-valued inputs'): jax.grad(lambda x: 1.)(key) out = jax.grad(lambda x: 1., allow_int=True)(key) @@ -2209,7 +2209,7 @@ def test_randint_out_of_range(self): @jtu.with_config(jax_default_prng_impl='unsafe_rbg') class LaxRandomWithUnsafeRBGPRNGTest(LaxRandomWithRBGPRNGTest): - def seed_prng(self, seed): + def make_key(self, seed): return random.unsafe_rbg_key(seed)