Skip to content

Commit

Permalink
Rollback to fix internal breakage
Browse files Browse the repository at this point in the history
Reverts 7d203ae

PiperOrigin-RevId: 574101804
  • Loading branch information
jax authors committed Oct 17, 2023
1 parent 2604d0c commit 2be6019
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 31 deletions.
23 changes: 15 additions & 8 deletions jax/_src/random.py
Expand Up @@ -53,7 +53,10 @@
Shape = Sequence[int]

PRNGImpl = prng.PRNGImpl
KeyArray = Array

# TODO(frostig,vanderplas): remove after deprecation window
KeyArray = Union[Array, prng.PRNGKeyArray]
PRNGKeyArray = prng.PRNGKeyArray

UINT_DTYPES = prng.UINT_DTYPES

Expand All @@ -66,8 +69,9 @@ def _isnan(x: ArrayLike) -> Array:
return lax.ne(x, x)


def _check_prng_key(key) -> tuple[KeyArray, bool]:
if isinstance(key, Array) and dtypes.issubdtype(key.dtype, dtypes.prng_key):
def _check_prng_key(key) -> tuple[prng.PRNGKeyArray, bool]:
# TODO(frostig): remove once we always enable_custom_prng
if isinstance(key, prng.PRNGKeyArray):
return key, False
elif _arraylike(key):
# Call random_wrap here to surface errors for invalid keys.
Expand Down Expand Up @@ -102,7 +106,7 @@ def _return_prng_keys(was_wrapped, key):
return prng.random_unwrap(key) if was_wrapped else key


def _random_bits(key: KeyArray, bit_width, shape) -> Array:
def _random_bits(key: prng.PRNGKeyArray, bit_width, shape) -> Array:
assert jnp.issubdtype(key.dtype, dtypes.prng_key)
return prng.random_bits(key, bit_width=bit_width, shape=shape)

Expand Down Expand Up @@ -166,18 +170,20 @@ def resolve_prng_impl(
raise TypeError(f'unrecognized type {t} for specifying PRNG implementation.')


def _key(ctor_name: str, seed: Union[int, Array], impl_spec: Optional[str] ) -> KeyArray:
def _key(ctor_name: str, seed: Union[int, Array], impl_spec: Optional[str]
) -> PRNGKeyArray:
impl = resolve_prng_impl(impl_spec)
if hasattr(seed, 'dtype') and jnp.issubdtype(seed.dtype, dtypes.prng_key):
raise TypeError(
f"{ctor_name} accepts a scalar seed, but was given a PRNG key.")
f"{ctor_name} accepts a scalar seed, but was given a PRNGKeyArray.")
if np.ndim(seed):
raise TypeError(
f"{ctor_name} accepts a scalar seed, but was given an array of "
f"shape {np.shape(seed)} != (). Use jax.vmap for batching")
return prng.seed_with_impl(impl, seed)

def key(seed: Union[int, Array], *, impl: Optional[str] = None) -> KeyArray:
def key(seed: Union[int, Array], *,
impl: Optional[str] = None) -> PRNGKeyArray:
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
The result is a scalar array with a key that indicates the default PRNG
Expand All @@ -195,7 +201,8 @@ def key(seed: Union[int, Array], *, impl: Optional[str] = None) -> KeyArray:
"""
return _key('key', seed, impl)

def PRNGKey(seed: Union[int, Array], *, impl: Optional[str] = None) -> KeyArray:
def PRNGKey(seed: Union[int, Array], *,
impl: Optional[str] = None) -> KeyArray:
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
The resulting key carries the default PRNG implementation, as
Expand Down
46 changes: 23 additions & 23 deletions tests/random_test.py
Expand Up @@ -592,7 +592,7 @@ def assertKeysEqual(self, key1, key2):

def test_construction(self):
key = random.key(42)
self.assertIsInstance(key, prng_internal.PRNGKeyArray)
self.assertIsInstance(key, jax_random.PRNGKeyArray)

def test_issubdtype(self):
key = random.key(42)
Expand All @@ -610,7 +610,7 @@ def test_issubdtype(self):
@skipIf(not config.enable_custom_prng.value, 'relies on typed key upgrade flag')
def test_construction_upgrade_flag(self):
key = random.PRNGKey(42)
self.assertIsInstance(key, prng_internal.PRNGKeyArray)
self.assertIsInstance(key, jax_random.PRNGKeyArray)

def make_keys(self, *shape, seed=28):
seeds = seed + jnp.arange(math.prod(shape), dtype=jnp.uint32)
Expand Down Expand Up @@ -677,13 +677,13 @@ def test_key_copy(self):
def test_isinstance(self):
@jax.jit
def f(k):
self.assertIsInstance(k, prng_internal.PRNGKeyArray)
self.assertIsInstance(k, jax_random.PRNGKeyArray)
return k

k1 = self.make_keys()
k2 = f(k1)
self.assertIsInstance(k1, prng_internal.PRNGKeyArray)
self.assertIsInstance(k2, prng_internal.PRNGKeyArray)
self.assertIsInstance(k1, jax_random.PRNGKeyArray)
self.assertIsInstance(k2, jax_random.PRNGKeyArray)

def test_cpp_dispatch_normal(self):
# Ensure we stay on the C++ dispatch path when calling a jitted
Expand Down Expand Up @@ -748,10 +748,10 @@ def test_random_wrap_vmap(self):
f = partial(prng_internal.random_wrap, impl=prng_internal.threefry_prng_impl)
base_arr = jnp.arange(6, dtype=jnp.uint32).reshape(3, 2)
keys = jax.vmap(f, in_axes=0)(base_arr)
self.assertIsInstance(keys, prng_internal.PRNGKeyArray)
self.assertIsInstance(keys, jax_random.PRNGKeyArray)
self.assertEqual(keys.shape, (3,))
keys = jax.vmap(f, in_axes=1)(base_arr.T)
self.assertIsInstance(keys, prng_internal.PRNGKeyArray)
self.assertIsInstance(keys, jax_random.PRNGKeyArray)
self.assertEqual(keys.shape, (3,))

@jtu.sample_product(use_internal=[False, True])
Expand Down Expand Up @@ -848,72 +848,72 @@ def test_scan_lowering(self):
ks = self.make_keys(3, 4)
f = lambda ks: jax.lax.scan(lambda _, k: (None, k.T), None, ks)
_, out = jax.jit(f)(ks) # doesn't crash
self.assertIsInstance(out, prng_internal.PRNGKeyArray)
self.assertIsInstance(out, jax_random.PRNGKeyArray)
self.assertEqual(out.shape, (3, 4))

def test_slice(self):
ks = self.make_keys(3, 4)
ys = jax.jit(lambda x: lax.slice_in_dim(x, 1, 3))(ks)
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
self.assertEqual(ys.shape, (2, 4))

def test_dynamic_slice(self):
ks = self.make_keys(3, 4)
index = np.int16(1) # non-default int type to catch type errors.
ys = jax.jit(partial(lax.dynamic_slice_in_dim, slice_size=2))(ks, index)
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
self.assertEqual(ys.shape, (2, 4))

def test_dynamic_update_slice(self):
ks = self.make_keys(3, 4)
k = self.make_keys(1, 4)
index = np.int16(1) # non-default int type to catch type errors.
ys = jax.jit(partial(lax.dynamic_update_slice_in_dim, axis=0))(ks, k, index)
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
self.assertEqual(ys.shape, (3, 4))

def test_transpose(self):
ks = self.make_keys(3, 4)
ys = jax.jit(lambda x: x.T)(ks)
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
self.assertEqual(ys.shape, (4, 3))

def test_gather(self):
ks = self.make_keys(3, 4)
ys = jax.jit(lambda x: x[1])(ks)
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
self.assertEqual(ys.shape, (4,))

ks = self.make_keys(3, 4, 5)

ys = jax.jit(lambda x: x[1])(ks)
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
self.assertEqual(ys.shape, (4, 5))

ys = jax.jit(lambda x: x[1, 2:4])(ks)
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
self.assertEqual(ys.shape, (2, 5))

ys = jax.jit(lambda x: x[1, 2:4, 3])(ks)
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
self.assertEqual(ys.shape, (2,))

ys = jax.jit(lambda x: x[:, 2:4, 3:4])(ks)
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
self.assertEqual(ys.shape, (3, 2, 1))

def test_select(self):
ks = self.make_keys(3, 2)
cs = jnp.array([True, False, False, True, False, True]).reshape(3, 2)
ys = jax.jit(lax.select)(cs, ks, ks)
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
self.assertEqual(ys.shape, (3, 2))

def test_select_scalar_cond(self):
# regression test for https://github.com/google/jax/issues/16422
ks = self.make_keys(3)
ys = lax.select(True, ks, ks)
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
self.assertIsInstance(ys, jax_random.PRNGKeyArray)
self.assertEqual(ys.shape, (3,))

def test_vmap_of_cond(self):
Expand Down Expand Up @@ -986,7 +986,7 @@ def f_jvp(primals, tangents):
custom_result = jax.grad(f)(0.0, key)

self.assertAllClose(default_result, custom_result)
self.assertIsInstance(key_dot, prng_internal.PRNGKeyArray)
self.assertIsInstance(key_dot, jax_random.PRNGKeyArray)
self.assertArraysEqual(random.key_data(key_dot), np.uint32(0))

def test_key_array_indexing_0d(self):
Expand Down Expand Up @@ -1155,7 +1155,7 @@ def assertKeysEqual(self, key1, key2):
def check_shape(self, func, *args):
like = lambda keys: jnp.ones(keys.shape)
out_key = func(*args)
self.assertIsInstance(out_key, prng_internal.PRNGKeyArray)
self.assertIsInstance(out_key, jax_random.PRNGKeyArray)
out_like_key = func(*tree_util.tree_map(like, args))
self.assertIsInstance(out_like_key, jax.Array)
self.assertEqual(out_key.shape, out_like_key.shape)
Expand All @@ -1166,11 +1166,11 @@ def check_against_reference(self, key_func, arr_func, *key_args):
self.assertIsInstance(out_arr, jax.Array)

out_key = key_func(*key_args)
self.assertIsInstance(out_key, prng_internal.PRNGKeyArray)
self.assertIsInstance(out_key, jax_random.PRNGKeyArray)
self.assertArraysEqual(random.key_data(out_key), out_arr)

out_key = jax.jit(key_func)(*key_args)
self.assertIsInstance(out_key, prng_internal.PRNGKeyArray)
self.assertIsInstance(out_key, jax_random.PRNGKeyArray)
self.assertArraysEqual(random.key_data(out_key), out_arr)

@parameterized.parameters([
Expand Down

0 comments on commit 2be6019

Please sign in to comment.