diff --git a/jax/_src/random.py b/jax/_src/random.py index 2a395d04e52e..0aba6729ca42 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -131,6 +131,11 @@ def key(seed: Union[int, Array]) -> PRNGKeyArray: """ # TODO(frostig): Take impl as optional argument impl = default_prng_impl() + if isinstance(seed, prng.PRNGKeyArray): + raise TypeError("key accepts a scalar seed, but was given a PRNGKeyArray.") + if np.ndim(seed): + raise TypeError("key accepts a scalar seed, but was given an array of " + f"shape {np.shape(seed)} != (). Use jax.vmap for batching") return prng.seed_with_impl(impl, seed) def PRNGKey(seed: Union[int, Array]) -> KeyArray: diff --git a/tests/random_test.py b/tests/random_test.py index c99942519fe8..7041f8915850 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -1654,6 +1654,22 @@ def test_key_as_seed(self): key = self.make_keys() with self.assertRaisesRegex(TypeError, "PRNGKey accepts a scalar seed"): jax.random.PRNGKey(key) + with self.assertRaisesRegex(TypeError, "key accepts a scalar seed"): + jax.random.key(key) + + def test_non_scalar_seed(self): + seed_arr = np.arange(4) + with self.assertRaisesRegex(TypeError, "PRNGKey accepts a scalar seed"): + jax.random.PRNGKey(seed_arr) + with self.assertRaisesRegex(TypeError, "key accepts a scalar seed"): + jax.random.key(seed_arr) + + def test_non_integer_seed(self): + seed = np.pi + with self.assertRaisesRegex(TypeError, "PRNG key seed must be an integer"): + jax.random.PRNGKey(seed) + with self.assertRaisesRegex(TypeError, "PRNG key seed must be an integer"): + jax.random.key(seed) def test_dtype_property(self): k1, k2 = self.make_keys(), self.make_keys()