Skip to content

Commit

Permalink
[random] deprecate jax.random.default_prng_impl()
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Oct 19, 2023
1 parent 741b71f commit 53c4de4
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 7 deletions.
5 changes: 2 additions & 3 deletions jax/experimental/jax2tf/tests/primitive_harness.py
Expand Up @@ -38,7 +38,6 @@
"""

from collections.abc import Iterable, Sequence
import itertools
import operator
import os
from functools import partial
Expand Down Expand Up @@ -2676,9 +2675,9 @@ def _make_reducer_harness(prim,
def wrap_and_split():
key = jax.random.key(42)
if config.enable_custom_prng.value:
key = prng.random_wrap(key, impl=jax.random.default_prng_impl())
key = jax.random.wrap_key_data(key)
result = jax.random.split(key, 2)
return prng.random_unwrap(result)
return jax.random.key_data(result)

define(
"random_split",
Expand Down
9 changes: 8 additions & 1 deletion jax/random.py
Expand Up @@ -143,7 +143,7 @@
cauchy as cauchy,
chisquare as chisquare,
choice as choice,
default_prng_impl as default_prng_impl,
default_prng_impl as _deprecated_default_prng_impl,
dirichlet as dirichlet,
double_sided_maxwell as double_sided_maxwell,
exponential as exponential,
Expand Down Expand Up @@ -225,12 +225,19 @@
"jax.random.threefry2x32_p is deprecated. Use jax.extend.random.threefry2x32_p.",
_deprecated_threefry2x32_p,
),
# Added October 19. 2023
"default_prng_impl": (
"jax.random.default_prng_impl is deprecated. Typical uses can be replaced by "
"jax.random.key_impl(key), jax.eval_shape(jax.random.key, 0).dtype, or similar.",
_deprecated_default_prng_impl,
),
}

import typing
if typing.TYPE_CHECKING:
PRNGKeyArray = typing.Any
KeyArray = typing.Any
default_prng_impl = _deprecated_default_prng_impl
threefry_2x32 = _deprecated_threefry_2x32
threefry2x32_p = _deprecated_threefry2x32_p
threefry2x32_key = _deprecated_threefry2x32_key
Expand Down
3 changes: 1 addition & 2 deletions tests/api_test.py
Expand Up @@ -50,7 +50,6 @@
from jax._src import core
from jax._src import custom_derivatives
from jax._src import linear_util as lu
from jax._src import prng
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.ad_checkpoint import saved_residuals
Expand Down Expand Up @@ -878,7 +877,7 @@ def test_omnistaging(self):
def wrap(arr):
arr = np.array(arr, dtype=np.uint32)
if config.enable_custom_prng.value:
return prng.random_wrap(arr, impl=jax.random.default_prng_impl())
return jax.random.wrap_key_data(arr)
else:
return arr

Expand Down
2 changes: 1 addition & 1 deletion tests/random_test.py
Expand Up @@ -511,7 +511,7 @@ def test_prng_seeds_and_keys(self, seed, typ, jit, key, make_key):
for name, impl in PRNG_IMPLS])
def test_default_prng_selection(self, make_key, name, impl):
with jax.default_prng_impl(name):
self.assertIs(random.default_prng_impl(), impl)
self.assertIs(jax_random.default_prng_impl(), impl)
key = make_key(42)
self.check_key_has_impl(key, impl)
k1, k2 = random.split(key, 2)
Expand Down

0 comments on commit 53c4de4

Please sign in to comment.