Skip to content

Commit

Permalink
Define reuse_key primitive in jax._src.prng
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Feb 14, 2024
1 parent b9824d7 commit 49eb700
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 28 deletions.
2 changes: 1 addition & 1 deletion docs/jax.experimental.key_reuse.rst
Expand Up @@ -9,5 +9,5 @@ API
.. autosummary::
:toctree: _autosummary

unconsumed_copy
reuse_key
KeyReuseError
2 changes: 1 addition & 1 deletion jax/_src/lax/control_flow/loops.py
Expand Up @@ -396,7 +396,7 @@ def body_fun(vals):
# because the scan body may consume any keys within it.
# Import here to avoid circular imports
from jax.experimental import key_reuse
xs_unconsumed = _map(key_reuse.unconsumed_copy, xs)
xs_unconsumed = _map(key_reuse.reuse_key, xs)
x = _map(partial(_dynamic_index_array, i_), x_avals, xs_unconsumed)
out_flat = f_impl(*consts, *carry, *x)
carry_out, y_updates = split_list(out_flat, [num_carry])
Expand Down
23 changes: 23 additions & 0 deletions jax/_src/prng.py
Expand Up @@ -1338,3 +1338,26 @@ def _unsafe_rbg_fold_in(key: typing.Array, data: typing.Array) -> typing.Array:
tag='urbg')

register_prng(unsafe_rbg_prng_impl)


# Primitives related to key reuse
reuse_key_p = core.Primitive("reuse_key")
reuse_key_p.def_impl(lambda x: x)
reuse_key_p.def_abstract_eval(lambda x: x)
batching.defvectorized(reuse_key_p)
mlir.register_lowering(reuse_key_p, lambda _, k: [k])

def reuse_key(key):
"""Explicitly mark a key as unconsumed.
Outside the context of key reuse checking (see :mod:`jax.experimental.key_reuse`)
this function operates as an identity.
Example:
>>> import jax
>>> key = jax.random.key(0)
>>> data = jax.random.uniform(key)
>>> same_data = jax.random.uniform(reuse_key(key))
"""
return reuse_key_p.bind(key)
3 changes: 1 addition & 2 deletions jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -69,7 +69,6 @@
from jax._src.lax import windowed_reductions as lax_windowed_reductions
from jax._src.lib import xla_client
from jax._src.numpy.ufuncs import logaddexp
from jax.experimental.key_reuse._common import unconsumed_copy_p

import tensorflow as tf # type: ignore[import]

Expand Down Expand Up @@ -1529,7 +1528,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
"consume",
]

tf_impl[unconsumed_copy_p] = lambda x: x
tf_impl[prng.reuse_key_p] = lambda x: x

tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient

Expand Down
4 changes: 3 additions & 1 deletion jax/experimental/key_reuse/__init__.py
Expand Up @@ -38,8 +38,10 @@
This flag can also be set globally if you wish to enagle key reuse checks in
every JIT-compiled function.
"""
from jax._src.prng import (
reuse_key as reuse_key,
)

from jax.experimental.key_reuse._common import (
unconsumed_copy as unconsumed_copy,
KeyReuseError as KeyReuseError,
)
11 changes: 0 additions & 11 deletions jax/experimental/key_reuse/_common.py
Expand Up @@ -62,17 +62,6 @@ def consume(key):
"""Consume the key and return a consumed copy."""
return consume_p.bind(key)

unconsumed_copy_p = core.Primitive("unconsumed_copy")
unconsumed_copy_p.def_impl(lambda x: x)
unconsumed_copy_p.def_abstract_eval(lambda x: x)
batching.defvectorized(unconsumed_copy_p)
mlir.register_lowering(
unconsumed_copy_p,
mlir.lower_fun(lambda x: x, multiple_results=False))

def unconsumed_copy(key):
"""Return a copy of key marked as unconsumed."""
return unconsumed_copy_p.bind(key)

assert_consumed_value_p = core.Primitive("assert_consumed_value")
assert_consumed_value_p.def_impl(lambda x, *, value: x)
Expand Down
4 changes: 2 additions & 2 deletions jax/experimental/key_reuse/_forwarding.py
Expand Up @@ -33,7 +33,7 @@
from jax._src.interpreters import partial_eval as pe

from jax.experimental.key_reuse._common import (
consume_p, unconsumed_copy_p, assert_consumed_value_p, KeyReuseError,
consume_p, assert_consumed_value_p, KeyReuseError,
Sink, Source, KeyReuseSignature
)
import numpy as np
Expand All @@ -52,7 +52,7 @@ class KeyReuseSignatureWithForwards(NamedTuple):
key_reuse_signatures: dict[core.Primitive, KeyReuseSignatureWithForwards] = {}

key_reuse_signatures[consume_p] = KeyReuseSignatureWithForwards([Sink(0)], [], [Forward(0, 0)])
key_reuse_signatures[unconsumed_copy_p] = KeyReuseSignatureWithForwards([], [Source(0)])
key_reuse_signatures[prng.reuse_key_p] = KeyReuseSignatureWithForwards([], [Source(0)])
key_reuse_signatures[prng.random_bits_p] = KeyReuseSignatureWithForwards([Sink(0)], [])
# TODO(jakevdp): should fold_in sink its input key?
# key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignatureWithForwards([Sink(0)], [Source(0)])
Expand Down
4 changes: 2 additions & 2 deletions jax/experimental/key_reuse/_simple.py
Expand Up @@ -33,7 +33,7 @@
from jax._src.interpreters import partial_eval as pe

from jax.experimental.key_reuse._common import (
consume_p, unconsumed_copy_p, assert_consumed_value_p, KeyReuseError,
consume_p, assert_consumed_value_p, KeyReuseError,
Sink, Source, KeyReuseSignature
)
import numpy as np
Expand All @@ -42,7 +42,7 @@
key_reuse_signatures: dict[core.Primitive, KeyReuseSignature] = {}

key_reuse_signatures[consume_p] = KeyReuseSignature([Sink(0)], [])
key_reuse_signatures[unconsumed_copy_p] = KeyReuseSignature([], [Source(0)])
key_reuse_signatures[prng.reuse_key_p] = KeyReuseSignature([], [Source(0)])
key_reuse_signatures[prng.random_bits_p] = KeyReuseSignature([Sink(0)], [])
# TODO(jakevdp): should fold_in sink its input key?
# key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature([Sink(0)], [Source(0)])
Expand Down
15 changes: 7 additions & 8 deletions tests/key_reuse_test.py
Expand Up @@ -22,9 +22,8 @@
from jax._src import prng
from jax._src import test_util as jtu
from jax.experimental.key_reuse._common import (
assert_consumed, assert_unconsumed, consume, consume_p, unconsumed_copy_p)
from jax.experimental.key_reuse import (
_forwarding, _simple, KeyReuseError, unconsumed_copy)
assert_consumed, assert_unconsumed, consume, consume_p)
from jax.experimental.key_reuse import _forwarding, _simple, KeyReuseError

from jax import config
config.parse_flags_with_absl()
Expand All @@ -36,7 +35,7 @@

primitives_with_static_signatures = {
consume_p: (consume, key),
unconsumed_copy_p: (unconsumed_copy, key),
prng.reuse_key_p: (prng.reuse_key, key),
prng.random_bits_p: (jax.random.bits, key),
prng.random_fold_in_p: (jax.random.fold_in, key, 2),
prng.random_seed_p: (jax.random.key, 0),
Expand Down Expand Up @@ -91,12 +90,12 @@ def f(key):
assert_consumed(key2)
self.check_key_reuse(f, jax.random.key(0))

def test_unconsumed_copy(self):
def test_reuse_key(self):
def f(key):
assert_unconsumed(key)
consume(key)
assert_consumed(key)
key2 = unconsumed_copy(key)
key2 = prng.reuse_key(key)
assert_unconsumed(key2)
self.check_key_reuse(f, jax.random.key(0))

Expand Down Expand Up @@ -337,12 +336,12 @@ def f(key):
assert_consumed(key2)
self.check_key_reuse(f, jax.random.key(0))

def test_unconsumed_copy(self):
def test_reuse_key(self):
def f(key):
assert_unconsumed(key)
consume(key)
assert_consumed(key)
key2 = unconsumed_copy(key)
key2 = prng.reuse_key(key)
assert_unconsumed(key2)
self.check_key_reuse(f, jax.random.key(0))

Expand Down

0 comments on commit 49eb700

Please sign in to comment.