Skip to content

Commit

Permalink
improve remat transpose caching (cf. #9661)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Mar 25, 2022
1 parent 563e0c6 commit 78cf4df
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 25 deletions.
47 changes: 23 additions & 24 deletions jax/interpreters/ad.py
Expand Up @@ -28,7 +28,7 @@
from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
zeros_like_aval, zeros_like_p, Zero)
from jax._src.util import (unzip2, safe_map, safe_zip, split_list, wrap_name,
as_hashable_function, cache)
as_hashable_function, weakref_lru_cache)
from jax.tree_util import register_pytree_node
from jax import linear_util as lu
from jax._src.api_util import flatten_fun, flatten_fun_nokwargs
Expand Down Expand Up @@ -586,7 +586,8 @@ def traceable(num_primals, in_tree_def, *primals_and_tangents):

def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes, False)
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr,
reduce_axes, False)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
if config.jax_experimental_name_stack:
new_params = params
Expand All @@ -603,31 +604,29 @@ def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):

def remat_transpose(params, call_jaxpr, primals_in, cotangents_in,
cotangent_in_avals, reduce_axes):
# backward_pass can only transpose linear computations, but the call_jaxpr embedded in
# remat contains primal (non-linear) equations too. Hence, we have to eliminate those
# (in this case via partial_eval) before we call into backward_pass again.
typed_call_jaxpr = core.ClosedJaxpr(call_jaxpr, [])
call_jaxpr = _close_jaxpr(call_jaxpr)
unknowns = map(is_undefined_primal, primals_in)
primal_jaxpr, tangent_jaxpr, out_unknowns = \
pe.partial_eval_jaxpr(typed_call_jaxpr, unknowns=unknowns, instantiate=True) # type: ignore

def do_transpose(primals_in, cotangents_in):
# NOTE: This is passing in undefined primals in place of tangent arguments, but it
# should all work out, because we're only computing the primal part here.
residuals = core.jaxpr_as_fun(primal_jaxpr)(*primals_in)[len(cotangents_in):]
# Now that we have a purely linear jaxpr, we can transpose it
cotangents_out = backward_pass(
tangent_jaxpr.jaxpr, reduce_axes, False, (), primals_in + residuals, cotangents_in)
# backward_pass will return cotangents computed for all invars, but some of them
# are residuals appended by partial eval, so we need to skip those before we return.
return cotangents_out[:len(primals_in)]

flat_args, in_tree_def = tree_flatten((primals_in, cotangents_in))
flat_do_transpose, out_tree = flatten_fun_nokwargs(lu.wrap_init(do_transpose), in_tree_def)
flat_cotangents_out = pe.remat_call_p.bind(flat_do_transpose, *flat_args, **params)
primal_jaxpr, tangent_jaxpr, _ = \
pe.partial_eval_jaxpr(call_jaxpr, unknowns=unknowns, instantiate=True) # type: ignore
args, in_tree_def = tree_flatten((primals_in, cotangents_in))
transpose = lu.hashable_partial(lu.wrap_init(_remat_transpose), primal_jaxpr,
tangent_jaxpr, reduce_axes)
flat_transpose, out_tree = flatten_fun_nokwargs(transpose, in_tree_def)
flat_cotangents_out = pe.remat_call_p.bind(flat_transpose, *args, **params)
return tree_unflatten(out_tree(), flat_cotangents_out)
primitive_transposes[pe.remat_call_p] = remat_transpose

def _remat_transpose(primal_jaxpr, tangent_jaxpr, reduce_axes,
primals_in, cotangents_in):
res = core.jaxpr_as_fun(primal_jaxpr)(*primals_in)[len(cotangents_in):]
cotangents_out = backward_pass(tangent_jaxpr.jaxpr, reduce_axes, False, (),
(*primals_in, *res), cotangents_in)
return cotangents_out[:len(primals_in)]

@weakref_lru_cache
def _close_jaxpr(jaxpr: core.Jaxpr) -> core.ClosedJaxpr:
return core.ClosedJaxpr(jaxpr, [])

@lu.transformation_with_aux
def nonzero_outputs(*args, **kwargs):
results = yield args, kwargs
Expand Down Expand Up @@ -680,7 +679,7 @@ def jvp_jaxpr(jaxpr, nonzeros, instantiate):
inst = tuple(instantiate) if isinstance(instantiate, list) else instantiate
return _jvp_jaxpr(jaxpr, tuple(nonzeros), inst)

@cache()
@weakref_lru_cache
def _jvp_jaxpr(jaxpr, nonzeros, instantiate):
assert len(jaxpr.in_avals) == len(nonzeros)
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
Expand Down
11 changes: 10 additions & 1 deletion tests/api_test.py
Expand Up @@ -4328,10 +4328,19 @@ def test_linearize_caching(self):
identity = jax.checkpoint(jax.jit(lambda x: 2 * x))
_, f_lin = jax.linearize(identity, 1.)
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
for _ in range(10):
for _ in range(20):
f_lin(1.).block_until_ready()
self.assertEqual(count[0], 1) # cached after first execution

def test_vjp_caching(self):
# https://github.com/google/jax/issues/9661
identity = jax.checkpoint(jax.jit(lambda x: 2 * x))
_, f_vjp = jax.vjp(identity, 1.)
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
for _ in range(20):
f_vjp(1.)[0].block_until_ready()
self.assertEqual(count[0], 2) # eval_jaxpr on fwd, backward_pass on bwd


class JaxprTest(jtu.JaxTestCase):

Expand Down

0 comments on commit 78cf4df

Please sign in to comment.