Skip to content

Commit

Permalink
identify PRNG schemes on key arrays, and recognize them in key constr…
Browse files Browse the repository at this point in the history
…uctors

Specifically:

* Introduce `jax.random.key_impl`, which accepts a key array and
  returns a hashable identifier of its PRNG implementation.

* Accept this identifier optionally as the `impl` argument to
  `jax.random.key` and `wrap_key_data`.

This now works:

```python
k1 = jax.random.key(72, impl='threefry2x32')
impl = jax.random.key_impl(k1)
k2 = jax.random.key(72, impl=impl)
assert arrays_equal(k1, k2)
assert k1.dtype == k2.dtype
```

This change also set up an internal PRNG registry and register
built-in implementations, to simplify various places where we
essentially reconstruct such a registry from scratch (such as in
tests).

Co-authored-by: Jake Vanderplas <jakevdp@google.com>
  • Loading branch information
froystig and jakevdp committed Oct 6, 2023
1 parent 19b900d commit 5158e25
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 24 deletions.
28 changes: 25 additions & 3 deletions jax/_src/prng.py
Expand Up @@ -14,7 +14,7 @@
from __future__ import annotations

import abc
from collections.abc import Hashable, Iterator, Sequence
from collections.abc import Iterator, Sequence
from functools import partial, reduce
import math
import operator as op
Expand Down Expand Up @@ -95,6 +95,7 @@ class PRNGImpl(NamedTuple):
split: Callable
random_bits: Callable
fold_in: Callable
name: str = '<unnamed>'
tag: str = '?'

def __hash__(self) -> int:
Expand All @@ -104,12 +105,21 @@ def __str__(self) -> str:
return self.tag

def pprint(self):
return (pp.text(f"{self.__class__.__name__} [{self.tag}]:") +
ty = self.__class__.__name__
return (pp.text(f"{ty} [{self.tag}] {{{self.name}}}:") +
pp.nest(2, pp.group(pp.brk() + pp.join(pp.brk(), [
pp.text(f"{k} = {v}") for k, v in self._asdict().items()
]))))


prngs = {}

def register_prng(impl: PRNGImpl):
if impl.name in prngs:
raise ValueError(f'PRNG with name {impl.name} already registered: {impl}')
prngs[impl.name] = impl


# -- PRNG key arrays

def _check_prng_key_data(impl, key_data: typing.Array):
Expand Down Expand Up @@ -248,6 +258,7 @@ class behave like an array whose base elements are keys, hiding the
``random_bits``, ``fold_in``).
"""

# TODO(frostig,vanderplas): hide impl attribute
impl: PRNGImpl
_base_array: typing.Array

Expand Down Expand Up @@ -593,7 +604,8 @@ def device_put_replicated(val, aval, sharding, devices):


class KeyTy(dtypes.ExtendedDType):
impl: Hashable # prng.PRNGImpl. TODO(mattjj,frostig): protocol really
# TODO(frostig,vanderplas): hide impl attribute
impl: PRNGImpl # TODO(mattjj,frostig): protocol really
_rules = KeyTyRules
type = dtypes.prng_key

Expand Down Expand Up @@ -1366,8 +1378,11 @@ def _threefry_random_bits_original(key: typing.Array, bit_width, shape):
split=threefry_split,
random_bits=threefry_random_bits,
fold_in=threefry_fold_in,
name='threefry2x32',
tag='fry')

register_prng(threefry_prng_impl)


# -- RngBitGenerator PRNG implementation

Expand Down Expand Up @@ -1411,8 +1426,12 @@ def _rbg_random_bits(key: typing.Array, bit_width: int, shape: Sequence[int]
split=_rbg_split,
random_bits=_rbg_random_bits,
fold_in=_rbg_fold_in,
name='rbg',
tag='rbg')

register_prng(rbg_prng_impl)


def _unsafe_rbg_split(key: typing.Array, shape: Shape) -> typing.Array:
# treat 10 iterations of random bits as a 'hash function'
num = math.prod(shape)
Expand All @@ -1431,4 +1450,7 @@ def _unsafe_rbg_fold_in(key: typing.Array, data: typing.Array) -> typing.Array:
split=_unsafe_rbg_split,
random_bits=_rbg_random_bits,
fold_in=_unsafe_rbg_fold_in,
name='unsafe_rbg',
tag='urbg')

register_prng(unsafe_rbg_prng_impl)
83 changes: 65 additions & 18 deletions jax/_src/random.py
Expand Up @@ -16,7 +16,8 @@
from functools import partial
import math
from operator import index
from typing import Optional, Union
import typing
from typing import Hashable, Optional, Union
import warnings

import numpy as np
Expand Down Expand Up @@ -52,7 +53,9 @@
DTypeLikeFloat = DTypeLike
Shape = Sequence[int]

# TODO(frostig): simplify once we always enable_custom_prng
PRNGImpl = prng.PRNGImpl

# TODO(frostig,vanderplas): remove after deprecation window
KeyArray = Union[Array, prng.PRNGKeyArray]
PRNGKeyArray = prng.PRNGKeyArray

Expand Down Expand Up @@ -109,35 +112,64 @@ def _random_bits(key: prng.PRNGKeyArray, bit_width, shape) -> Array:
return prng.random_bits(key, bit_width=bit_width, shape=shape)


PRNG_IMPLS = {
'threefry2x32': prng.threefry_prng_impl,
'rbg': prng.rbg_prng_impl,
'unsafe_rbg': prng.unsafe_rbg_prng_impl,
}

# TODO(frostig,vanderplas): remove from public API altogether, or at
# least change to return after asserting presence in `prng.prngs`
def default_prng_impl():
"""Get the default PRNG implementation.
The default implementation is determined by ``config.jax_default_prng_impl``,
which specifies it by name. This function returns the corresponding
``jax.prng.PRNGImpl`` instance.
which specifies it by name.
"""
impl_name = config.jax_default_prng_impl
assert impl_name in PRNG_IMPLS, impl_name
return PRNG_IMPLS[impl_name]
assert impl_name in prng.prngs, impl_name
return prng.prngs[impl_name]


### key operations

def resolve_prng_impl(impl_spec: Optional[str]):
# Wrapper around prng.PRNGImpl meant to hide its attributes from the
# public API.
# TODO(frostig,vanderplas): consider hiding all the attributes of
# PRNGImpl and directly returning it.
class PRNGSpec:
"""Specifies a PRNG key implementation."""

__slots__ = ['_impl']
_impl: PRNGImpl

def __init__(self, impl):
self._impl = impl

def __str__(self) -> str: return str(self._impl)
def __hash__(self) -> int: return hash(self._impl)

def __eq__(self, other) -> bool:
return self._impl == other._impl


def resolve_prng_impl(
impl_spec: Optional[Union[str, PRNGSpec, PRNGImpl]]) -> PRNGImpl:
if impl_spec is None:
return default_prng_impl()
if impl_spec in PRNG_IMPLS:
return PRNG_IMPLS[impl_spec]
if type(impl_spec) is PRNGImpl:
# TODO(frostig,vanderplas): remove this case once we remove
# default_prng_impl (and thus PRNGImpl) from the public API and
# PRNGImpl from jex. We won't need to handle these then, and we
# can remove them from the input type annotation above as well.
return impl_spec
if type(impl_spec) is PRNGSpec:
return impl_spec._impl
if type(impl_spec) is str:
if impl_spec in prng.prngs:
return prng.prngs[impl_spec]

keys_fmt = ', '.join(f'"{s}"' for s in prng.prngs.keys())
raise ValueError(f'unrecognized PRNG implementation "{impl_spec}". '
f'Did you mean one of: {keys_fmt}?')

t = type(impl_spec)
raise TypeError(f'unrecognized type {t} for specifying PRNG implementation.')

keys_fmt = ', '.join(f'"{s}"' for s in PRNG_IMPLS.keys())
raise ValueError(f'unrecognized PRNG implementation "{impl_spec}". '
f'Did you mean one of: {keys_fmt}?')

def _key(ctor_name: str, seed: Union[int, Array], impl_spec: Optional[str]
) -> PRNGKeyArray:
Expand Down Expand Up @@ -189,6 +221,7 @@ def PRNGKey(seed: Union[int, Array], *,
"""
return _return_prng_keys(True, _key('PRNGKey', seed, impl))


# TODO(frostig): remove once we always enable_custom_prng
def _check_default_impl_with_no_custom_prng(impl, name):
default_impl = default_prng_impl()
Expand Down Expand Up @@ -219,6 +252,7 @@ def unsafe_rbg_key(seed: int) -> KeyArray:
key = prng.seed_with_impl(impl, seed)
return _return_prng_keys(True, key)


def _fold_in(key: KeyArray, data: IntegerArray) -> KeyArray:
# Alternative to fold_in() to use within random samplers.
# TODO(frostig): remove and use fold_in() once we always enable_custom_prng
Expand All @@ -245,6 +279,7 @@ def fold_in(key: KeyArray, data: IntegerArray) -> KeyArray:
key, wrapped = _check_prng_key(key)
return _return_prng_keys(wrapped, _fold_in(key, data))


def _split(key: KeyArray, num: Union[int, tuple[int, ...]] = 2) -> KeyArray:
# Alternative to split() to use within random samplers.
# TODO(frostig): remove and use split(); we no longer need to wait
Expand All @@ -270,6 +305,17 @@ def split(key: KeyArray, num: Union[int, tuple[int, ...]] = 2) -> KeyArray:
key, wrapped = _check_prng_key(key)
return _return_prng_keys(wrapped, _split(key, num))


def _key_impl(keys: KeyArray) -> PRNGImpl:
assert jnp.issubdtype(keys.dtype, dtypes.prng_key)
keys_dtype = typing.cast(prng.KeyTy, keys.dtype)
return keys_dtype.impl

def key_impl(keys: KeyArray) -> Hashable:
keys, _ = _check_prng_key(keys)
return PRNGSpec(_key_impl(keys))


def _key_data(keys: KeyArray) -> Array:
assert jnp.issubdtype(keys.dtype, dtypes.prng_key)
return prng.random_unwrap(keys)
Expand All @@ -279,6 +325,7 @@ def key_data(keys: KeyArray) -> Array:
keys, _ = _check_prng_key(keys)
return _key_data(keys)


def wrap_key_data(key_bits_array: Array, *, impl: Optional[str] = None):
"""Wrap an array of key data bits into a PRNG key array.
Expand Down
3 changes: 3 additions & 0 deletions jax/extend/random.py
Expand Up @@ -16,6 +16,9 @@
# See PEP 484 & https://github.com/google/jax/issues/7570

from jax._src.prng import (
# TODO(frostig,vanderplas): expose a define_prng_impl instead of the
# PRNGImpl constructor, to leave some room for us to register or check input,
# or to change what output type we return.
PRNGImpl as PRNGImpl,
seed_with_impl as seed_with_impl,
threefry2x32_p as threefry2x32_p,
Expand Down
1 change: 1 addition & 0 deletions jax/random.py
Expand Up @@ -155,6 +155,7 @@
gumbel as gumbel,
key as key,
key_data as key_data,
key_impl as key_impl,
laplace as laplace,
logistic as logistic,
loggamma as loggamma,
Expand Down
46 changes: 43 additions & 3 deletions tests/random_test.py
Expand Up @@ -41,9 +41,8 @@
from jax import config
config.parse_flags_with_absl()

PRNG_IMPLS = [('threefry2x32', prng_internal.threefry_prng_impl),
('rbg', prng_internal.rbg_prng_impl),
('unsafe_rbg', prng_internal.unsafe_rbg_prng_impl)]

PRNG_IMPLS = list(prng_internal.prngs.items())


class OnX64(enum.Enum):
Expand Down Expand Up @@ -189,6 +188,14 @@ def check_key_has_impl(self, key, impl):
self.assertEqual(key.dtype, jnp.dtype('uint32'))
self.assertEqual(key.shape, impl.key_shape)

def test_config_prngs_registered(self):
# TODO(frostig): pull these string values somehow from the
# jax_default_prng_impl config enum state definition directly,
# rather than copying manually here?
self.assertIn('threefry2x32', prng_internal.prngs)
self.assertIn('rbg', prng_internal.prngs)
self.assertIn('unsafe_rbg', prng_internal.prngs)

def testThreefry2x32(self):
# We test the hash by comparing to known values provided in the test code of
# the original reference implementation of Threefry. For the values, see
Expand Down Expand Up @@ -1030,6 +1037,8 @@ def test_async(self):
self.assertArraysEqual(key, key.block_until_ready())
self.assertIsNone(key.copy_to_host_async())

# -- key construction and un/wrapping with impls

def test_wrap_key_default(self):
key1 = jax.random.key(17)
data = jax.random.key_data(key1)
Expand All @@ -1055,6 +1064,37 @@ def test_wrap_key_explicit(self):
key3 = jax.random.wrap_key_data(data, impl='unsafe_rbg')
self.assertNotEqual(key1.dtype, key3.dtype)

@jtu.sample_product(prng_name=[name for name, _ in PRNG_IMPLS])
def test_key_make_like_other_key(self, prng_name):
# start by specifying the implementation by string name, then
# round trip via whatever `key_impl` outputs
k1 = jax.random.key(42, impl=prng_name)
impl = jax.random.key_impl(k1)
k2 = jax.random.key(42, impl=impl)
self.assertArraysEqual(k1, k2)
self.assertEqual(k1.dtype, k2.dtype)

@jtu.sample_product(prng_name=[name for name, _ in PRNG_IMPLS])
def test_key_wrap_like_other_key(self, prng_name):
# start by specifying the implementation by string name, then
# round trip via whatever `key_impl` outputs
k1 = jax.random.key(42, impl=prng_name)
data = jax.random.key_data(k1)
impl = jax.random.key_impl(k1)
k2 = jax.random.wrap_key_data(data, impl=impl)
self.assertArraysEqual(k1, k2)
self.assertEqual(k1.dtype, k2.dtype)

def test_key_impl_from_string_error(self):
with self.assertRaisesRegex(ValueError, 'unrecognized PRNG implementation'):
jax.random.key(42, impl='unlikely name')

def test_key_impl_from_object_error(self):
class A: pass

with self.assertRaisesRegex(TypeError, 'unrecognized type .* PRNG'):
jax.random.key(42, impl=A())

# TODO(frostig,mattjj): more polymorphic primitives tests


Expand Down

0 comments on commit 5158e25

Please sign in to comment.