diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 59b482179e21..5963bffd7c28 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -14,7 +14,7 @@ from __future__ import annotations import abc -from collections.abc import Hashable, Iterator, Sequence +from collections.abc import Iterator, Sequence from functools import partial, reduce import math import operator as op @@ -95,6 +95,7 @@ class PRNGImpl(NamedTuple): split: Callable random_bits: Callable fold_in: Callable + name: str = '' tag: str = '?' def __hash__(self) -> int: @@ -104,12 +105,21 @@ def __str__(self) -> str: return self.tag def pprint(self): - return (pp.text(f"{self.__class__.__name__} [{self.tag}]:") + + ty = self.__class__.__name__ + return (pp.text(f"{ty} [{self.tag}] {{{self.name}}}:") + pp.nest(2, pp.group(pp.brk() + pp.join(pp.brk(), [ pp.text(f"{k} = {v}") for k, v in self._asdict().items() ])))) +prngs = {} + +def register_prng(impl: PRNGImpl): + if impl.name in prngs: + raise ValueError(f'PRNG with name {impl.name} already registered: {impl}') + prngs[impl.name] = impl + + # -- PRNG key arrays def _check_prng_key_data(impl, key_data: typing.Array): @@ -248,6 +258,7 @@ class behave like an array whose base elements are keys, hiding the ``random_bits``, ``fold_in``). """ + # TODO(frostig,vanderplas): hide impl attribute impl: PRNGImpl _base_array: typing.Array @@ -593,7 +604,8 @@ def device_put_replicated(val, aval, sharding, devices): class KeyTy(dtypes.ExtendedDType): - impl: Hashable # prng.PRNGImpl. TODO(mattjj,frostig): protocol really + # TODO(frostig,vanderplas): hide impl attribute + impl: PRNGImpl # TODO(mattjj,frostig): protocol really _rules = KeyTyRules type = dtypes.prng_key @@ -1366,8 +1378,11 @@ def _threefry_random_bits_original(key: typing.Array, bit_width, shape): split=threefry_split, random_bits=threefry_random_bits, fold_in=threefry_fold_in, + name='threefry2x32', tag='fry') +register_prng(threefry_prng_impl) + # -- RngBitGenerator PRNG implementation @@ -1411,8 +1426,12 @@ def _rbg_random_bits(key: typing.Array, bit_width: int, shape: Sequence[int] split=_rbg_split, random_bits=_rbg_random_bits, fold_in=_rbg_fold_in, + name='rbg', tag='rbg') +register_prng(rbg_prng_impl) + + def _unsafe_rbg_split(key: typing.Array, shape: Shape) -> typing.Array: # treat 10 iterations of random bits as a 'hash function' num = math.prod(shape) @@ -1431,4 +1450,7 @@ def _unsafe_rbg_fold_in(key: typing.Array, data: typing.Array) -> typing.Array: split=_unsafe_rbg_split, random_bits=_rbg_random_bits, fold_in=_unsafe_rbg_fold_in, + name='unsafe_rbg', tag='urbg') + +register_prng(unsafe_rbg_prng_impl) diff --git a/jax/_src/random.py b/jax/_src/random.py index 463472b2f477..d5ed72a80341 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -16,7 +16,8 @@ from functools import partial import math from operator import index -from typing import Optional, Union +import typing +from typing import Hashable, Optional, Union import warnings import numpy as np @@ -52,7 +53,9 @@ DTypeLikeFloat = DTypeLike Shape = Sequence[int] -# TODO(frostig): simplify once we always enable_custom_prng +PRNGImpl = prng.PRNGImpl + +# TODO(frostig,vanderplas): remove after deprecation window KeyArray = Union[Array, prng.PRNGKeyArray] PRNGKeyArray = prng.PRNGKeyArray @@ -109,35 +112,64 @@ def _random_bits(key: prng.PRNGKeyArray, bit_width, shape) -> Array: return prng.random_bits(key, bit_width=bit_width, shape=shape) -PRNG_IMPLS = { - 'threefry2x32': prng.threefry_prng_impl, - 'rbg': prng.rbg_prng_impl, - 'unsafe_rbg': prng.unsafe_rbg_prng_impl, -} - +# TODO(frostig,vanderplas): remove from public API altogether, or at +# least change to return after asserting presence in `prng.prngs` def default_prng_impl(): """Get the default PRNG implementation. The default implementation is determined by ``config.jax_default_prng_impl``, - which specifies it by name. This function returns the corresponding - ``jax.prng.PRNGImpl`` instance. + which specifies it by name. """ impl_name = config.jax_default_prng_impl - assert impl_name in PRNG_IMPLS, impl_name - return PRNG_IMPLS[impl_name] + assert impl_name in prng.prngs, impl_name + return prng.prngs[impl_name] ### key operations -def resolve_prng_impl(impl_spec: Optional[str]): +# Wrapper around prng.PRNGImpl meant to hide its attributes from the +# public API. +# TODO(frostig,vanderplas): consider hiding all the attributes of +# PRNGImpl and directly returning it. +class PRNGSpec: + """Specifies a PRNG key implementation.""" + + __slots__ = ['_impl'] + _impl: PRNGImpl + + def __init__(self, impl): + self._impl = impl + + def __str__(self) -> str: return str(self._impl) + def __hash__(self) -> int: return hash(self._impl) + + def __eq__(self, other) -> bool: + return self._impl == other._impl + + +def resolve_prng_impl( + impl_spec: Optional[Union[str, PRNGSpec, PRNGImpl]]) -> PRNGImpl: if impl_spec is None: return default_prng_impl() - if impl_spec in PRNG_IMPLS: - return PRNG_IMPLS[impl_spec] + if type(impl_spec) is PRNGImpl: + # TODO(frostig,vanderplas): remove this case once we remove + # default_prng_impl (and thus PRNGImpl) from the public API and + # PRNGImpl from jex. We won't need to handle these then, and we + # can remove them from the input type annotation above as well. + return impl_spec + if type(impl_spec) is PRNGSpec: + return impl_spec._impl + if type(impl_spec) is str: + if impl_spec in prng.prngs: + return prng.prngs[impl_spec] + + keys_fmt = ', '.join(f'"{s}"' for s in prng.prngs.keys()) + raise ValueError(f'unrecognized PRNG implementation "{impl_spec}". ' + f'Did you mean one of: {keys_fmt}?') + + t = type(impl_spec) + raise TypeError(f'unrecognized type {t} for specifying PRNG implementation.') - keys_fmt = ', '.join(f'"{s}"' for s in PRNG_IMPLS.keys()) - raise ValueError(f'unrecognized PRNG implementation "{impl_spec}". ' - f'Did you mean one of: {keys_fmt}?') def _key(ctor_name: str, seed: Union[int, Array], impl_spec: Optional[str] ) -> PRNGKeyArray: @@ -189,6 +221,7 @@ def PRNGKey(seed: Union[int, Array], *, """ return _return_prng_keys(True, _key('PRNGKey', seed, impl)) + # TODO(frostig): remove once we always enable_custom_prng def _check_default_impl_with_no_custom_prng(impl, name): default_impl = default_prng_impl() @@ -219,6 +252,7 @@ def unsafe_rbg_key(seed: int) -> KeyArray: key = prng.seed_with_impl(impl, seed) return _return_prng_keys(True, key) + def _fold_in(key: KeyArray, data: IntegerArray) -> KeyArray: # Alternative to fold_in() to use within random samplers. # TODO(frostig): remove and use fold_in() once we always enable_custom_prng @@ -245,6 +279,7 @@ def fold_in(key: KeyArray, data: IntegerArray) -> KeyArray: key, wrapped = _check_prng_key(key) return _return_prng_keys(wrapped, _fold_in(key, data)) + def _split(key: KeyArray, num: Union[int, tuple[int, ...]] = 2) -> KeyArray: # Alternative to split() to use within random samplers. # TODO(frostig): remove and use split(); we no longer need to wait @@ -270,6 +305,17 @@ def split(key: KeyArray, num: Union[int, tuple[int, ...]] = 2) -> KeyArray: key, wrapped = _check_prng_key(key) return _return_prng_keys(wrapped, _split(key, num)) + +def _key_impl(keys: KeyArray) -> PRNGImpl: + assert jnp.issubdtype(keys.dtype, dtypes.prng_key) + keys_dtype = typing.cast(prng.KeyTy, keys.dtype) + return keys_dtype.impl + +def key_impl(keys: KeyArray) -> Hashable: + keys, _ = _check_prng_key(keys) + return PRNGSpec(_key_impl(keys)) + + def _key_data(keys: KeyArray) -> Array: assert jnp.issubdtype(keys.dtype, dtypes.prng_key) return prng.random_unwrap(keys) @@ -279,6 +325,7 @@ def key_data(keys: KeyArray) -> Array: keys, _ = _check_prng_key(keys) return _key_data(keys) + def wrap_key_data(key_bits_array: Array, *, impl: Optional[str] = None): """Wrap an array of key data bits into a PRNG key array. diff --git a/jax/extend/random.py b/jax/extend/random.py index 080fcdd2600a..7b1e2ed1e4e0 100644 --- a/jax/extend/random.py +++ b/jax/extend/random.py @@ -16,6 +16,9 @@ # See PEP 484 & https://github.com/google/jax/issues/7570 from jax._src.prng import ( + # TODO(frostig,vanderplas): expose a define_prng_impl instead of the + # PRNGImpl constructor, to leave some room for us to register or check input, + # or to change what output type we return. PRNGImpl as PRNGImpl, seed_with_impl as seed_with_impl, threefry2x32_p as threefry2x32_p, diff --git a/jax/random.py b/jax/random.py index 4c6c9d860611..b4959ce1990e 100644 --- a/jax/random.py +++ b/jax/random.py @@ -155,6 +155,7 @@ gumbel as gumbel, key as key, key_data as key_data, + key_impl as key_impl, laplace as laplace, logistic as logistic, loggamma as loggamma, diff --git a/tests/random_test.py b/tests/random_test.py index 69fb10d9815b..4345a31de3b7 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -41,9 +41,8 @@ from jax import config config.parse_flags_with_absl() -PRNG_IMPLS = [('threefry2x32', prng_internal.threefry_prng_impl), - ('rbg', prng_internal.rbg_prng_impl), - ('unsafe_rbg', prng_internal.unsafe_rbg_prng_impl)] + +PRNG_IMPLS = list(prng_internal.prngs.items()) class OnX64(enum.Enum): @@ -189,6 +188,14 @@ def check_key_has_impl(self, key, impl): self.assertEqual(key.dtype, jnp.dtype('uint32')) self.assertEqual(key.shape, impl.key_shape) + def test_config_prngs_registered(self): + # TODO(frostig): pull these string values somehow from the + # jax_default_prng_impl config enum state definition directly, + # rather than copying manually here? + self.assertIn('threefry2x32', prng_internal.prngs) + self.assertIn('rbg', prng_internal.prngs) + self.assertIn('unsafe_rbg', prng_internal.prngs) + def testThreefry2x32(self): # We test the hash by comparing to known values provided in the test code of # the original reference implementation of Threefry. For the values, see @@ -1030,6 +1037,8 @@ def test_async(self): self.assertArraysEqual(key, key.block_until_ready()) self.assertIsNone(key.copy_to_host_async()) + # -- key construction and un/wrapping with impls + def test_wrap_key_default(self): key1 = jax.random.key(17) data = jax.random.key_data(key1) @@ -1055,6 +1064,37 @@ def test_wrap_key_explicit(self): key3 = jax.random.wrap_key_data(data, impl='unsafe_rbg') self.assertNotEqual(key1.dtype, key3.dtype) + @jtu.sample_product(prng_name=[name for name, _ in PRNG_IMPLS]) + def test_key_make_like_other_key(self, prng_name): + # start by specifying the implementation by string name, then + # round trip via whatever `key_impl` outputs + k1 = jax.random.key(42, impl=prng_name) + impl = jax.random.key_impl(k1) + k2 = jax.random.key(42, impl=impl) + self.assertArraysEqual(k1, k2) + self.assertEqual(k1.dtype, k2.dtype) + + @jtu.sample_product(prng_name=[name for name, _ in PRNG_IMPLS]) + def test_key_wrap_like_other_key(self, prng_name): + # start by specifying the implementation by string name, then + # round trip via whatever `key_impl` outputs + k1 = jax.random.key(42, impl=prng_name) + data = jax.random.key_data(k1) + impl = jax.random.key_impl(k1) + k2 = jax.random.wrap_key_data(data, impl=impl) + self.assertArraysEqual(k1, k2) + self.assertEqual(k1.dtype, k2.dtype) + + def test_key_impl_from_string_error(self): + with self.assertRaisesRegex(ValueError, 'unrecognized PRNG implementation'): + jax.random.key(42, impl='unlikely name') + + def test_key_impl_from_object_error(self): + class A: pass + + with self.assertRaisesRegex(TypeError, 'unrecognized type .* PRNG'): + jax.random.key(42, impl=A()) + # TODO(frostig,mattjj): more polymorphic primitives tests