From 78cf4df21b61486fdc1c314c8da1d52b849deeca Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 25 Mar 2022 16:28:35 -0700 Subject: [PATCH] improve remat transpose caching (cf. #9661) --- jax/interpreters/ad.py | 47 +++++++++++++++++++++--------------------- tests/api_test.py | 11 +++++++++- 2 files changed, 33 insertions(+), 25 deletions(-) diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index 9bbf9c8b0672..6af9fd2abbe9 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -28,7 +28,7 @@ from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_aval, zeros_like_p, Zero) from jax._src.util import (unzip2, safe_map, safe_zip, split_list, wrap_name, - as_hashable_function, cache) + as_hashable_function, weakref_lru_cache) from jax.tree_util import register_pytree_node from jax import linear_util as lu from jax._src.api_util import flatten_fun, flatten_fun_nokwargs @@ -586,7 +586,8 @@ def traceable(num_primals, in_tree_def, *primals_and_tangents): def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes): all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts - fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes, False) + fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, + reduce_axes, False) fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def) if config.jax_experimental_name_stack: new_params = params @@ -603,31 +604,29 @@ def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes): def remat_transpose(params, call_jaxpr, primals_in, cotangents_in, cotangent_in_avals, reduce_axes): - # backward_pass can only transpose linear computations, but the call_jaxpr embedded in - # remat contains primal (non-linear) equations too. Hence, we have to eliminate those - # (in this case via partial_eval) before we call into backward_pass again. - typed_call_jaxpr = core.ClosedJaxpr(call_jaxpr, []) + call_jaxpr = _close_jaxpr(call_jaxpr) unknowns = map(is_undefined_primal, primals_in) - primal_jaxpr, tangent_jaxpr, out_unknowns = \ - pe.partial_eval_jaxpr(typed_call_jaxpr, unknowns=unknowns, instantiate=True) # type: ignore - - def do_transpose(primals_in, cotangents_in): - # NOTE: This is passing in undefined primals in place of tangent arguments, but it - # should all work out, because we're only computing the primal part here. - residuals = core.jaxpr_as_fun(primal_jaxpr)(*primals_in)[len(cotangents_in):] - # Now that we have a purely linear jaxpr, we can transpose it - cotangents_out = backward_pass( - tangent_jaxpr.jaxpr, reduce_axes, False, (), primals_in + residuals, cotangents_in) - # backward_pass will return cotangents computed for all invars, but some of them - # are residuals appended by partial eval, so we need to skip those before we return. - return cotangents_out[:len(primals_in)] - - flat_args, in_tree_def = tree_flatten((primals_in, cotangents_in)) - flat_do_transpose, out_tree = flatten_fun_nokwargs(lu.wrap_init(do_transpose), in_tree_def) - flat_cotangents_out = pe.remat_call_p.bind(flat_do_transpose, *flat_args, **params) + primal_jaxpr, tangent_jaxpr, _ = \ + pe.partial_eval_jaxpr(call_jaxpr, unknowns=unknowns, instantiate=True) # type: ignore + args, in_tree_def = tree_flatten((primals_in, cotangents_in)) + transpose = lu.hashable_partial(lu.wrap_init(_remat_transpose), primal_jaxpr, + tangent_jaxpr, reduce_axes) + flat_transpose, out_tree = flatten_fun_nokwargs(transpose, in_tree_def) + flat_cotangents_out = pe.remat_call_p.bind(flat_transpose, *args, **params) return tree_unflatten(out_tree(), flat_cotangents_out) primitive_transposes[pe.remat_call_p] = remat_transpose +def _remat_transpose(primal_jaxpr, tangent_jaxpr, reduce_axes, + primals_in, cotangents_in): + res = core.jaxpr_as_fun(primal_jaxpr)(*primals_in)[len(cotangents_in):] + cotangents_out = backward_pass(tangent_jaxpr.jaxpr, reduce_axes, False, (), + (*primals_in, *res), cotangents_in) + return cotangents_out[:len(primals_in)] + +@weakref_lru_cache +def _close_jaxpr(jaxpr: core.Jaxpr) -> core.ClosedJaxpr: + return core.ClosedJaxpr(jaxpr, []) + @lu.transformation_with_aux def nonzero_outputs(*args, **kwargs): results = yield args, kwargs @@ -680,7 +679,7 @@ def jvp_jaxpr(jaxpr, nonzeros, instantiate): inst = tuple(instantiate) if isinstance(instantiate, list) else instantiate return _jvp_jaxpr(jaxpr, tuple(nonzeros), inst) -@cache() +@weakref_lru_cache def _jvp_jaxpr(jaxpr, nonzeros, instantiate): assert len(jaxpr.in_avals) == len(nonzeros) f = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) diff --git a/tests/api_test.py b/tests/api_test.py index ab4a2e4c1428..9999e3be9545 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4328,10 +4328,19 @@ def test_linearize_caching(self): identity = jax.checkpoint(jax.jit(lambda x: 2 * x)) _, f_lin = jax.linearize(identity, 1.) with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 - for _ in range(10): + for _ in range(20): f_lin(1.).block_until_ready() self.assertEqual(count[0], 1) # cached after first execution + def test_vjp_caching(self): + # https://github.com/google/jax/issues/9661 + identity = jax.checkpoint(jax.jit(lambda x: 2 * x)) + _, f_vjp = jax.vjp(identity, 1.) + with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 + for _ in range(20): + f_vjp(1.)[0].block_until_ready() + self.assertEqual(count[0], 2) # eval_jaxpr on fwd, backward_pass on bwd + class JaxprTest(jtu.JaxTestCase):