From 951d515701c2f2f7b8f81a37e936f053831a0526 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 20 Jun 2023 01:16:25 -0700 Subject: [PATCH] random.key: error for non-scalar seeds. Previously, this function's implementation would implicitly map over non-scalar seed inputs. This is not the behavior we want, because in the future we may want to allow arrays of integers as a single seed. --- jax/_src/random.py | 5 +++++ tests/random_test.py | 16 ++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/jax/_src/random.py b/jax/_src/random.py index 2a395d04e52e..0aba6729ca42 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -131,6 +131,11 @@ def key(seed: Union[int, Array]) -> PRNGKeyArray: """ # TODO(frostig): Take impl as optional argument impl = default_prng_impl() + if isinstance(seed, prng.PRNGKeyArray): + raise TypeError("key accepts a scalar seed, but was given a PRNGKeyArray.") + if np.ndim(seed): + raise TypeError("key accepts a scalar seed, but was given an array of " + f"shape {np.shape(seed)} != (). Use jax.vmap for batching") return prng.seed_with_impl(impl, seed) def PRNGKey(seed: Union[int, Array]) -> KeyArray: diff --git a/tests/random_test.py b/tests/random_test.py index c99942519fe8..7041f8915850 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -1654,6 +1654,22 @@ def test_key_as_seed(self): key = self.make_keys() with self.assertRaisesRegex(TypeError, "PRNGKey accepts a scalar seed"): jax.random.PRNGKey(key) + with self.assertRaisesRegex(TypeError, "key accepts a scalar seed"): + jax.random.key(key) + + def test_non_scalar_seed(self): + seed_arr = np.arange(4) + with self.assertRaisesRegex(TypeError, "PRNGKey accepts a scalar seed"): + jax.random.PRNGKey(seed_arr) + with self.assertRaisesRegex(TypeError, "key accepts a scalar seed"): + jax.random.key(seed_arr) + + def test_non_integer_seed(self): + seed = np.pi + with self.assertRaisesRegex(TypeError, "PRNG key seed must be an integer"): + jax.random.PRNGKey(seed) + with self.assertRaisesRegex(TypeError, "PRNG key seed must be an integer"): + jax.random.key(seed) def test_dtype_property(self): k1, k2 = self.make_keys(), self.make_keys()