From 49eb7008c0d9d4424ca311f67f108954e8199d2c Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 14 Feb 2024 14:01:08 -0800 Subject: [PATCH] Define reuse_key primitive in jax._src.prng --- docs/jax.experimental.key_reuse.rst | 2 +- jax/_src/lax/control_flow/loops.py | 2 +- jax/_src/prng.py | 23 +++++++++++++++++++++++ jax/experimental/jax2tf/jax2tf.py | 3 +-- jax/experimental/key_reuse/__init__.py | 4 +++- jax/experimental/key_reuse/_common.py | 11 ----------- jax/experimental/key_reuse/_forwarding.py | 4 ++-- jax/experimental/key_reuse/_simple.py | 4 ++-- tests/key_reuse_test.py | 15 +++++++-------- 9 files changed, 40 insertions(+), 28 deletions(-) diff --git a/docs/jax.experimental.key_reuse.rst b/docs/jax.experimental.key_reuse.rst index 27975a9f153e..5c7caf80f0ce 100644 --- a/docs/jax.experimental.key_reuse.rst +++ b/docs/jax.experimental.key_reuse.rst @@ -9,5 +9,5 @@ API .. autosummary:: :toctree: _autosummary - unconsumed_copy + reuse_key KeyReuseError diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 762a3b963a2a..0997665f98de 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -396,7 +396,7 @@ def body_fun(vals): # because the scan body may consume any keys within it. # Import here to avoid circular imports from jax.experimental import key_reuse - xs_unconsumed = _map(key_reuse.unconsumed_copy, xs) + xs_unconsumed = _map(key_reuse.reuse_key, xs) x = _map(partial(_dynamic_index_array, i_), x_avals, xs_unconsumed) out_flat = f_impl(*consts, *carry, *x) carry_out, y_updates = split_list(out_flat, [num_carry]) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 961ad5c4cd7c..9174407d39b2 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -1338,3 +1338,26 @@ def _unsafe_rbg_fold_in(key: typing.Array, data: typing.Array) -> typing.Array: tag='urbg') register_prng(unsafe_rbg_prng_impl) + + +# Primitives related to key reuse +reuse_key_p = core.Primitive("reuse_key") +reuse_key_p.def_impl(lambda x: x) +reuse_key_p.def_abstract_eval(lambda x: x) +batching.defvectorized(reuse_key_p) +mlir.register_lowering(reuse_key_p, lambda _, k: [k]) + +def reuse_key(key): + """Explicitly mark a key as unconsumed. + + Outside the context of key reuse checking (see :mod:`jax.experimental.key_reuse`) + this function operates as an identity. + + Example: + + >>> import jax + >>> key = jax.random.key(0) + >>> data = jax.random.uniform(key) + >>> same_data = jax.random.uniform(reuse_key(key)) + """ + return reuse_key_p.bind(key) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index a40acb1840af..e57e60583239 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -69,7 +69,6 @@ from jax._src.lax import windowed_reductions as lax_windowed_reductions from jax._src.lib import xla_client from jax._src.numpy.ufuncs import logaddexp -from jax.experimental.key_reuse._common import unconsumed_copy_p import tensorflow as tf # type: ignore[import] @@ -1529,7 +1528,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs): "consume", ] -tf_impl[unconsumed_copy_p] = lambda x: x +tf_impl[prng.reuse_key_p] = lambda x: x tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient diff --git a/jax/experimental/key_reuse/__init__.py b/jax/experimental/key_reuse/__init__.py index c7ca0f9d0b62..6b330975a109 100644 --- a/jax/experimental/key_reuse/__init__.py +++ b/jax/experimental/key_reuse/__init__.py @@ -38,8 +38,10 @@ This flag can also be set globally if you wish to enagle key reuse checks in every JIT-compiled function. """ +from jax._src.prng import ( + reuse_key as reuse_key, +) from jax.experimental.key_reuse._common import ( - unconsumed_copy as unconsumed_copy, KeyReuseError as KeyReuseError, ) diff --git a/jax/experimental/key_reuse/_common.py b/jax/experimental/key_reuse/_common.py index f5237c4d157f..dc2f451dafce 100644 --- a/jax/experimental/key_reuse/_common.py +++ b/jax/experimental/key_reuse/_common.py @@ -62,17 +62,6 @@ def consume(key): """Consume the key and return a consumed copy.""" return consume_p.bind(key) -unconsumed_copy_p = core.Primitive("unconsumed_copy") -unconsumed_copy_p.def_impl(lambda x: x) -unconsumed_copy_p.def_abstract_eval(lambda x: x) -batching.defvectorized(unconsumed_copy_p) -mlir.register_lowering( - unconsumed_copy_p, - mlir.lower_fun(lambda x: x, multiple_results=False)) - -def unconsumed_copy(key): - """Return a copy of key marked as unconsumed.""" - return unconsumed_copy_p.bind(key) assert_consumed_value_p = core.Primitive("assert_consumed_value") assert_consumed_value_p.def_impl(lambda x, *, value: x) diff --git a/jax/experimental/key_reuse/_forwarding.py b/jax/experimental/key_reuse/_forwarding.py index ebe3737cc1b0..d25d9722483f 100644 --- a/jax/experimental/key_reuse/_forwarding.py +++ b/jax/experimental/key_reuse/_forwarding.py @@ -33,7 +33,7 @@ from jax._src.interpreters import partial_eval as pe from jax.experimental.key_reuse._common import ( - consume_p, unconsumed_copy_p, assert_consumed_value_p, KeyReuseError, + consume_p, assert_consumed_value_p, KeyReuseError, Sink, Source, KeyReuseSignature ) import numpy as np @@ -52,7 +52,7 @@ class KeyReuseSignatureWithForwards(NamedTuple): key_reuse_signatures: dict[core.Primitive, KeyReuseSignatureWithForwards] = {} key_reuse_signatures[consume_p] = KeyReuseSignatureWithForwards([Sink(0)], [], [Forward(0, 0)]) -key_reuse_signatures[unconsumed_copy_p] = KeyReuseSignatureWithForwards([], [Source(0)]) +key_reuse_signatures[prng.reuse_key_p] = KeyReuseSignatureWithForwards([], [Source(0)]) key_reuse_signatures[prng.random_bits_p] = KeyReuseSignatureWithForwards([Sink(0)], []) # TODO(jakevdp): should fold_in sink its input key? # key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignatureWithForwards([Sink(0)], [Source(0)]) diff --git a/jax/experimental/key_reuse/_simple.py b/jax/experimental/key_reuse/_simple.py index 06f8e15b6a0e..2e26a95b1e3f 100644 --- a/jax/experimental/key_reuse/_simple.py +++ b/jax/experimental/key_reuse/_simple.py @@ -33,7 +33,7 @@ from jax._src.interpreters import partial_eval as pe from jax.experimental.key_reuse._common import ( - consume_p, unconsumed_copy_p, assert_consumed_value_p, KeyReuseError, + consume_p, assert_consumed_value_p, KeyReuseError, Sink, Source, KeyReuseSignature ) import numpy as np @@ -42,7 +42,7 @@ key_reuse_signatures: dict[core.Primitive, KeyReuseSignature] = {} key_reuse_signatures[consume_p] = KeyReuseSignature([Sink(0)], []) -key_reuse_signatures[unconsumed_copy_p] = KeyReuseSignature([], [Source(0)]) +key_reuse_signatures[prng.reuse_key_p] = KeyReuseSignature([], [Source(0)]) key_reuse_signatures[prng.random_bits_p] = KeyReuseSignature([Sink(0)], []) # TODO(jakevdp): should fold_in sink its input key? # key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature([Sink(0)], [Source(0)]) diff --git a/tests/key_reuse_test.py b/tests/key_reuse_test.py index cae82aa34f40..f11c4844118a 100644 --- a/tests/key_reuse_test.py +++ b/tests/key_reuse_test.py @@ -22,9 +22,8 @@ from jax._src import prng from jax._src import test_util as jtu from jax.experimental.key_reuse._common import ( - assert_consumed, assert_unconsumed, consume, consume_p, unconsumed_copy_p) -from jax.experimental.key_reuse import ( - _forwarding, _simple, KeyReuseError, unconsumed_copy) + assert_consumed, assert_unconsumed, consume, consume_p) +from jax.experimental.key_reuse import _forwarding, _simple, KeyReuseError from jax import config config.parse_flags_with_absl() @@ -36,7 +35,7 @@ primitives_with_static_signatures = { consume_p: (consume, key), - unconsumed_copy_p: (unconsumed_copy, key), + prng.reuse_key_p: (prng.reuse_key, key), prng.random_bits_p: (jax.random.bits, key), prng.random_fold_in_p: (jax.random.fold_in, key, 2), prng.random_seed_p: (jax.random.key, 0), @@ -91,12 +90,12 @@ def f(key): assert_consumed(key2) self.check_key_reuse(f, jax.random.key(0)) - def test_unconsumed_copy(self): + def test_reuse_key(self): def f(key): assert_unconsumed(key) consume(key) assert_consumed(key) - key2 = unconsumed_copy(key) + key2 = prng.reuse_key(key) assert_unconsumed(key2) self.check_key_reuse(f, jax.random.key(0)) @@ -337,12 +336,12 @@ def f(key): assert_consumed(key2) self.check_key_reuse(f, jax.random.key(0)) - def test_unconsumed_copy(self): + def test_reuse_key(self): def f(key): assert_unconsumed(key) consume(key) assert_consumed(key) - key2 = unconsumed_copy(key) + key2 = prng.reuse_key(key) assert_unconsumed(key2) self.check_key_reuse(f, jax.random.key(0))