From 57dd913834a54dee921047af7c78be0374e83c47 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 22 Nov 2019 10:53:11 -0800 Subject: [PATCH] Add experimental rematerialization decorator We want to allow users to control how reverse-mode autodiff saves values from the forward pass. In particular, we want it to be easy to signal that a function shouldn't have any of its intermediate residuals stored for the backward pass, and instead those values should be recomputed from the function's saved inputs. (This feature is especially handy for accelerators on which memory access is much more expensive than FLOPs are.) In JAX terms, since we implement reverse-mode as a composition of forward-mode, partial evaluation, and transposition, we want users to control how partial evaluation behaves. See https://github.com/google/jax/pull/1749 for more. Co-authored-by: Dougal Maclaurin --- docs/jax.lax.rst | 1 - jax/abstract_arrays.py | 14 ++- jax/api.py | 13 ++- jax/core.py | 3 +- jax/interpreters/ad.py | 84 ++++++++++++++++- jax/interpreters/partial_eval.py | 110 +++++++++++++++++++++- jax/interpreters/xla.py | 26 +++++ jax/lax/lax.py | 25 ++--- jax/lax/lax_control_flow.py | 4 +- jax/linear_util.py | 2 +- tests/api_test.py | 157 +++++++++++++++++++++++++++++++ 11 files changed, 410 insertions(+), 29 deletions(-) diff --git a/docs/jax.lax.rst b/docs/jax.lax.rst index 14aaadf37511..6d3122725109 100644 --- a/docs/jax.lax.rst +++ b/docs/jax.lax.rst @@ -104,7 +104,6 @@ Operators scatter scatter_add select - shaped_identity shift_left shift_right_arithmetic shift_right_logical diff --git a/jax/abstract_arrays.py b/jax/abstract_arrays.py index 28f151011015..1c0353c2cf5f 100644 --- a/jax/abstract_arrays.py +++ b/jax/abstract_arrays.py @@ -22,7 +22,7 @@ from . import core from . import ad_util from . import dtypes -from . util import prod +from . util import prod, partialmethod def concretization_err_msg(fun): @@ -145,6 +145,9 @@ def _len(self, ignored_tracer): def strip_weak_type(self): return ShapedArray(self.shape, self.dtype) if self.weak_type else self +def _forward_to_value(self, fun, ignored_tracer, *args): + return fun(self.val, *args) + class ConcreteArray(ShapedArray): __slots__ = ['val'] array_abstraction_level = 0 @@ -185,6 +188,15 @@ def str_short(self): def strip_weak_type(self): return ConcreteArray(self.val) if self.weak_type else self + _bool = _nonzero = partialmethod(_forward_to_value, bool) + _float = partialmethod(_forward_to_value, float) + _int = partialmethod(_forward_to_value, int) + if six.PY2: + _long = partialmethod(_forward_to_value, long) # noqa: F821 + _complex = partialmethod(_forward_to_value, complex) + _hex = partialmethod(_forward_to_value, hex) + _oct = partialmethod(_forward_to_value, oct) + class AbstractToken(core.AbstractValue): pass abstract_token = AbstractToken() diff --git a/jax/api.py b/jax/api.py index 212df9677267..09b9e8516dd9 100644 --- a/jax/api.py +++ b/jax/api.py @@ -55,7 +55,7 @@ from .lib import xla_bridge as xb from .lib.xla_bridge import (device_count, local_device_count, devices, local_devices, host_id, host_ids, host_count) -from .abstract_arrays import ShapedArray, raise_to_shaped +from .abstract_arrays import ConcreteArray, ShapedArray, raise_to_shaped from .interpreters import partial_eval as pe from .interpreters import xla from .interpreters import pxla @@ -1978,3 +1978,14 @@ def abstractify(x): out = pe.abstract_eval_fun(fun.call_wrapped, *map(abstractify, args_flat)) out = [ShapeDtypeStruct(x.shape, x.dtype) for x in out] return tree_unflatten(out_tree(), out) + + +def checkpoint(fun, concrete=False): + @wraps(fun) + def fun_remat(*args, **kwargs): + args_flat, in_tree = tree_flatten((args, kwargs)) + flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) + out_flat = pe.remat_call(flat_fun, *args_flat, concrete=concrete) + return tree_unflatten(out_tree(), out_flat) + return fun_remat +remat = checkpoint diff --git a/jax/core.py b/jax/core.py index 42c7f2e3f5c2..7c7c23aea0e5 100644 --- a/jax/core.py +++ b/jax/core.py @@ -597,7 +597,8 @@ def call_bind(primitive, f, *args, **params): def call_impl(f, *args, **params): - return f.call_wrapped(*args, **params) + del params # params parameterize the call primitive, not the function + return f.call_wrapped(*args) call_p = Primitive('call') diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index 88f92b3f95a6..9eba7b5bca98 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -140,6 +140,9 @@ def unpair_pval(pval): return (aval_1, const_1), (aval_2, const_2) def backward_pass(jaxpr, consts, freevar_vals, args, cotangents_in): + if all(ct is zero for ct in cotangents_in): + return [zero] * len(jaxpr.freevars), [zero] * len(jaxpr.invars) + def write_cotangent(v, ct): # assert v not in primal_env if ct is not None: @@ -159,13 +162,46 @@ def write_primal(v, val): primal_env[v] = val primal_env = {} + write_primal(core.unitvar, core.unit) map(write_primal, jaxpr.constvars, consts) map(write_primal, jaxpr.freevars, freevar_vals) map(write_primal, jaxpr.invars, args) + def is_linear(var): + if type(var) is Literal: + return False + else: + return primal_env.get(var, undefined_primal) is undefined_primal + + linear_eqns = [] + for eqn in jaxpr.eqns: + if not eqn.bound_subjaxprs: + if any(is_linear(v) for v in eqn.invars): + linear_eqns.append(eqn) + else: + in_vals = map(read_primal, eqn.invars) + ans = eqn.primitive.bind(*in_vals, **eqn.params) + if eqn.primitive.multiple_results: + map(write_primal, eqn.outvars, ans) + else: + write_primal(eqn.outvars[0], ans) + else: + (subjaxpr, const_vars, bound_vars), = eqn.bound_subjaxprs + if any(is_linear(v) for v in it.chain(eqn.invars, const_vars, bound_vars)): + linear_eqns.append(eqn) + sub_consts = map(read_primal, const_vars) + sub_freevar_vals = map(read_primal, bound_vars) + in_vals = map(read_primal, eqn.invars) + all_args, in_tree_def = tree_flatten((sub_consts, sub_freevar_vals, in_vals)) + fun = hashable_partial(wrap_init(_eval_primals), subjaxpr) + fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def) + out_flat = eqn.primitive.bind(fun, *all_args, **eqn.params) + ans = tree_unflatten(out_tree(), out_flat) + map(write_primal, eqn.outvars, ans) + ct_env = {} map(write_cotangent, jaxpr.outvars, cotangents_in) - for eqn in jaxpr.eqns[::-1]: + for eqn in linear_eqns[::-1]: invals = map(read_primal, eqn.invars) if eqn.primitive.multiple_results: cts_in = map(read_cotangent, eqn.outvars) @@ -187,6 +223,51 @@ def write_primal(v, val): cotangents_out = map(read_cotangent, jaxpr.invars) return freevar_cts, cotangents_out +def _eval_primals(jaxpr, consts, freevar_vals, args): + primal_env = {} + + def read_primal(v): + if type(v) is Literal: + return v.val + else: + return primal_env.get(v, undefined_primal) + + def write_primal(v, val): + if val is not undefined_primal: + primal_env[v] = val + + def is_linear(var): + if type(var) is Literal: + return False + else: + return primal_env.get(var, undefined_primal) is undefined_primal + + write_primal(core.unitvar, core.unit) + map(write_primal, jaxpr.constvars, consts) + map(write_primal, jaxpr.freevars, freevar_vals) + map(write_primal, jaxpr.invars, args) + for eqn in jaxpr.eqns: + if not eqn.bound_subjaxprs: + if not any(is_linear(v) for v in eqn.invars): + in_vals = map(read_primal, eqn.invars) + ans = eqn.primitive.bind(*in_vals, **eqn.params) + if eqn.primitive.multiple_results: + map(write_primal, eqn.outvars, ans) + else: + write_primal(eqn.outvars[0], ans) + else: + (subjaxpr, const_vars, bound_vars), = eqn.bound_subjaxprs + sub_consts = map(read_primal, const_vars) + sub_freevar_vals = map(read_primal, bound_vars) + in_vals = map(read_primal, eqn.invars) + all_args, in_tree_def = tree_flatten((sub_consts, sub_freevar_vals, in_vals)) + fun = hashable_partial(wrap_init(_eval_primals), subjaxpr) + fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def) + out_flat = eqn.primitive.bind(fun, *all_args, **eqn.params) + ans = tree_unflatten(out_tree(), out_flat) + map(write_primal, eqn.outvars, ans) + return map(read_primal, jaxpr.outvars) + class UndefinedPrimal(object): def __repr__(self): return '_' undefined_primal = UndefinedPrimal() @@ -460,6 +541,7 @@ def call_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct): out_flat = primitive.bind(fun, *all_args, **params) return tree_unflatten(out_tree(), out_flat) primitive_transposes[core.call_p] = partial(call_transpose, call_p) +primitive_transposes[pe.remat_call_p] = partial(call_transpose, pe.remat_call_p) def map_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct): all_args, in_tree_def = tree_flatten((consts, freevar_vals, args, ct)) diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 45cb51070fde..6952636c7276 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -24,7 +24,7 @@ from .. import core from .. import linear_util as lu -from ..abstract_arrays import ShapedArray, ConcreteArray +from ..abstract_arrays import ShapedArray, ConcreteArray, raise_to_shaped from ..linear_util import thunk, transformation, transformation_with_aux from ..util import unzip2, safe_zip, safe_map, toposort, partial, split_list from ..core import (Trace, Tracer, new_master, Jaxpr, Literal, get_aval, @@ -104,6 +104,8 @@ def process_primitive(self, primitive, tracers, params): return out_tracer def process_call(self, call_primitive, f, tracers, params): + if call_primitive in call_partial_eval_rules: + return call_partial_eval_rules[call_primitive](self, f, tracers, params) if call_primitive in map_primitives: return self.process_map(call_primitive, f, tracers, params) in_pvs, in_consts = unzip2([t.pval for t in tracers]) @@ -188,7 +190,6 @@ def todo(x): return out_tracers return out, todo - def _mapped_aval(aval): if aval is core.abstract_unit: return aval @@ -207,6 +208,8 @@ def _unmapped_aval(size, aval): raise TypeError(aval) map_primitives = set() +custom_partial_eval_rules = {} +call_partial_eval_rules = {} def partial_eval(f, trace, pvs): @@ -450,4 +453,105 @@ def fun(*vals): def _split_aval(unknown, aval): return (abstract_unit, aval) if unknown else (aval, abstract_unit) -custom_partial_eval_rules = {} + +remat_call_p = core.Primitive('remat_call') +remat_call = partial(core.call_bind, remat_call_p) +remat_call_p.def_custom_bind(remat_call) +remat_call_p.def_impl(core.call_impl) +remat_call_p.multiple_results = True + +def _remat_partial_eval(trace, f, tracers, params): + concrete = params['concrete'] + + # Unlike JaxprTrace.process_call, we want to form a jaxpr for the entirety of + # the function being called, not just for the unknown parts. To do that, we + # instantiate all the input tracers as constants in the jaxpr being formed. + # Those tracers might have concrete avals, and doing abstract interpretation + # on concrete avals engenders a tradeoff: it allows data-dependent Python + # control flow to work, but it can in some cases lead to redundant FLOPs (done + # both in the `bind` call below and the `core.jaxpr_as_fun` call). We use the + # `concrete` parameter to switch this behavior, and if `concrete` is False + # then we raise the avals to the Shaped level. + instantiated_tracers = map(trace.instantiate_const, tracers) + if not concrete: + instantiated_tracers = [ + JaxprTracer(trace, PartialVal((raise_to_shaped(t.pval[0]), unit)), t.recipe) + if type(t.pval[0]) is ConcreteArray else t for t in instantiated_tracers] + + # Using the instantiated tracers, run call_bind like JaxprTrace.process_call. + in_pvs, in_consts = unzip2(t.pval for t in instantiated_tracers) + fun, aux = partial_eval(f, trace, in_pvs) + out_flat = remat_call_p.bind(fun, *in_consts, **params) + out_pvs, jaxpr, env = aux() + out_pval_consts1, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)]) + out_pvals1 = [PartialVal((pv, const)) for pv, const in zip(out_pvs, out_pval_consts1)] + + # Since we traced with everything marked as unknown, but we need to know which + # outputs are known/unknown, we use partial_eval_jaxpr to get out_unknowns. + in_avals = [raise_to_shaped(pv) for pv in in_pvs] + out_avals = [raise_to_shaped(pv if pv is not None else core.get_aval(const)) + for pv, const in zip(out_pvs, out_pval_consts1)] + typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals) + in_unknowns = [t.pval[0] is not None for t in tracers] + jaxpr_1, jaxpr_2, out_unknowns = partial_eval_jaxpr(typed_jaxpr, in_unknowns, False) + num_res = len(jaxpr_1.out_avals) - len(jaxpr_2.out_avals) + + # First, we revise the jaxpr to be staged out not to output too much. + typed_jaxpr = _dce_jaxpr(typed_jaxpr, out_unknowns) + + # Next, we need values for the outputs that should be known. Since consts + # weren't passed through Python for evaluation, we need to evaluate jaxpr_1, + # minus the residual outputs that we don't need. When `concrete=True`, as an + # optimization we can avoid redoing *some* redundant FLOPs, namely those that + # produced concrete avals at the output, simply by using those as computed + # values. For the use case of reverse-mode ad, all the primal outputs should + # be concrete (thus not recomputed). + to_compute = [not uk and type(pv) is not ConcreteArray + for uk, pv in zip(out_unknowns, out_pvs)] + jaxpr_1 = _dce_jaxpr(jaxpr_1, to_compute + [False] * num_res) + _, in_consts = unzip2(t.pval for t in tracers) + out_pval_consts2 = core.jaxpr_as_fun(jaxpr_1)(*in_consts)[:-num_res or None] + out_pvals = map(_reconstruct_pval, out_pvals1, out_pval_consts2, out_unknowns) + + # Now that we have out_pvals, the rest is just like JaxprTrace.process_call. + const_tracers = map(trace.new_instantiated_const, consts) + bound_subjaxpr = (jaxpr, const_tracers, map(trace.full_raise, env)) + out_tracers = [JaxprTracer(trace, out_pval, None) for out_pval in out_pvals] + eqn = new_eqn_recipe(instantiated_tracers, out_tracers, remat_call_p, + (bound_subjaxpr,), params) + for t in out_tracers: + t.recipe = eqn + return out_tracers +call_partial_eval_rules[remat_call_p] = _remat_partial_eval + +def _dce_jaxpr(typed_jaxpr, outputs): + # This dead-code elimination is pretty rudimentary, and in particular doesn't + # nontrivially DCE through scan or other higher-order primitives. + jaxpr = typed_jaxpr.jaxpr + outvars, out_avals = jaxpr.outvars, typed_jaxpr.out_avals + out_pairs = [(var, aval) if output else (core.unitvar, core.abstract_unit) + for var, aval, output in zip(outvars, out_avals, outputs)] + new_outvars, new_out_avals = unzip2(out_pairs) + + needed_vars = set(new_outvars) + new_eqns = [] + for eqn in jaxpr.eqns[::-1]: + if set(eqn.outvars) & needed_vars: + new_eqns.append(eqn) + needed_vars.update(eqn.invars) + new_eqns = new_eqns[::-1] + + new_jaxpr = core.Jaxpr(jaxpr.constvars, jaxpr.freevars, jaxpr.invars, + new_outvars, new_eqns) + return core.TypedJaxpr(new_jaxpr, typed_jaxpr.literals, typed_jaxpr.in_avals, + new_out_avals) + +def _reconstruct_pval(pval1, const2, unknown): + pv1, const1 = pval1 + if unknown or pv1 is None: + return pval1 + else: + if type(pv1) is ConcreteArray: + return PartialVal((None, pv1.val)) + else: + return PartialVal((None, const2)) diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index 0a67b5b806bb..607ce116c059 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -762,6 +762,32 @@ def _device_put_impl(x, device=None): ad.deflinear(device_put_p, lambda cotangent, **kwargs: [cotangent]) +def _remat_translation_rule(c, jaxpr, axis_env, const_nodes, freevar_nodes, in_nodes, + backend=None, device=None, concrete=None): + # This looks a lot like _xla_call_translation_rule, except for a widget we use + # to foil CSE. + del device, concrete # Unused. + subc = xb.make_computation_builder("remat_call_subcomputation") + consts = [subc.ParameterWithShape(c.GetShape(n)) for n in const_nodes] + freevars = [subc.ParameterWithShape(c.GetShape(n)) for n in freevar_nodes] + args = [subc.ParameterWithShape(c.GetShape(n)) for n in in_nodes] + args = [_foil_cse(subc, x) for x in args] + out_nodes = jaxpr_subcomp(subc, jaxpr, backend, axis_env, consts, freevars, *args) + subc = subc.Build(subc.Tuple(*out_nodes)) + return c.Call(subc, list(const_nodes) + list(freevar_nodes) + list(in_nodes)) +call_translations[pe.remat_call_p] = _remat_translation_rule + +def _foil_cse(c, x): + rng = c.RngNormal(c.Constant(onp.array(0, dtype=onp.float32)), + c.Constant(onp.array(1, dtype=onp.float32)), + []) + pred = c.Lt(rng, c.Constant(onp.finfo(onp.float32).max)) + xla_shape = c.GetShape(x) + shape, dtype = xla_shape.dimensions(), xla_shape.numpy_dtype() + zero = c.Broadcast(c.Constant(onp.array(0, dtype=dtype)), shape) + return c.Select(pred, x, zero) + + ### lazy constants class DeviceConstant(DeviceArray): diff --git a/jax/lax/lax.py b/jax/lax/lax.py index d4897b621ffa..0c84b1dce1cd 100644 --- a/jax/lax/lax.py +++ b/jax/lax/lax.py @@ -1035,9 +1035,6 @@ def sort_key_val(keys, values, dimension=-1): def tie_in(x, y): return tie_in_p.bind(x, y) -def shaped_identity(x): - return shaped_identity_p.bind(x, shape=x.shape) - def full(shape, fill_value, dtype=None): """Returns an array of `shape` filled with `fill_value`. @@ -1472,7 +1469,7 @@ def zeros_like_array(x): def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None): prim = Primitive(name) prim.def_impl(partial(xla.apply_primitive, prim)) - prim.def_abstract_eval(partial(standard_abstract_eval, shape_rule, dtype_rule)) + prim.def_abstract_eval(partial(standard_abstract_eval, prim, shape_rule, dtype_rule)) xla.translations[prim] = translation_rule or partial(standard_translate, name) return prim @@ -1480,17 +1477,17 @@ def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None): def standard_reduction_primitive(shape_rule, dtype_rule, name, translation_rule=None): prim = Primitive(name) prim.def_impl(partial(xla.apply_primitive, prim)) - prim.def_abstract_eval(partial(standard_abstract_eval, shape_rule, dtype_rule)) + prim.def_abstract_eval(partial(standard_abstract_eval, prim, shape_rule, dtype_rule)) xla.reduction_translations[prim] = translation_rule or partial(standard_translate, name) return prim -def standard_abstract_eval(shape_rule, dtype_rule, *args, **kwargs): +def standard_abstract_eval(prim, shape_rule, dtype_rule, *args, **kwargs): assert all(isinstance(arg, UnshapedArray) for arg in args), args least_specialized = _max( map(type, args), key=operator.attrgetter('array_abstraction_level')) if least_specialized is ConcreteArray: - return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs)) + return ConcreteArray(prim.impl(*[x.val for x in args], **kwargs)) elif least_specialized is ShapedArray: return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs)) elif least_specialized is UnshapedArray: @@ -3933,7 +3930,7 @@ def _sort_batch_rule(batched_args, batch_dims, dimension): def _sort_key_val_abstract_eval(keys, values, dimension): - return keys, values + return raise_to_shaped(keys), raise_to_shaped(values) def _sort_key_val_jvp(primals, tangents, dimension): # NOTE(mattjj): this re-sorts three times, but if we had a variadic @@ -3991,7 +3988,7 @@ def _sort_key_val_batch_rule(batched_args, batch_dims, dimension): new_dimension = dimension + (keys_bdim <= dimension) return sort_key_val(keys, new_values, new_dimension), (keys_bdim, keys_bdim) else: - raise Exception # unreachable + assert False # unreachable sort_key_val_p = Primitive('sort_key_val') sort_key_val_p.multiple_results = True @@ -4013,21 +4010,13 @@ def _tie_in_batch_rule(batched_args, batch_dims): tie_in_p = Primitive('tie_in') tie_in_p.def_impl(lambda x, y: y) -tie_in_p.def_abstract_eval(lambda x, y: y) +tie_in_p.def_abstract_eval(lambda x, y: raise_to_shaped(y)) xla.translations[tie_in_p] = lambda c, x, y: y ad.deflinear(tie_in_p, _tie_in_transpose_rule) batching.primitive_batchers[tie_in_p] = _tie_in_batch_rule masking.shape_rules[tie_in_p] = lambda shape_exprs: shape_exprs[1] masking.masking_rules[tie_in_p] = lambda vals, logical_shapes: vals[1] -shaped_identity_p = Primitive('shape_id') -shaped_identity_p.def_impl(lambda x, shape: x) -shaped_identity_p.def_abstract_eval(lambda x, shape: x) -xla.translations[shaped_identity_p] = lambda c, x, shape: x -ad.deflinear(shaped_identity_p, lambda t, shape: [shaped_identity(t)]) -batching.primitive_batchers[shaped_identity_p] = \ - lambda a, d, shape: (shaped_identity(a[0]), d[0]) - ### constants diff --git a/jax/lax/lax_control_flow.py b/jax/lax/lax_control_flow.py index 7badc7663b1c..7ec583c7670e 100644 --- a/jax/lax/lax_control_flow.py +++ b/jax/lax/lax_control_flow.py @@ -1084,7 +1084,7 @@ def linearize_and_solve(x, b): def _root_abstract_eval(*args, **kwargs): - return args[sum(kwargs['const_lengths']):] + return _map(raise_to_shaped, args[sum(kwargs['const_lengths']):]) def _root_impl(*args, **kwargs): @@ -1253,7 +1253,7 @@ def custom_linear_solve( def _linear_solve_abstract_eval(*args, **kwargs): - return args[sum(kwargs['const_lengths']):] + return _map(raise_to_shaped, args[sum(kwargs['const_lengths']):]) def _custom_linear_solve_impl(*args, **kwargs): diff --git a/jax/linear_util.py b/jax/linear_util.py index 8cf970ff14ba..b462774edbdf 100644 --- a/jax/linear_util.py +++ b/jax/linear_util.py @@ -148,8 +148,8 @@ def call_wrapped(self, *args, **kwargs): gen = gen(*(gen_args + tuple(args)), **kwargs) args, kwargs = next(gen) stack.append((gen, out_store)) + gen = None - del gen ans = self.f(*args, **dict(self.params, **kwargs)) del args while stack: diff --git a/tests/api_test.py b/tests/api_test.py index 4cd5d116296a..b77c3c400842 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1303,6 +1303,163 @@ def test_grad_of_jit_compilation_caching(self): self.assertAllClose(ans1, onp.cos(2.), check_dtypes=False) self.assertAllClose(ans2, onp.cos(3.), check_dtypes=False) + def test_remat_basic(self): + @api.remat + def g(x): + return lax.sin(x), 3. + + def f(x): + x, _ = g(x) + return x + + ans = f(2.) + expected = onp.sin(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans, f_lin = api.linearize(f, 2.) + expected = onp.sin(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = f_lin(3.) + expected = onp.cos(2.) * 3. + self.assertAllClose(ans, expected, check_dtypes=False) + + jaxpr = api.make_jaxpr(f_lin)(3.) + self.assertIn('sin', str(jaxpr)) + + def test_remat_grad_python_control_flow(self): + @partial(api.remat, concrete=True) + def g(x): + if x > 0: + return lax.sin(x), 3. + else: + return lax.cos(x), 4. + + def f(x): + x, _ = g(x) + return x + + ans = f(2.) + expected = onp.sin(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(f)(2.) + expected = onp.cos(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_remat_jit(self): + @api.remat + def g(x): + return lax.sin(lax.sin(x)) + + def f_(x): + return g(x) + f = api.jit(f_) + + ans = f(2.) + expected = onp.sin(onp.sin(2.)) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(f)(2.) + expected = onp.cos(onp.sin(2.)) * onp.cos(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.jit(api.grad(f_))(2.) + expected = onp.cos(onp.sin(2.)) * onp.cos(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_remat_vmap(self): + @api.remat + def g(x): + return lax.sin(lax.sin(x)) + + x = onp.arange(3.) + + ans = api.vmap(g)(x) + expected = onp.sin(onp.sin(x)) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.jacfwd(g)(x) + expected = onp.diag(onp.cos(onp.sin(x)) * onp.cos(x)) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.jacrev(g)(x) + expected = onp.diag(onp.cos(onp.sin(x)) * onp.cos(x)) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_remat_higher_order_autodiff(self): + def f(x): + return lax.cos(lax.sin(x)) + g = api.remat(f) + + ans = api.grad(api.grad(g))(3.) + expected = api.grad(api.grad(f))(3.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_remat_scan(self): + to_scan = lambda c, x: (np.sin(c), None) + + def f_noremat(x): + y, _ = lax.scan(to_scan, x, onp.arange(3.)) + return y + + def f_yesremat(x): + y, _ = lax.scan(api.remat(to_scan), x, onp.arange(3.)) + return y + + ans = f_yesremat(4.) + expected = f_noremat(4.) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(f_yesremat)(4.) + expected = api.grad(f_noremat)(4.) + self.assertAllClose(ans, expected, check_dtypes=False) + + jaxpr = api.make_jaxpr(api.linearize(f_yesremat, 4.)[1])(1.) + scan_eqn, = jaxpr.eqns + self.assertIn(' sin ', str(scan_eqn.params['jaxpr'])) + + jaxpr = api.make_jaxpr(api.vjp(f_yesremat, 4.)[1])(1.) + scan_eqn, = jaxpr.eqns + self.assertIn(' cos ', str(scan_eqn.params['jaxpr'])) + + def test_remat_no_redundant_flops(self): + # see https://github.com/google/jax/pull/1749#issuecomment-558267584 + + @api.jit + def g(x): + return f(2., x) + + @api.remat + def f(x, y): + return np.sin(x) * y + + # We swap out sin_p's impl rule to count how many times it's invoked + called = [] + sin_impl = lax.sin_p.impl + try: + lax.sin_p.def_impl(lambda x: called.append(1) or sin_impl(x)) + api.grad(g)(3.) + finally: + lax.sin_p.def_impl(sin_impl) + num_calls = len(called) + self.assertEqual(num_calls, 1) + + def test_remat_binomial_checkpointing(self): + def binom_checkpoint(funs): + if len(funs) == 1: + return funs[0] + else: + f1 = binom_checkpoint(funs[:len(funs)//2]) + f2 = binom_checkpoint(funs[len(funs)//2:]) + return api.remat(lambda x: f1(f2(x))) + + f1 = binom_checkpoint([np.sin, np.sin, np.sin, np.sin]) + f2 = lambda x: np.sin(np.sin(np.sin(np.sin(x)))) + x = 4. + self.assertAllClose(f1(x), f2(x), check_dtypes=False) + self.assertAllClose(api.grad(f1)(x), api.grad(f2)(x), check_dtypes=False) + if __name__ == '__main__': absltest.main()