Skip to content

Commit

Permalink
haiku: fully support JAX typed PRNG keys
Browse files Browse the repository at this point in the history
For more details, see google/jax#17297. Previously, we had imagined a world where the jax_enable_custom_prng flag globally determined the presence of typed keys. This proved untenable for a number of reasons. Going forward, old-style and new-style keys are expected to exist side-by-side regardless of the value of `jax_enable_custom_prng`, which will soon be deprecated. Eventually old-style keys will also be deprecated and removed.

PiperOrigin-RevId: 565694585
  • Loading branch information
Jake VanderPlas authored and Copybara-Service committed Sep 15, 2023
1 parent cbcbbaa commit 8a001e6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 18 deletions.
17 changes: 12 additions & 5 deletions haiku/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,17 +972,24 @@ def assert_is_prng_key(key: PRNGKey):
# device-to-host copy.
make_error = lambda: ValueError( # pylint: disable=g-long-lambda
f"The provided key is not a JAX PRNGKey but a {type(key)}:\n{key}")
if jax.config.jax_enable_custom_prng:
if not isinstance(key, jax.random.KeyArray):
raise make_error()

if not hasattr(key, "shape") or not hasattr(key, "dtype"):
raise make_error()

if hasattr(jax.dtypes, "prng_key"): # JAX 0.4.14 or newer
is_typed_prng = jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key)
elif hasattr(jax.random, "PRNGKeyArray"): # Older JAX versions
is_typed_prng = isinstance(key, jax.random.PRNGKeyArray)
else: # Shouldn't get here, but just in case...
is_typed_prng = False

if is_typed_prng:
if key.shape:
raise ValueError(
"Provided key did not have expected shape and/or dtype: "
f"expected=(shape=(), dtype={key.dtype}), "
f"actual=(shape={key.shape}, dtype={key.dtype})")
else:
if not hasattr(key, "shape") or not hasattr(key, "dtype"):
raise make_error()
config_hint = ""
default_impl = jax.random.default_prng_impl()
expected_shape = default_impl.key_shape
Expand Down
18 changes: 5 additions & 13 deletions haiku/_src/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,13 @@ def test_rbg_default_impl_invalid_key_shape(self):
):
init(key)

def test_invalid_key(self):
init, _ = transform.transform(base.next_rng_key)
with self.assertRaisesRegex(ValueError, "Init must be called with an RNG"):
init([1, 2])

class CustomRNGTest(parameterized.TestCase):

def setUp(self):
super().setUp()
jax.config.update("jax_enable_custom_prng", True)

def tearDown(self):
super().tearDown()
jax.config.update("jax_enable_custom_prng", False)
class CustomRNGTest(parameterized.TestCase):

def test_non_custom_key(self):
init, _ = transform.transform(base.next_rng_key)
Expand Down Expand Up @@ -116,11 +113,6 @@ def count_splits(_, num):
init(key)
self.assertEqual(count, 1)

def test_invalid_custom_key(self):
init, _ = transform.transform(base.next_rng_key)
with self.assertRaisesRegex(ValueError, "Init must be called with an RNG"):
init(jnp.ones((2,), dtype=jnp.uint32))


def split_for_n(key, n):
for _ in range(n):
Expand Down

0 comments on commit 8a001e6

Please sign in to comment.