Skip to content

Commit

Permalink
Merge pull request #19837 from jakevdp:key-reuse-clerrs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 607489807
  • Loading branch information
jax authors committed Feb 16, 2024
2 parents 243e7ed + 8eab599 commit 0203d15
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 65 deletions.
96 changes: 35 additions & 61 deletions jax/experimental/key_reuse/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from collections import defaultdict
from functools import reduce
from typing import Any, Callable, NamedTuple
from typing import Any, Callable

import jax
from jax import lax
Expand All @@ -38,10 +38,24 @@
)
import numpy as np


def _check_consumed_value(eqn, consumed):
"""Extra check for use with assert_consumed_value_p"""
expected = eqn.params['value']
if not np.all(consumed == expected):
if np.all(expected):
raise AssertionError(f"Expected key to be consumed in {eqn}")
elif not np.any(expected):
raise AssertionError(f"Expected key to not be consumed in {eqn}")
else:
raise AssertionError(f"Expected {expected}, got {consumed} in {eqn}")


# The behavior of most primitives can be described via simple signatures.
key_reuse_signatures: dict[core.Primitive, KeyReuseSignature] = {}

key_reuse_signatures[consume_p] = KeyReuseSignature([Sink(0)], [], [Forward(0, 0)])
key_reuse_signatures[assert_consumed_value_p] = KeyReuseSignature([], [], [Forward(0, 0)])
key_reuse_signatures[prng.reuse_key_p] = KeyReuseSignature([], [Source(0)])
key_reuse_signatures[prng.random_bits_p] = KeyReuseSignature([Sink(0)], [])
# TODO(jakevdp): should fold_in sink its input key?
Expand All @@ -50,6 +64,7 @@
key_reuse_signatures[prng.random_seed_p] = KeyReuseSignature([], [Source(0)])
key_reuse_signatures[prng.random_split_p] = KeyReuseSignature([Sink(0)], [Source(0)])
key_reuse_signatures[random.random_gamma_p] = KeyReuseSignature([Sink(0)], [])
# TODO(jakevdp): broadcast should probably consume the input to avoid implicit duplication
key_reuse_signatures[lax.broadcast_in_dim_p] = KeyReuseSignature([], [], [Forward(0, 0)])
key_reuse_signatures[lax.copy_p] = KeyReuseSignature([], [], [Forward(0, 0)])
key_reuse_signatures[lax.convert_element_type_p] = KeyReuseSignature([], [], [Forward(0, 0)])
Expand All @@ -67,19 +82,15 @@
key_reuse_signatures_dynamic: dict[core.Primitive, Callable[..., KeyReuseSignature]] = {}

# The default signature will Sink all key inputs, and not Source any.
def unknown_signature(eqn, args_consumed):
def unknown_signature(eqn):
def is_key(var: core.Atom):
return hasattr(var.aval, "dtype") and jax.dtypes.issubdtype(var.aval.dtype, jax.dtypes.prng_key)
return KeyReuseSignature(
sinks=[Sink(idx, True) for idx, var in enumerate(eqn.invars) if is_key(var)],
sources=[],
)

def get_jaxpr_type_signature(
jaxpr: core.Jaxpr,
consumed_inputs: list[bool | np.ndarray] | None = None,
forwarded_inputs: dict[int, int] | None = None,
) -> KeyReuseSignature:
def get_jaxpr_type_signature(jaxpr: core.Jaxpr) -> KeyReuseSignature:
"""Parse the jaxpr to determine key reuse signature"""
consumed: dict[core.Atom, bool | np.ndarray] = {}
forwards: dict[core.Atom, core.Atom] = {} # map forwarded outputs to inputs.
Expand Down Expand Up @@ -122,24 +133,18 @@ def is_consumed(var: core.Atom):
return False
return consumed.get(var, False)

if forwarded_inputs:
for i, j in forwarded_inputs.items():
forwards[jaxpr.invars[i]] = jaxpr.invars[j]

if consumed_inputs:
for var, mask in util.safe_zip(jaxpr.invars, consumed_inputs):
if not isinstance(var, core.Literal):
source(var, mask)

for eqn in jaxpr.eqns:
if eqn.primitive in key_reuse_signatures:
signature = key_reuse_signatures[eqn.primitive]
elif eqn.primitive in key_reuse_signatures_dynamic:
args_consumed = [is_consumed(var) for var in eqn.invars]
signature = key_reuse_signatures_dynamic[eqn.primitive](eqn, args_consumed)
signature = key_reuse_signatures_dynamic[eqn.primitive](eqn)
else:
args_consumed = [is_consumed(var) for var in eqn.invars]
signature = unknown_signature(eqn, args_consumed)
signature = unknown_signature(eqn)

if eqn.primitive == assert_consumed_value_p:
# This is a special case that goes beyond normal key reuse logic.
_check_consumed_value(eqn, is_consumed(eqn.invars[0]))

for in_idx, out_idx in signature.forwards:
forwards[eqn.outvars[out_idx]] = eqn.invars[in_idx]

Expand Down Expand Up @@ -187,8 +192,7 @@ def check_key_reuse(fun: Callable[..., Any], /, *args: Any) -> None:
#----------------------------------------------------------------------------------
# key reuse rules for particular primitives:

def _slice_signature(eqn, args_consumed):
del args_consumed # unused here
def _slice_signature(eqn):
in_aval = eqn.invars[0].aval
if not jax.dtypes.issubdtype(in_aval.dtype, jax.dtypes.prng_key):
return KeyReuseSignature([], [], [Forward(0, 0)])
Expand All @@ -204,35 +208,13 @@ def _slice_signature(eqn, args_consumed):

key_reuse_signatures_dynamic[lax.slice_p] = _slice_signature

def _pjit_key_type_signature(eqn, args_consumed):
jaxpr = eqn.params['jaxpr']
forwarded_inputs = {i: eqn.invars.index(var) for i, var in enumerate(eqn.invars)
if var in eqn.invars[:i]}
sig = get_jaxpr_type_signature(jaxpr.jaxpr)
if args_consumed and any(np.any(args_consumed[s.idx] & s.mask) for s in sig.sinks):
# Double consumption detected: re-trace with context for better errors.
get_jaxpr_type_signature(jaxpr.jaxpr, args_consumed, forwarded_inputs)
return sig
def _pjit_key_type_signature(eqn):
return get_jaxpr_type_signature(eqn.params['jaxpr'].jaxpr)

key_reuse_signatures_dynamic[pjit.pjit_p] = _pjit_key_type_signature

def _assert_consumed_value_key_type_signature(eqn, args_consumed):
actual = args_consumed[0]
expected = eqn.params['value']
if not np.all(actual == expected):
if np.all(expected):
raise AssertionError(f"Expected key to be consumed in {eqn}")
elif not np.any(expected):
raise AssertionError(f"Expected key to not be consumed in {eqn}")
else:
raise AssertionError(f"Expected {expected}, got {actual} in {eqn}")
return KeyReuseSignature([], [], [Forward(0, 0)])

key_reuse_signatures_dynamic[assert_consumed_value_p] = _assert_consumed_value_key_type_signature

def _cond_key_type_signature(eqn, args_consumed):
signatures = [get_jaxpr_type_signature(branch.jaxpr, consumed_inputs=args_consumed[1:])
for branch in eqn.params['branches']]
def _cond_key_type_signature(eqn):
signatures = [get_jaxpr_type_signature(branch.jaxpr) for branch in eqn.params['branches']]
sinks = defaultdict(list)
sources = defaultdict(list)
for sig in signatures:
Expand All @@ -249,11 +231,11 @@ def _cond_key_type_signature(eqn, args_consumed):

key_reuse_signatures_dynamic[lax.cond_p] = _cond_key_type_signature

def _scan_key_type_signature(eqn, args_consumed):
def _scan_key_type_signature(eqn):
jaxpr = eqn.params['jaxpr'].jaxpr
num_consts = eqn.params['num_consts']
num_carry = eqn.params['num_carry']
signature = get_jaxpr_type_signature(jaxpr, args_consumed)
signature = get_jaxpr_type_signature(jaxpr)

# scan body should not consume key in constants
if any(np.any(s.mask) for s in signature.sinks if s.idx < num_consts):
Expand All @@ -278,13 +260,12 @@ def _scan_key_type_signature(eqn, args_consumed):

key_reuse_signatures_dynamic[jax.lax.scan_p] = _scan_key_type_signature

def _while_key_type_signature(eqn, args_consumed):
def _while_key_type_signature(eqn):
cond_jaxpr = eqn.params['cond_jaxpr'].jaxpr
cond_nconsts = eqn.params['cond_nconsts']
body_jaxpr = eqn.params['body_jaxpr'].jaxpr
body_nconsts = eqn.params['body_nconsts']

# TODO(jakevdp): pass args_consumed here?
cond_signature = get_jaxpr_type_signature(cond_jaxpr)
body_signature = get_jaxpr_type_signature(body_jaxpr)

Expand Down Expand Up @@ -320,21 +301,14 @@ def _while_key_type_signature(eqn, args_consumed):

key_reuse_signatures_dynamic[jax.lax.while_p] = _while_key_type_signature

def _remat_key_type_signature(eqn, args_consumed):
def _remat_key_type_signature(eqn):
# The assumption here is that the non-differentiated pass contains all relevant
# key usage, and the differentiated pass
# 1) will only consume keys that are already consumed in the non-differentiated pass
# 2) will never create keys
# Therefore, the differentiated pass is a no-op.
if eqn.params['differentiated']:
return KeyReuseSignature([], [])
jaxpr = eqn.params['jaxpr']
forwarded_inputs = {i: eqn.invars.index(var) for i, var in enumerate(eqn.invars)
if var in eqn.invars[:i]}
sig = get_jaxpr_type_signature(jaxpr)
if args_consumed and any(np.any(args_consumed[s.idx] & s.mask) for s in sig.sinks):
# Double consumption detected: re-trace with context for better errors.
get_jaxpr_type_signature(jaxpr, args_consumed, forwarded_inputs)
return sig
return get_jaxpr_type_signature(eqn.params['jaxpr'])

key_reuse_signatures_dynamic[remat_p] = _remat_key_type_signature
8 changes: 4 additions & 4 deletions tests/key_reuse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def f():
key = jax.random.key(0)
return jax.random.uniform(key) + jax.random.uniform(key)

with self.assertRaisesRegex(KeyReuseError, self.random_bits_error):
with self.assertRaisesRegex(KeyReuseError, self.pjit_error):
self.check_key_reuse(f)

def test_reuse_after_split(self):
Expand All @@ -350,7 +350,7 @@ def f_bad():
_ = jax.random.split(key)
return jax.random.uniform(key)

with self.assertRaisesRegex(KeyReuseError, self.random_bits_error):
with self.assertRaisesRegex(KeyReuseError, self.pjit_error):
self.check_key_reuse(f_bad)

def f_bad_2():
Expand Down Expand Up @@ -418,15 +418,15 @@ def f_bad(key, condition):
r1 = jax.lax.cond(condition, jax.random.uniform, jax.random.normal, key)
return r1 + jax.random.uniform(key)

with self.assertRaisesRegex(KeyReuseError, self.random_bits_error):
with self.assertRaisesRegex(KeyReuseError, self.pjit_error):
self.check_key_reuse(f_bad, key, True)

# Check where only one branch consumes the key
def f_bad_2(key, condition):
r1 = jax.lax.cond(condition, jax.random.uniform, lambda key: 1.0, key)
return r1 + jax.random.uniform(key)

with self.assertRaisesRegex(KeyReuseError, self.random_bits_error):
with self.assertRaisesRegex(KeyReuseError, self.pjit_error):
self.check_key_reuse(f_bad_2, key, True)

def test_simple_scan(self):
Expand Down

0 comments on commit 0203d15

Please sign in to comment.