From 2bf9322ccccf20640c7262f13377f773cb93b43e Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Mon, 18 Sep 2023 14:06:53 -0700 Subject: [PATCH] move `wrap_key_data` to `jax.random` This is a fine function for the public API, rather than `jax.extend`. --- docs/jax.extend.rst | 1 - docs/jax.random.rst | 1 + jax/_src/extend/random.py | 23 ----------------------- jax/_src/random.py | 17 +++++++++++++++++ jax/extend/random.py | 4 ---- jax/random.py | 1 + tests/extend_test.py | 28 ---------------------------- tests/random_test.py | 25 +++++++++++++++++++++++++ 8 files changed, 44 insertions(+), 56 deletions(-) delete mode 100644 jax/_src/extend/random.py diff --git a/docs/jax.extend.rst b/docs/jax.extend.rst index f737e51ef081..e9894b1a1edb 100644 --- a/docs/jax.extend.rst +++ b/docs/jax.extend.rst @@ -35,4 +35,3 @@ threefry_prng_impl rbg_prng_impl unsafe_rbg_prng_impl - wrap_key_data diff --git a/docs/jax.random.rst b/docs/jax.random.rst index b22d5709cd0c..efe89679dce7 100644 --- a/docs/jax.random.rst +++ b/docs/jax.random.rst @@ -18,6 +18,7 @@ List of Available Functions PRNGKey key key_data + wrap_key_data ball bernoulli beta diff --git a/jax/_src/extend/random.py b/jax/_src/extend/random.py deleted file mode 100644 index f28be626dd45..000000000000 --- a/jax/_src/extend/random.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - -from jax._src import prng -from jax._src import random -from jax._src.typing import Array - -def wrap_key_data(key_bits_array: Array, *, impl: Optional[str] = None): - impl_obj = random.resolve_prng_impl(impl) - return prng.random_wrap(key_bits_array, impl=impl_obj) diff --git a/jax/_src/random.py b/jax/_src/random.py index 431fc307207e..bd4353b95214 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -275,9 +275,26 @@ def _key_data(keys: KeyArray) -> Array: return prng.random_unwrap(keys) def key_data(keys: KeyArray) -> Array: + """Recover the bits of key data underlying a PRNG key array.""" keys, _ = _check_prng_key(keys) return _key_data(keys) +def wrap_key_data(key_bits_array: Array, *, impl: Optional[str] = None): + """Wrap an array of key data bits into a PRNG key array. + + Args: + key_bits_array: a ``uint32`` array with trailing shape corresponding to + the key shape of the PRNG implementation specified by ``impl``. + impl: optional, specifies a PRNG implementation, as in ``random.key``. + + Returns: + A PRNG key array, whose dtype is a subdtype of ``jax.dtypes.prng_key`` + corresponding to ``impl``, and whose shape equals the leading shape + of ``key_bits_array.shape`` up to the key bit dimensions. + """ + impl_obj = resolve_prng_impl(impl) + return prng.random_wrap(key_bits_array, impl=impl_obj) + ### random samplers diff --git a/jax/extend/random.py b/jax/extend/random.py index ed73f1c6a1e9..080fcdd2600a 100644 --- a/jax/extend/random.py +++ b/jax/extend/random.py @@ -24,7 +24,3 @@ rbg_prng_impl as rbg_prng_impl, unsafe_rbg_prng_impl as unsafe_rbg_prng_impl, ) - -from jax._src.extend.random import ( - wrap_key_data as wrap_key_data, -) diff --git a/jax/random.py b/jax/random.py index ea176999ca65..6e0843a222bd 100644 --- a/jax/random.py +++ b/jax/random.py @@ -183,6 +183,7 @@ unsafe_rbg_key as unsafe_rbg_key, wald as wald, weibull_min as weibull_min, + wrap_key_data as wrap_key_data, ) diff --git a/tests/extend_test.py b/tests/extend_test.py index 31c3e76be975..3c63cf70eb3a 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -14,7 +14,6 @@ from absl.testing import absltest -import jax import jax.extend as jex from jax._src import linear_util @@ -46,32 +45,5 @@ def test_symbols(self): self.assertIs(jex.linear_util.wrap_init, linear_util.wrap_init) -class RandomTest(jtu.JaxTestCase): - def test_wrap_key_default(self): - key1 = jax.random.key(17) - data = jax.random.key_data(key1) - key2 = jex.random.wrap_key_data(data) - self.assertEqual(key1.dtype, key2.dtype) - self.assertArraysEqual(jax.random.key_data(key1), - jax.random.key_data(key2)) - - impl = config.jax_default_prng_impl - key3 = jex.random.wrap_key_data(data, impl=impl) - self.assertEqual(key1.dtype, key3.dtype) - self.assertArraysEqual(jax.random.key_data(key1), - jax.random.key_data(key3)) - - def test_wrap_key_explicit(self): - key1 = jax.random.key(17, impl='rbg') - data = jax.random.key_data(key1) - key2 = jex.random.wrap_key_data(data, impl='rbg') - self.assertEqual(key1.dtype, key2.dtype) - self.assertArraysEqual(jax.random.key_data(key1), - jax.random.key_data(key2)) - - key3 = jex.random.wrap_key_data(data, impl='unsafe_rbg') - self.assertNotEqual(key1.dtype, key3.dtype) - - if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/random_test.py b/tests/random_test.py index 19ba8395e7a9..37b9f20e04e2 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -2169,6 +2169,31 @@ def test_async(self): self.assertArraysEqual(key, key.block_until_ready()) self.assertIsNone(key.copy_to_host_async()) + def test_wrap_key_default(self): + key1 = jax.random.key(17) + data = jax.random.key_data(key1) + key2 = jax.random.wrap_key_data(data) + self.assertEqual(key1.dtype, key2.dtype) + self.assertArraysEqual(jax.random.key_data(key1), + jax.random.key_data(key2)) + + impl = config.jax_default_prng_impl + key3 = jax.random.wrap_key_data(data, impl=impl) + self.assertEqual(key1.dtype, key3.dtype) + self.assertArraysEqual(jax.random.key_data(key1), + jax.random.key_data(key3)) + + def test_wrap_key_explicit(self): + key1 = jax.random.key(17, impl='rbg') + data = jax.random.key_data(key1) + key2 = jax.random.wrap_key_data(data, impl='rbg') + self.assertEqual(key1.dtype, key2.dtype) + self.assertArraysEqual(jax.random.key_data(key1), + jax.random.key_data(key2)) + + key3 = jax.random.wrap_key_data(data, impl='unsafe_rbg') + self.assertNotEqual(key1.dtype, key3.dtype) + # TODO(frostig,mattjj): more polymorphic primitives tests