From 2be6019f1c99b234b91bf736578cd3d6886a6f18 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 17 Oct 2023 04:23:44 -0700 Subject: [PATCH] Rollback to fix internal breakage Reverts 7d203aebfa6206affde207c884b50172e203d177 PiperOrigin-RevId: 574101804 --- jax/_src/random.py | 23 ++++++++++++++-------- tests/random_test.py | 46 ++++++++++++++++++++++---------------------- 2 files changed, 38 insertions(+), 31 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index 8f7f11dcca2e..9b832821a64f 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -53,7 +53,10 @@ Shape = Sequence[int] PRNGImpl = prng.PRNGImpl -KeyArray = Array + +# TODO(frostig,vanderplas): remove after deprecation window +KeyArray = Union[Array, prng.PRNGKeyArray] +PRNGKeyArray = prng.PRNGKeyArray UINT_DTYPES = prng.UINT_DTYPES @@ -66,8 +69,9 @@ def _isnan(x: ArrayLike) -> Array: return lax.ne(x, x) -def _check_prng_key(key) -> tuple[KeyArray, bool]: - if isinstance(key, Array) and dtypes.issubdtype(key.dtype, dtypes.prng_key): +def _check_prng_key(key) -> tuple[prng.PRNGKeyArray, bool]: + # TODO(frostig): remove once we always enable_custom_prng + if isinstance(key, prng.PRNGKeyArray): return key, False elif _arraylike(key): # Call random_wrap here to surface errors for invalid keys. @@ -102,7 +106,7 @@ def _return_prng_keys(was_wrapped, key): return prng.random_unwrap(key) if was_wrapped else key -def _random_bits(key: KeyArray, bit_width, shape) -> Array: +def _random_bits(key: prng.PRNGKeyArray, bit_width, shape) -> Array: assert jnp.issubdtype(key.dtype, dtypes.prng_key) return prng.random_bits(key, bit_width=bit_width, shape=shape) @@ -166,18 +170,20 @@ def resolve_prng_impl( raise TypeError(f'unrecognized type {t} for specifying PRNG implementation.') -def _key(ctor_name: str, seed: Union[int, Array], impl_spec: Optional[str] ) -> KeyArray: +def _key(ctor_name: str, seed: Union[int, Array], impl_spec: Optional[str] + ) -> PRNGKeyArray: impl = resolve_prng_impl(impl_spec) if hasattr(seed, 'dtype') and jnp.issubdtype(seed.dtype, dtypes.prng_key): raise TypeError( - f"{ctor_name} accepts a scalar seed, but was given a PRNG key.") + f"{ctor_name} accepts a scalar seed, but was given a PRNGKeyArray.") if np.ndim(seed): raise TypeError( f"{ctor_name} accepts a scalar seed, but was given an array of " f"shape {np.shape(seed)} != (). Use jax.vmap for batching") return prng.seed_with_impl(impl, seed) -def key(seed: Union[int, Array], *, impl: Optional[str] = None) -> KeyArray: +def key(seed: Union[int, Array], *, + impl: Optional[str] = None) -> PRNGKeyArray: """Create a pseudo-random number generator (PRNG) key given an integer seed. The result is a scalar array with a key that indicates the default PRNG @@ -195,7 +201,8 @@ def key(seed: Union[int, Array], *, impl: Optional[str] = None) -> KeyArray: """ return _key('key', seed, impl) -def PRNGKey(seed: Union[int, Array], *, impl: Optional[str] = None) -> KeyArray: +def PRNGKey(seed: Union[int, Array], *, + impl: Optional[str] = None) -> KeyArray: """Create a pseudo-random number generator (PRNG) key given an integer seed. The resulting key carries the default PRNG implementation, as diff --git a/tests/random_test.py b/tests/random_test.py index 59c72247a1e1..2ad03710e9d0 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -592,7 +592,7 @@ def assertKeysEqual(self, key1, key2): def test_construction(self): key = random.key(42) - self.assertIsInstance(key, prng_internal.PRNGKeyArray) + self.assertIsInstance(key, jax_random.PRNGKeyArray) def test_issubdtype(self): key = random.key(42) @@ -610,7 +610,7 @@ def test_issubdtype(self): @skipIf(not config.enable_custom_prng.value, 'relies on typed key upgrade flag') def test_construction_upgrade_flag(self): key = random.PRNGKey(42) - self.assertIsInstance(key, prng_internal.PRNGKeyArray) + self.assertIsInstance(key, jax_random.PRNGKeyArray) def make_keys(self, *shape, seed=28): seeds = seed + jnp.arange(math.prod(shape), dtype=jnp.uint32) @@ -677,13 +677,13 @@ def test_key_copy(self): def test_isinstance(self): @jax.jit def f(k): - self.assertIsInstance(k, prng_internal.PRNGKeyArray) + self.assertIsInstance(k, jax_random.PRNGKeyArray) return k k1 = self.make_keys() k2 = f(k1) - self.assertIsInstance(k1, prng_internal.PRNGKeyArray) - self.assertIsInstance(k2, prng_internal.PRNGKeyArray) + self.assertIsInstance(k1, jax_random.PRNGKeyArray) + self.assertIsInstance(k2, jax_random.PRNGKeyArray) def test_cpp_dispatch_normal(self): # Ensure we stay on the C++ dispatch path when calling a jitted @@ -748,10 +748,10 @@ def test_random_wrap_vmap(self): f = partial(prng_internal.random_wrap, impl=prng_internal.threefry_prng_impl) base_arr = jnp.arange(6, dtype=jnp.uint32).reshape(3, 2) keys = jax.vmap(f, in_axes=0)(base_arr) - self.assertIsInstance(keys, prng_internal.PRNGKeyArray) + self.assertIsInstance(keys, jax_random.PRNGKeyArray) self.assertEqual(keys.shape, (3,)) keys = jax.vmap(f, in_axes=1)(base_arr.T) - self.assertIsInstance(keys, prng_internal.PRNGKeyArray) + self.assertIsInstance(keys, jax_random.PRNGKeyArray) self.assertEqual(keys.shape, (3,)) @jtu.sample_product(use_internal=[False, True]) @@ -848,20 +848,20 @@ def test_scan_lowering(self): ks = self.make_keys(3, 4) f = lambda ks: jax.lax.scan(lambda _, k: (None, k.T), None, ks) _, out = jax.jit(f)(ks) # doesn't crash - self.assertIsInstance(out, prng_internal.PRNGKeyArray) + self.assertIsInstance(out, jax_random.PRNGKeyArray) self.assertEqual(out.shape, (3, 4)) def test_slice(self): ks = self.make_keys(3, 4) ys = jax.jit(lambda x: lax.slice_in_dim(x, 1, 3))(ks) - self.assertIsInstance(ys, prng_internal.PRNGKeyArray) + self.assertIsInstance(ys, jax_random.PRNGKeyArray) self.assertEqual(ys.shape, (2, 4)) def test_dynamic_slice(self): ks = self.make_keys(3, 4) index = np.int16(1) # non-default int type to catch type errors. ys = jax.jit(partial(lax.dynamic_slice_in_dim, slice_size=2))(ks, index) - self.assertIsInstance(ys, prng_internal.PRNGKeyArray) + self.assertIsInstance(ys, jax_random.PRNGKeyArray) self.assertEqual(ys.shape, (2, 4)) def test_dynamic_update_slice(self): @@ -869,51 +869,51 @@ def test_dynamic_update_slice(self): k = self.make_keys(1, 4) index = np.int16(1) # non-default int type to catch type errors. ys = jax.jit(partial(lax.dynamic_update_slice_in_dim, axis=0))(ks, k, index) - self.assertIsInstance(ys, prng_internal.PRNGKeyArray) + self.assertIsInstance(ys, jax_random.PRNGKeyArray) self.assertEqual(ys.shape, (3, 4)) def test_transpose(self): ks = self.make_keys(3, 4) ys = jax.jit(lambda x: x.T)(ks) - self.assertIsInstance(ys, prng_internal.PRNGKeyArray) + self.assertIsInstance(ys, jax_random.PRNGKeyArray) self.assertEqual(ys.shape, (4, 3)) def test_gather(self): ks = self.make_keys(3, 4) ys = jax.jit(lambda x: x[1])(ks) - self.assertIsInstance(ys, prng_internal.PRNGKeyArray) + self.assertIsInstance(ys, jax_random.PRNGKeyArray) self.assertEqual(ys.shape, (4,)) ks = self.make_keys(3, 4, 5) ys = jax.jit(lambda x: x[1])(ks) - self.assertIsInstance(ys, prng_internal.PRNGKeyArray) + self.assertIsInstance(ys, jax_random.PRNGKeyArray) self.assertEqual(ys.shape, (4, 5)) ys = jax.jit(lambda x: x[1, 2:4])(ks) - self.assertIsInstance(ys, prng_internal.PRNGKeyArray) + self.assertIsInstance(ys, jax_random.PRNGKeyArray) self.assertEqual(ys.shape, (2, 5)) ys = jax.jit(lambda x: x[1, 2:4, 3])(ks) - self.assertIsInstance(ys, prng_internal.PRNGKeyArray) + self.assertIsInstance(ys, jax_random.PRNGKeyArray) self.assertEqual(ys.shape, (2,)) ys = jax.jit(lambda x: x[:, 2:4, 3:4])(ks) - self.assertIsInstance(ys, prng_internal.PRNGKeyArray) + self.assertIsInstance(ys, jax_random.PRNGKeyArray) self.assertEqual(ys.shape, (3, 2, 1)) def test_select(self): ks = self.make_keys(3, 2) cs = jnp.array([True, False, False, True, False, True]).reshape(3, 2) ys = jax.jit(lax.select)(cs, ks, ks) - self.assertIsInstance(ys, prng_internal.PRNGKeyArray) + self.assertIsInstance(ys, jax_random.PRNGKeyArray) self.assertEqual(ys.shape, (3, 2)) def test_select_scalar_cond(self): # regression test for https://github.com/google/jax/issues/16422 ks = self.make_keys(3) ys = lax.select(True, ks, ks) - self.assertIsInstance(ys, prng_internal.PRNGKeyArray) + self.assertIsInstance(ys, jax_random.PRNGKeyArray) self.assertEqual(ys.shape, (3,)) def test_vmap_of_cond(self): @@ -986,7 +986,7 @@ def f_jvp(primals, tangents): custom_result = jax.grad(f)(0.0, key) self.assertAllClose(default_result, custom_result) - self.assertIsInstance(key_dot, prng_internal.PRNGKeyArray) + self.assertIsInstance(key_dot, jax_random.PRNGKeyArray) self.assertArraysEqual(random.key_data(key_dot), np.uint32(0)) def test_key_array_indexing_0d(self): @@ -1155,7 +1155,7 @@ def assertKeysEqual(self, key1, key2): def check_shape(self, func, *args): like = lambda keys: jnp.ones(keys.shape) out_key = func(*args) - self.assertIsInstance(out_key, prng_internal.PRNGKeyArray) + self.assertIsInstance(out_key, jax_random.PRNGKeyArray) out_like_key = func(*tree_util.tree_map(like, args)) self.assertIsInstance(out_like_key, jax.Array) self.assertEqual(out_key.shape, out_like_key.shape) @@ -1166,11 +1166,11 @@ def check_against_reference(self, key_func, arr_func, *key_args): self.assertIsInstance(out_arr, jax.Array) out_key = key_func(*key_args) - self.assertIsInstance(out_key, prng_internal.PRNGKeyArray) + self.assertIsInstance(out_key, jax_random.PRNGKeyArray) self.assertArraysEqual(random.key_data(out_key), out_arr) out_key = jax.jit(key_func)(*key_args) - self.assertIsInstance(out_key, prng_internal.PRNGKeyArray) + self.assertIsInstance(out_key, jax_random.PRNGKeyArray) self.assertArraysEqual(random.key_data(out_key), out_arr) @parameterized.parameters([