Skip to content

Commit

Permalink
[random] deprecate named key creation functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Sep 21, 2023
1 parent 6a551a1 commit 22818d6
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 15 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Expand Up @@ -10,6 +10,11 @@ Remember to align the itemized text with the first line of an item within a list

* Deprecations
* Removed the deprecated module `jax.abstract_arrays` and all its contents.
* Named key constructors in {mod}`jax.random` are deprecated. Pass the `impl` argument
to {func}`jax.random.PRNGKey` or {func}`jax.random.key` instead:
* `random.threefry2x32_key(seed)` becomes `random.PRNGKey(seed, impl='threefry2x32')`
* `random.rbg_key(seed)` becomes `random.PRNGKey(seed, impl='rbg')`
* `random.unsafe_rbg_key(seed)` becomes `random.PRNGKey(seed, impl='unsafe_rbg')`

# jaxlib 0.4.17

Expand Down
4 changes: 2 additions & 2 deletions jax/_src/random.py
Expand Up @@ -195,8 +195,8 @@ def _check_default_impl_with_no_custom_prng(impl, name):
default_name = config.jax_default_prng_impl
if not config.jax_enable_custom_prng and default_impl is not impl:
raise RuntimeError('jax_enable_custom_prng must be enabled in order '
f'to seed an RNG with an implementation "f{name}" '
f'differing from the default "f{default_name}".')
f'to seed an RNG with an implementation "{name}" '
f'differing from the default "{default_name}".')

def threefry2x32_key(seed: int) -> KeyArray:
"""Creates a threefry2x32 PRNG key from an integer seed."""
Expand Down
19 changes: 16 additions & 3 deletions jax/random.py
Expand Up @@ -170,17 +170,17 @@
randint as randint,
random_gamma_p as random_gamma_p,
rayleigh as rayleigh,
rbg_key as rbg_key,
rbg_key as _deprecated_rbg_key,
shuffle as shuffle,
split as split,
t as t,
threefry_2x32 as threefry_2x32,
threefry2x32_key as threefry2x32_key,
threefry2x32_key as _deprecated_threefry2x32_key,
threefry2x32_p as threefry2x32_p,
triangular as triangular,
truncated_normal as truncated_normal,
uniform as uniform,
unsafe_rbg_key as unsafe_rbg_key,
unsafe_rbg_key as _deprecated_unsafe_rbg_key,
wald as wald,
weibull_min as weibull_min,
wrap_key_data as wrap_key_data,
Expand All @@ -202,12 +202,25 @@
"jax.dtypes.issubdtype(arr, jax.dtypes.prng_key) for runtime detection of "
"typed prng keys.", _PRNGKeyArray
),
# Added September 21, 2023
"threefry2x32_key": (
"jax.random.threefry2x32_key(seed) is deprecated. "
"Use jax.random.PRNGKey(seed, 'threefry2x32')", _deprecated_threefry2x32_key),
"rbg_key": (
"jax.random.rbg_key(seed) is deprecated. "
"Use jax.random.PRNGKey(seed, 'rbg')", _deprecated_rbg_key),
"unsafe_rbg_key": (
"jax.random.unsafe_rbg_key(seed) is deprecated. "
"Use jax.random.PRNGKey(seed, 'unsafe_rbg')", _deprecated_unsafe_rbg_key),
}

import typing
if typing.TYPE_CHECKING:
PRNGKeyArray = typing.Any
KeyArray = typing.Any
threefry_2x32_key = _deprecated_threefry2x32_key
rbg_key = _deprecated_rbg_key
unsafe_rbg_key = _deprecated_unsafe_rbg_key
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
Expand Down
23 changes: 13 additions & 10 deletions tests/random_test.py
Expand Up @@ -522,18 +522,21 @@ def test_default_prng_selection(self, make_key, name, impl):

@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
def test_explicit_threefry2x32_key(self):
self.check_key_has_impl(random.threefry2x32_key(42),
prng_internal.threefry_prng_impl)
with self.assertWarnsRegex(DeprecationWarning, "jax.random.threefry2x32_key"):
self.check_key_has_impl(random.threefry2x32_key(42),
prng_internal.threefry_prng_impl)

@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
def test_explicit_rbg_key(self):
self.check_key_has_impl(random.rbg_key(42),
prng_internal.rbg_prng_impl)
with self.assertWarnsRegex(DeprecationWarning, "jax.random.rbg_key"):
self.check_key_has_impl(random.rbg_key(42),
prng_internal.rbg_prng_impl)

@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
def test_explicit_unsafe_rbg_key(self):
self.check_key_has_impl(random.unsafe_rbg_key(42),
prng_internal.unsafe_rbg_prng_impl)
with self.assertWarnsRegex(DeprecationWarning, "jax.random.unsafe_rbg_key"):
self.check_key_has_impl(random.unsafe_rbg_key(42),
prng_internal.unsafe_rbg_prng_impl)

@parameterized.parameters([{'make_key': ctor, 'name': name, 'impl': impl}
for ctor in KEY_CTORS
Expand Down Expand Up @@ -579,7 +582,7 @@ def test_legacy_prng_key_flag(self):

class ThreefryPrngTest(jtu.JaxTestCase):
@parameterized.parameters([{'make_key': ctor} for ctor in [
random.threefry2x32_key,
jax_random.threefry2x32_key,
partial(random.PRNGKey, impl='threefry2x32'),
partial(random.key, impl='threefry2x32')]])
def test_seed_no_implicit_transfers(self, make_key):
Expand Down Expand Up @@ -640,7 +643,7 @@ def _CheckChiSquared(self, samples, pmf):
f'{expected_freq}\n{actual_freq}')

def make_key(self, seed):
return random.threefry2x32_key(seed)
return random.PRNGKey(seed, impl='threefry2x32')

@jtu.sample_product(
num=(None, 6, (6,), (2, 3), (2, 3, 4)),
Expand Down Expand Up @@ -2296,7 +2299,7 @@ def test_grad_of_prng_key(self):
@jtu.with_config(jax_default_prng_impl='rbg')
class LaxRandomWithRBGPRNGTest(LaxRandomTest):
def make_key(self, seed):
return random.rbg_key(seed)
return random.PRNGKey(seed, impl='rbg')

def test_split_shape(self):
key = self.make_key(73)
Expand Down Expand Up @@ -2372,7 +2375,7 @@ def test_randint_out_of_range(self):
@jtu.with_config(jax_default_prng_impl='unsafe_rbg')
class LaxRandomWithUnsafeRBGPRNGTest(LaxRandomWithRBGPRNGTest):
def make_key(self, seed):
return random.unsafe_rbg_key(seed)
return random.PRNGKey(seed, impl="unsafe_rbg")


def _sampler_unimplemented_with_custom_prng(*args, **kwargs):
Expand Down

0 comments on commit 22818d6

Please sign in to comment.