Skip to content

Commit

Permalink
[jex] replace extend.random.PRNGImpl with `extend.random.define_prn…
Browse files Browse the repository at this point in the history
…g_impl`

Instead of exposing a constructor, only expose a function that returns an opaque
object representing the defined implementation. This result can still be passed
to `jax.random.key` and `wrap_key_data`.

PiperOrigin-RevId: 575027938
  • Loading branch information
froystig authored and jax authors committed Nov 1, 2023
1 parent 49fedb1 commit 37cac26
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 16 deletions.
2 changes: 1 addition & 1 deletion docs/jax.extend.rst
Expand Up @@ -28,7 +28,7 @@
.. autosummary::
:toctree: _autosummary

PRNGImpl
define_prng_impl
seed_with_impl
threefry2x32_p
threefry_2x32
Expand Down
34 changes: 34 additions & 0 deletions jax/_src/extend/random.py
@@ -0,0 +1,34 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Hashable

from jax import Array

from jax._src import prng
from jax._src import random

Shape = tuple[int, ...]

def define_prng_impl(*,
key_shape: Shape,
seed: Callable[[Array], Array],
split: Callable[[Array, Shape], Array],
random_bits: Callable[[Array, int, Shape], Array],
fold_in: Callable[[Array, int], Array],
name: str = '<unnamed>',
tag: str = '?') -> Hashable:
return random.PRNGSpec(prng.PRNGImpl(
key_shape, seed, split, random_bits, fold_in,
name=name, tag=tag))
20 changes: 14 additions & 6 deletions jax/_src/random.py
Expand Up @@ -145,8 +145,12 @@ def __eq__(self, other) -> bool:
return self._impl == other._impl


def resolve_prng_impl(
impl_spec: Optional[Union[str, PRNGSpec, PRNGImpl]]) -> PRNGImpl:
# TODO(frostig,vanderplas): remove PRNGImpl from this union when it's
# no longer in the public API because `default_prng_impl` is gone
PRNGSpecDesc = Union[str, PRNGSpec, PRNGImpl]


def resolve_prng_impl(impl_spec: Optional[PRNGSpecDesc]) -> PRNGImpl:
if impl_spec is None:
return default_prng_impl()
if type(impl_spec) is PRNGImpl:
Expand All @@ -169,7 +173,8 @@ def resolve_prng_impl(
raise TypeError(f'unrecognized type {t} for specifying PRNG implementation.')


def _key(ctor_name: str, seed: int | ArrayLike, impl_spec: Optional[str] ) -> KeyArray:
def _key(ctor_name: str, seed: int | ArrayLike,
impl_spec: Optional[PRNGSpecDesc]) -> KeyArray:
impl = resolve_prng_impl(impl_spec)
if hasattr(seed, 'dtype') and jnp.issubdtype(seed.dtype, dtypes.prng_key):
raise TypeError(
Expand All @@ -180,7 +185,8 @@ def _key(ctor_name: str, seed: int | ArrayLike, impl_spec: Optional[str] ) -> Ke
f"shape {np.shape(seed)} != (). Use jax.vmap for batching")
return prng.random_seed(seed, impl=impl)

def key(seed: int | ArrayLike, *, impl: Optional[str] = None) -> KeyArray:
def key(seed: int | ArrayLike, *,
impl: Optional[PRNGSpecDesc] = None) -> KeyArray:
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
The result is a scalar array with a key that indicates the default PRNG
Expand All @@ -198,7 +204,8 @@ def key(seed: int | ArrayLike, *, impl: Optional[str] = None) -> KeyArray:
"""
return _key('key', seed, impl)

def PRNGKey(seed: int | ArrayLike, *, impl: Optional[str] = None) -> KeyArray:
def PRNGKey(seed: int | ArrayLike, *,
impl: Optional[PRNGSpecDesc] = None) -> KeyArray:
"""Create a pseudo-random number generator (PRNG) key given an integer seed.
The resulting key carries the default PRNG implementation, as
Expand Down Expand Up @@ -321,7 +328,8 @@ def key_data(keys: KeyArrayLike) -> Array:
return _key_data(keys)


def wrap_key_data(key_bits_array: Array, *, impl: Optional[str] = None):
def wrap_key_data(key_bits_array: Array, *,
impl: Optional[PRNGSpecDesc] = None):
"""Wrap an array of key data bits into a PRNG key array.
Args:
Expand Down
8 changes: 4 additions & 4 deletions jax/extend/random.py
Expand Up @@ -15,11 +15,11 @@
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570

from jax._src.extend.random import (
define_prng_impl as define_prng_impl,
)

from jax._src.prng import (
# TODO(frostig,vanderplas): expose a define_prng_impl instead of the
# PRNGImpl constructor, to leave some room for us to register or check input,
# or to change what output type we return.
PRNGImpl as PRNGImpl,
random_seed as random_seed,
seed_with_impl as seed_with_impl,
threefry2x32_p as threefry2x32_p,
Expand Down
13 changes: 8 additions & 5 deletions tests/extend_test.py
Expand Up @@ -31,7 +31,6 @@ class ExtendTest(jtu.JaxTestCase):

def test_symbols(self):
# Assume these are tested in random_test.py, only check equivalence
self.assertIs(jex.random.PRNGImpl, prng.PRNGImpl)
self.assertIs(jex.random.seed_with_impl, prng.seed_with_impl)
self.assertIs(jex.random.threefry2x32_p, prng.threefry2x32_p)
self.assertIs(jex.random.threefry_2x32, prng.threefry_2x32)
Expand Down Expand Up @@ -61,21 +60,25 @@ def seed_rule(_):
def no_rule(*args, **kwargs):
assert False, 'unreachable'

impl = jex.random.PRNGImpl(shape, seed_rule, no_rule, no_rule, no_rule)
impl = jex.random.define_prng_impl(
key_shape=shape, seed=seed_rule, split=no_rule, fold_in=no_rule,
random_bits=no_rule)
k = jax.random.key(42, impl=impl)
self.assertEqual(k.shape, ())
self.assertEqual(impl, jax.random.key_impl(k)._impl)
self.assertEqual(impl, jax.random.key_impl(k))

def test_key_wrap_with_custom_impl(self):
def no_rule(*args, **kwargs):
assert False, 'unreachable'

shape = (4, 2, 7)
impl = jex.random.PRNGImpl(shape, no_rule, no_rule, no_rule, no_rule)
impl = jex.random.define_prng_impl(
key_shape=shape, seed=no_rule, split=no_rule, fold_in=no_rule,
random_bits=no_rule)
data = jnp.ones((3, *shape), dtype=jnp.dtype('uint32'))
k = jax.random.wrap_key_data(data, impl=impl)
self.assertEqual(k.shape, (3,))
self.assertEqual(impl, jax.random.key_impl(k)._impl)
self.assertEqual(impl, jax.random.key_impl(k))


if __name__ == "__main__":
Expand Down

0 comments on commit 37cac26

Please sign in to comment.