Skip to content

Commit

Permalink
move wrap_key_data to jax.random
Browse files Browse the repository at this point in the history
This is a fine function for the public API, rather than `jax.extend`.
  • Loading branch information
froystig committed Sep 18, 2023
1 parent 0e3b12d commit 2bf9322
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 56 deletions.
1 change: 0 additions & 1 deletion docs/jax.extend.rst
Expand Up @@ -35,4 +35,3 @@
threefry_prng_impl
rbg_prng_impl
unsafe_rbg_prng_impl
wrap_key_data
1 change: 1 addition & 0 deletions docs/jax.random.rst
Expand Up @@ -18,6 +18,7 @@ List of Available Functions
PRNGKey
key
key_data
wrap_key_data
ball
bernoulli
beta
Expand Down
23 changes: 0 additions & 23 deletions jax/_src/extend/random.py

This file was deleted.

17 changes: 17 additions & 0 deletions jax/_src/random.py
Expand Up @@ -275,9 +275,26 @@ def _key_data(keys: KeyArray) -> Array:
return prng.random_unwrap(keys)

def key_data(keys: KeyArray) -> Array:
"""Recover the bits of key data underlying a PRNG key array."""
keys, _ = _check_prng_key(keys)
return _key_data(keys)

def wrap_key_data(key_bits_array: Array, *, impl: Optional[str] = None):
"""Wrap an array of key data bits into a PRNG key array.
Args:
key_bits_array: a ``uint32`` array with trailing shape corresponding to
the key shape of the PRNG implementation specified by ``impl``.
impl: optional, specifies a PRNG implementation, as in ``random.key``.
Returns:
A PRNG key array, whose dtype is a subdtype of ``jax.dtypes.prng_key``
corresponding to ``impl``, and whose shape equals the leading shape
of ``key_bits_array.shape`` up to the key bit dimensions.
"""
impl_obj = resolve_prng_impl(impl)
return prng.random_wrap(key_bits_array, impl=impl_obj)


### random samplers

Expand Down
4 changes: 0 additions & 4 deletions jax/extend/random.py
Expand Up @@ -24,7 +24,3 @@
rbg_prng_impl as rbg_prng_impl,
unsafe_rbg_prng_impl as unsafe_rbg_prng_impl,
)

from jax._src.extend.random import (
wrap_key_data as wrap_key_data,
)
1 change: 1 addition & 0 deletions jax/random.py
Expand Up @@ -183,6 +183,7 @@
unsafe_rbg_key as unsafe_rbg_key,
wald as wald,
weibull_min as weibull_min,
wrap_key_data as wrap_key_data,
)


Expand Down
28 changes: 0 additions & 28 deletions tests/extend_test.py
Expand Up @@ -14,7 +14,6 @@

from absl.testing import absltest

import jax
import jax.extend as jex

from jax._src import linear_util
Expand Down Expand Up @@ -46,32 +45,5 @@ def test_symbols(self):
self.assertIs(jex.linear_util.wrap_init, linear_util.wrap_init)


class RandomTest(jtu.JaxTestCase):
def test_wrap_key_default(self):
key1 = jax.random.key(17)
data = jax.random.key_data(key1)
key2 = jex.random.wrap_key_data(data)
self.assertEqual(key1.dtype, key2.dtype)
self.assertArraysEqual(jax.random.key_data(key1),
jax.random.key_data(key2))

impl = config.jax_default_prng_impl
key3 = jex.random.wrap_key_data(data, impl=impl)
self.assertEqual(key1.dtype, key3.dtype)
self.assertArraysEqual(jax.random.key_data(key1),
jax.random.key_data(key3))

def test_wrap_key_explicit(self):
key1 = jax.random.key(17, impl='rbg')
data = jax.random.key_data(key1)
key2 = jex.random.wrap_key_data(data, impl='rbg')
self.assertEqual(key1.dtype, key2.dtype)
self.assertArraysEqual(jax.random.key_data(key1),
jax.random.key_data(key2))

key3 = jex.random.wrap_key_data(data, impl='unsafe_rbg')
self.assertNotEqual(key1.dtype, key3.dtype)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())
25 changes: 25 additions & 0 deletions tests/random_test.py
Expand Up @@ -2169,6 +2169,31 @@ def test_async(self):
self.assertArraysEqual(key, key.block_until_ready())
self.assertIsNone(key.copy_to_host_async())

def test_wrap_key_default(self):
key1 = jax.random.key(17)
data = jax.random.key_data(key1)
key2 = jax.random.wrap_key_data(data)
self.assertEqual(key1.dtype, key2.dtype)
self.assertArraysEqual(jax.random.key_data(key1),
jax.random.key_data(key2))

impl = config.jax_default_prng_impl
key3 = jax.random.wrap_key_data(data, impl=impl)
self.assertEqual(key1.dtype, key3.dtype)
self.assertArraysEqual(jax.random.key_data(key1),
jax.random.key_data(key3))

def test_wrap_key_explicit(self):
key1 = jax.random.key(17, impl='rbg')
data = jax.random.key_data(key1)
key2 = jax.random.wrap_key_data(data, impl='rbg')
self.assertEqual(key1.dtype, key2.dtype)
self.assertArraysEqual(jax.random.key_data(key1),
jax.random.key_data(key2))

key3 = jax.random.wrap_key_data(data, impl='unsafe_rbg')
self.assertNotEqual(key1.dtype, key3.dtype)

# TODO(frostig,mattjj): more polymorphic primitives tests


Expand Down

0 comments on commit 2bf9322

Please sign in to comment.