Skip to content

Commit

Permalink
[hcb] Add support for remat2 to host_callback
Browse files Browse the repository at this point in the history
A callback under ad_checkpoint.checkpoint will be invoked
twice when taking the gradient: once during the forward pass
and once again during the backward pass when the residuals
for the forward pass are rematerialized.
  • Loading branch information
gnecula committed Dec 15, 2021
1 parent 2c7db52 commit 3021d3e
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 41 deletions.
15 changes: 9 additions & 6 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.

* Breaking changes:
* The host_callback primitives have been simplified to drop the
special autodiff handling for hcb.id_tap and id_print.
From now on, only the primals are tapped. The old behavior can be
obtained (for a limited time) by setting the ``JAX_HOST_CALLBACK_AD_TRANSFORMS``
environment variable, or the ```--flax_host_callback_ad_transforms``` flag.
Additionally, added documentation for how to implement the old behavior
using JAX custom AD APIs ({jax-issue}`#7839`).
special autodiff handling for hcb.id_tap and id_print.
From now on, only the primals are tapped. The old behavior can be
obtained (for a limited time) by setting the ``JAX_HOST_CALLBACK_AD_TRANSFORMS``
environment variable, or the ```--flax_host_callback_ad_transforms``` flag.
Additionally, added documentation for how to implement the old behavior
using JAX custom AD APIs ({jax-issue}`#8678`).

* Bug fixes:
* host_callback now supports ad_checkpoint.checkpoint ({jax-issue}`#8907`).

* New features:
* add `jax.block_until_ready` ({jax-issue}`#8941)
Expand Down
26 changes: 0 additions & 26 deletions jax/_src/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,32 +321,6 @@ def apply_outfeed_rewriter(jaxpr: core.Jaxpr) -> core.Jaxpr:
else:
return jaxpr

outfeed_primitives: Set[core.Primitive] = set()
def jaxpr_uses_outfeed(jaxpr: core.Jaxpr) -> bool:
"""Finds if there are outfeed primitives anywhere inside a Jaxpr."""
return any(primitive_uses_outfeed(eqn.primitive, eqn.params)
for eqn in jaxpr.eqns)

def _param_uses_outfeed(param):
if type(param) is core.Jaxpr:
if jaxpr_uses_outfeed(param):
return True
elif type(param) is core.ClosedJaxpr:
if jaxpr_uses_outfeed(param.jaxpr):
return True
return False

def primitive_uses_outfeed(prim: core.Primitive, params: Dict) -> bool:
if prim in outfeed_primitives:
return True
for param in params.values():
if isinstance(param, tuple):
if any(unsafe_map(_param_uses_outfeed, param)):
return True
elif _param_uses_outfeed(param):
return True
return False


def jaxpr_replicas(jaxpr) -> int:
"""The number of replicas needed for a jaxpr.
Expand Down
26 changes: 26 additions & 0 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1678,6 +1678,32 @@ def call_impl(f: lu.WrappedFun, *args, **params):
named_call_p: CallPrimitive = CallPrimitive('named_call')
named_call_p.def_impl(call_impl)

outfeed_primitives: Set[Primitive] = set()
def jaxpr_uses_outfeed(jaxpr: Jaxpr) -> bool:
"""Finds if there are outfeed primitives anywhere inside a Jaxpr."""
return any(primitive_uses_outfeed(eqn.primitive, eqn.params)
for eqn in jaxpr.eqns)

def _param_uses_outfeed(param):
if type(param) is Jaxpr:
if jaxpr_uses_outfeed(param):
return True
elif type(param) is ClosedJaxpr:
if jaxpr_uses_outfeed(param.jaxpr):
return True
return False

def primitive_uses_outfeed(prim: Primitive, params: Dict) -> bool:
if prim in outfeed_primitives:
return True
for param in params.values():
if isinstance(param, tuple):
if any(unsafe_map(_param_uses_outfeed, param)):
return True
elif _param_uses_outfeed(param):
return True
return False

# ------------------- Map -------------------

def mapped_aval(size: int, axis: int, aval: AbstractValue) -> AbstractValue:
Expand Down
22 changes: 18 additions & 4 deletions jax/experimental/host_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,20 @@ def power3_with_cotangents(x):
# what: x,x^2 : (3., 9.)
# what: cotangents : (9., 3.)
If you use :func:`ad_checkpoint.checkpoint` to rematerialize the residuals
for the backward pass, then the callbacks from the primal computation will
be called twice::
jax.grad(lambda x: power3(ad_checkpoint.checkpoint(power3)(x)))(3.)
# what: x,x^2 : (3., 9.)
# what: x,x^2 : (27., 729.)
# what: x,x^2 : (3., 9.)
The callbacks are, in order from: the primal computation of the inner ``power3``,
the primal computation of the outer ``power3``, and the rematerialization
of the residuals for the inner ``power3``.
Behavior under jax.vmap
-----------------------
Expand Down Expand Up @@ -900,7 +914,7 @@ def _id_tap_dep_masking_rule(operands, operands_logical_shapes):
"""
outside_call_p = core.Primitive("outside_call")
outside_call_p.multiple_results = True
dispatch.outfeed_primitives.add(outside_call_p)
core.outfeed_primitives.add(outside_call_p)


def _outside_call_abstract_eval(*args_a: pe.AbstractValue,
Expand Down Expand Up @@ -1385,7 +1399,7 @@ def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool,
"""Rewrite a Jaxpr to thread the token, if needed."""
assert has_input_token or not has_output_token

if not has_input_token and not dispatch.jaxpr_uses_outfeed(jaxpr):
if not has_input_token and not core.jaxpr_uses_outfeed(jaxpr):
return jaxpr

mk_new_var = core.gensym([jaxpr])
Expand All @@ -1407,7 +1421,7 @@ def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool,
lax.create_token_p, {}, source_info_util.current()))

for eqn in jaxpr.eqns:
if not dispatch.primitive_uses_outfeed(eqn.primitive, eqn.params):
if not core.primitive_uses_outfeed(eqn.primitive, eqn.params):
eqns.append(eqn)
else:
output_token_var = mk_new_var(last_token_var.aval)
Expand Down Expand Up @@ -1445,7 +1459,7 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
cond_jaxpr, _, body_jaxpr, _ = util.split_dict(
eqn.params,
["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"])
if dispatch.jaxpr_uses_outfeed(cond_jaxpr.jaxpr):
if core.jaxpr_uses_outfeed(cond_jaxpr.jaxpr):
_rewrite_while_outfeed_cond(eqn, eqns, input_token_var, output_token_var,
input_itoken_var, output_itoken_var,
mk_new_var)
Expand Down
2 changes: 1 addition & 1 deletion jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,7 @@ def write(x: Atom, b: bool) -> None:
used_outs = map(read, eqn.outvars)
# If any outputs are used, then we need to keep a version of the eqn and
# potentially mark some inputs as used. Otherwise mark all inputs as unused.
if any(used_outs):
if any(used_outs) or core.primitive_uses_outfeed(eqn.primitive, eqn.params):
# If there's a rule for modifying the eqn and computing used inputs, apply
# it. Otherwise, keep the eqn unmodified and mark all inputs as used.
rule = dce_rules.get(eqn.primitive)
Expand Down
30 changes: 26 additions & 4 deletions tests/host_callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from absl.testing import parameterized

import jax
from jax import ad_checkpoint
from jax import core
from jax.config import config
from jax import dtypes
Expand Down Expand Up @@ -1458,6 +1459,19 @@ def power3_with_cotangents(x):
transforms: [('batch', {'batch_dims': (0, 0)})] what: cotangents
( [4. 9.] [2. 3.] )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()

print(f"grad o remat = {jax.grad(lambda x: power3(ad_checkpoint.checkpoint(power3)(x)))(3.)}")
hcb.barrier_wait()
expected = """
what: x,x^2
( 3. 9. )
what: x,x^2
( 27. 729. )
what: x,x^2
( 3. 9. )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()

def test_tap_pmap(self):
if len(local_devices()) < 2:
Expand Down Expand Up @@ -2024,8 +2038,8 @@ def loss(k):
use_result=use_result, use_remat=use_remat, grad_func=grad_func)
for use_result in [True, False]
for grad_func in ["grad", "value_and_grad"]
for use_remat in ["old", "none"]))
def test_tap_remat(self, use_result=False, grad_func="grad", use_remat="old"):
for use_remat in ["old", "new", "none"]))
def test_tap_remat(self, use_result=False, grad_func="grad", use_remat="new"):
def f(x):
id_print_result = hcb.id_print(x, output_stream=testing_stream)
if use_result:
Expand All @@ -2034,6 +2048,8 @@ def f(x):
grad_f = jax.grad if grad_func == "grad" else jax.value_and_grad
if use_remat == "old":
trans_f = jax.remat(f)
elif use_remat == "new":
trans_f = ad_checkpoint.checkpoint(f)
else:
assert use_remat == "none"
trans_f = f
Expand Down Expand Up @@ -2068,8 +2084,14 @@ def f(x):
2.
2."""
else:
# TODO: we should see two callbacks
expected = ""
if use_remat == "old":
# TODO: we should see two callbacks
expected = ""
else:
# Good: we see two callbacks, whether or not we use the result.
expected = """
2.
2."""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)

def test_tap_named_call(self):
Expand Down

0 comments on commit 3021d3e

Please sign in to comment.