From 53c4de477e2baa16abdf7236f12b1e9bdf79d4d6 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 19 Oct 2023 13:59:01 -0700 Subject: [PATCH] [random] deprecate jax.random.default_prng_impl() --- jax/experimental/jax2tf/tests/primitive_harness.py | 5 ++--- jax/random.py | 9 ++++++++- tests/api_test.py | 3 +-- tests/random_test.py | 2 +- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/jax/experimental/jax2tf/tests/primitive_harness.py b/jax/experimental/jax2tf/tests/primitive_harness.py index 9ede6cc32db0..b7832fd7e6a1 100644 --- a/jax/experimental/jax2tf/tests/primitive_harness.py +++ b/jax/experimental/jax2tf/tests/primitive_harness.py @@ -38,7 +38,6 @@ """ from collections.abc import Iterable, Sequence -import itertools import operator import os from functools import partial @@ -2676,9 +2675,9 @@ def _make_reducer_harness(prim, def wrap_and_split(): key = jax.random.key(42) if config.enable_custom_prng.value: - key = prng.random_wrap(key, impl=jax.random.default_prng_impl()) + key = jax.random.wrap_key_data(key) result = jax.random.split(key, 2) - return prng.random_unwrap(result) + return jax.random.key_data(result) define( "random_split", diff --git a/jax/random.py b/jax/random.py index ff95ac30dc53..729de54f9d03 100644 --- a/jax/random.py +++ b/jax/random.py @@ -143,7 +143,7 @@ cauchy as cauchy, chisquare as chisquare, choice as choice, - default_prng_impl as default_prng_impl, + default_prng_impl as _deprecated_default_prng_impl, dirichlet as dirichlet, double_sided_maxwell as double_sided_maxwell, exponential as exponential, @@ -225,12 +225,19 @@ "jax.random.threefry2x32_p is deprecated. Use jax.extend.random.threefry2x32_p.", _deprecated_threefry2x32_p, ), + # Added October 19. 2023 + "default_prng_impl": ( + "jax.random.default_prng_impl is deprecated. Typical uses can be replaced by " + "jax.random.key_impl(key), jax.eval_shape(jax.random.key, 0).dtype, or similar.", + _deprecated_default_prng_impl, + ), } import typing if typing.TYPE_CHECKING: PRNGKeyArray = typing.Any KeyArray = typing.Any + default_prng_impl = _deprecated_default_prng_impl threefry_2x32 = _deprecated_threefry_2x32 threefry2x32_p = _deprecated_threefry2x32_p threefry2x32_key = _deprecated_threefry2x32_key diff --git a/tests/api_test.py b/tests/api_test.py index 11bdf8e36b08..e5bae6d066a0 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -50,7 +50,6 @@ from jax._src import core from jax._src import custom_derivatives from jax._src import linear_util as lu -from jax._src import prng from jax._src import test_util as jtu from jax._src import xla_bridge from jax._src.ad_checkpoint import saved_residuals @@ -878,7 +877,7 @@ def test_omnistaging(self): def wrap(arr): arr = np.array(arr, dtype=np.uint32) if config.enable_custom_prng.value: - return prng.random_wrap(arr, impl=jax.random.default_prng_impl()) + return jax.random.wrap_key_data(arr) else: return arr diff --git a/tests/random_test.py b/tests/random_test.py index f878f96f9fa2..d32bdd2c8673 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -511,7 +511,7 @@ def test_prng_seeds_and_keys(self, seed, typ, jit, key, make_key): for name, impl in PRNG_IMPLS]) def test_default_prng_selection(self, make_key, name, impl): with jax.default_prng_impl(name): - self.assertIs(random.default_prng_impl(), impl) + self.assertIs(jax_random.default_prng_impl(), impl) key = make_key(42) self.check_key_has_impl(key, impl) k1, k2 = random.split(key, 2)