From b865827d06500680f19de46917b804dc8cd7348f Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 18 Oct 2023 14:42:49 -0700 Subject: [PATCH] [random] deprecate jax.random.threefry_2x32 & threefry2x32_p --- jax/_src/random.py | 9 --------- jax/random.py | 18 ++++++++++++++++-- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index 4b0ac00e3d66..dab08bbf5bba 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -2056,15 +2056,6 @@ def _weibull_min(key, scale, concentration, shape, dtype) -> Array: return jnp.power(-jnp.log1p(-random_uniform), 1.0/concentration) * scale -# TODO(frostig): remove these aliases - -threefry2x32_p = prng.threefry2x32_p - -def threefry_2x32(keypair, count): - warnings.warn('jax.random.threefry_2x32 has moved to jax.prng.threefry_2x32 ' - 'and will be removed from `random` module.', FutureWarning) - return prng.threefry_2x32(keypair, count) - def orthogonal( key: KeyArray, n: int, diff --git a/jax/random.py b/jax/random.py index b4959ce1990e..afe52bc5063c 100644 --- a/jax/random.py +++ b/jax/random.py @@ -175,9 +175,7 @@ shuffle as shuffle, split as split, t as t, - threefry_2x32 as threefry_2x32, threefry2x32_key as _deprecated_threefry2x32_key, - threefry2x32_p as threefry2x32_p, triangular as triangular, truncated_normal as truncated_normal, uniform as uniform, @@ -187,6 +185,11 @@ wrap_key_data as wrap_key_data, ) +from jax._src.prng import ( + threefry_2x32 as _deprecated_threefry_2x32, + threefry2x32_p as _deprecated_threefry2x32_p, +) + # Deprecations from jax._src.prng import PRNGKeyArray as _PRNGKeyArray @@ -213,12 +216,23 @@ "unsafe_rbg_key": ( "jax.random.unsafe_rbg_key(seed) is deprecated. " "Use jax.random.PRNGKey(seed, 'unsafe_rbg')", _deprecated_unsafe_rbg_key), + # Added October 18, 2023 + "threefry_2x32": ( # Note: this has been raising a FutureWarning since 2021 + "jax.random.threefry_2x32 is deprecated. Use jax.extend.random.threefry_2x32.", + _deprecated_threefry_2x32, + ), + "threefry2x32_p": ( + "jax.random.threefry2x32_p is deprecated. Use jax.extend.random.threefry2x32_p.", + _deprecated_threefry2x32_p, + ), } import typing if typing.TYPE_CHECKING: PRNGKeyArray = typing.Any KeyArray = typing.Any + threefry_2x32 = _deprecated_threefry_2x32 + threefry_2x32_p = _deprecated_threefry2x32_p threefry2x32_key = _deprecated_threefry2x32_key rbg_key = _deprecated_rbg_key unsafe_rbg_key = _deprecated_unsafe_rbg_key