Skip to content


Add experimental rematerialization decorator
Browse files Browse the repository at this point in the history
We want to allow users to control how reverse-mode autodiff saves values
from the forward pass. In particular, we want it to be easy to signal
that a function shouldn't have any of its intermediate residuals stored
for the backward pass, and instead those values should be recomputed
from the function's saved inputs. (This feature is especially handy for
accelerators on which memory access is much more expensive than FLOPs
are.) In JAX terms, since we implement reverse-mode as a composition of
forward-mode, partial evaluation, and transposition, we want users to
control how partial evaluation behaves.

See #1749 for more.

Co-authored-by: Dougal Maclaurin <>
  • Loading branch information
mattjj and dougalm committed Nov 26, 2019
1 parent 22b7c96 commit 57dd913
Show file tree
Hide file tree
Showing 11 changed files with 410 additions and 29 deletions.
1 change: 0 additions & 1 deletion docs/jax.lax.rst
Expand Up @@ -104,7 +104,6 @@ Operators
Expand Down
14 changes: 13 additions & 1 deletion jax/
Expand Up @@ -22,7 +22,7 @@
from . import core
from . import ad_util
from . import dtypes
from . util import prod
from . util import prod, partialmethod

def concretization_err_msg(fun):
Expand Down Expand Up @@ -145,6 +145,9 @@ def _len(self, ignored_tracer):
def strip_weak_type(self):
return ShapedArray(self.shape, self.dtype) if self.weak_type else self

def _forward_to_value(self, fun, ignored_tracer, *args):
return fun(self.val, *args)

class ConcreteArray(ShapedArray):
__slots__ = ['val']
array_abstraction_level = 0
Expand Down Expand Up @@ -185,6 +188,15 @@ def str_short(self):
def strip_weak_type(self):
return ConcreteArray(self.val) if self.weak_type else self

_bool = _nonzero = partialmethod(_forward_to_value, bool)
_float = partialmethod(_forward_to_value, float)
_int = partialmethod(_forward_to_value, int)
if six.PY2:
_long = partialmethod(_forward_to_value, long) # noqa: F821
_complex = partialmethod(_forward_to_value, complex)
_hex = partialmethod(_forward_to_value, hex)
_oct = partialmethod(_forward_to_value, oct)

class AbstractToken(core.AbstractValue): pass

abstract_token = AbstractToken()
Expand Down
13 changes: 12 additions & 1 deletion jax/
Expand Up @@ -55,7 +55,7 @@
from .lib import xla_bridge as xb
from .lib.xla_bridge import (device_count, local_device_count, devices, local_devices,
host_id, host_ids, host_count)
from .abstract_arrays import ShapedArray, raise_to_shaped
from .abstract_arrays import ConcreteArray, ShapedArray, raise_to_shaped
from .interpreters import partial_eval as pe
from .interpreters import xla
from .interpreters import pxla
Expand Down Expand Up @@ -1978,3 +1978,14 @@ def abstractify(x):
out = pe.abstract_eval_fun(fun.call_wrapped, *map(abstractify, args_flat))
out = [ShapeDtypeStruct(x.shape, x.dtype) for x in out]
return tree_unflatten(out_tree(), out)

def checkpoint(fun, concrete=False):
def fun_remat(*args, **kwargs):
args_flat, in_tree = tree_flatten((args, kwargs))
flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
out_flat = pe.remat_call(flat_fun, *args_flat, concrete=concrete)
return tree_unflatten(out_tree(), out_flat)
return fun_remat
remat = checkpoint
3 changes: 2 additions & 1 deletion jax/
Expand Up @@ -597,7 +597,8 @@ def call_bind(primitive, f, *args, **params):

def call_impl(f, *args, **params):
return f.call_wrapped(*args, **params)
del params # params parameterize the call primitive, not the function
return f.call_wrapped(*args)

call_p = Primitive('call')
Expand Down
84 changes: 83 additions & 1 deletion jax/interpreters/
Expand Up @@ -140,6 +140,9 @@ def unpair_pval(pval):
return (aval_1, const_1), (aval_2, const_2)

def backward_pass(jaxpr, consts, freevar_vals, args, cotangents_in):
if all(ct is zero for ct in cotangents_in):
return [zero] * len(jaxpr.freevars), [zero] * len(jaxpr.invars)

def write_cotangent(v, ct):
# assert v not in primal_env
if ct is not None:
Expand All @@ -159,13 +162,46 @@ def write_primal(v, val):
primal_env[v] = val

primal_env = {}
write_primal(core.unitvar, core.unit)
map(write_primal, jaxpr.constvars, consts)
map(write_primal, jaxpr.freevars, freevar_vals)
map(write_primal, jaxpr.invars, args)

def is_linear(var):
if type(var) is Literal:
return False
return primal_env.get(var, undefined_primal) is undefined_primal

linear_eqns = []
for eqn in jaxpr.eqns:
if not eqn.bound_subjaxprs:
if any(is_linear(v) for v in eqn.invars):
in_vals = map(read_primal, eqn.invars)
ans = eqn.primitive.bind(*in_vals, **eqn.params)
if eqn.primitive.multiple_results:
map(write_primal, eqn.outvars, ans)
write_primal(eqn.outvars[0], ans)
(subjaxpr, const_vars, bound_vars), = eqn.bound_subjaxprs
if any(is_linear(v) for v in it.chain(eqn.invars, const_vars, bound_vars)):
sub_consts = map(read_primal, const_vars)
sub_freevar_vals = map(read_primal, bound_vars)
in_vals = map(read_primal, eqn.invars)
all_args, in_tree_def = tree_flatten((sub_consts, sub_freevar_vals, in_vals))
fun = hashable_partial(wrap_init(_eval_primals), subjaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
out_flat = eqn.primitive.bind(fun, *all_args, **eqn.params)
ans = tree_unflatten(out_tree(), out_flat)
map(write_primal, eqn.outvars, ans)

ct_env = {}
map(write_cotangent, jaxpr.outvars, cotangents_in)
for eqn in jaxpr.eqns[::-1]:
for eqn in linear_eqns[::-1]:
invals = map(read_primal, eqn.invars)
if eqn.primitive.multiple_results:
cts_in = map(read_cotangent, eqn.outvars)
Expand All @@ -187,6 +223,51 @@ def write_primal(v, val):
cotangents_out = map(read_cotangent, jaxpr.invars)
return freevar_cts, cotangents_out

def _eval_primals(jaxpr, consts, freevar_vals, args):
primal_env = {}

def read_primal(v):
if type(v) is Literal:
return v.val
return primal_env.get(v, undefined_primal)

def write_primal(v, val):
if val is not undefined_primal:
primal_env[v] = val

def is_linear(var):
if type(var) is Literal:
return False
return primal_env.get(var, undefined_primal) is undefined_primal

write_primal(core.unitvar, core.unit)
map(write_primal, jaxpr.constvars, consts)
map(write_primal, jaxpr.freevars, freevar_vals)
map(write_primal, jaxpr.invars, args)
for eqn in jaxpr.eqns:
if not eqn.bound_subjaxprs:
if not any(is_linear(v) for v in eqn.invars):
in_vals = map(read_primal, eqn.invars)
ans = eqn.primitive.bind(*in_vals, **eqn.params)
if eqn.primitive.multiple_results:
map(write_primal, eqn.outvars, ans)
write_primal(eqn.outvars[0], ans)
(subjaxpr, const_vars, bound_vars), = eqn.bound_subjaxprs
sub_consts = map(read_primal, const_vars)
sub_freevar_vals = map(read_primal, bound_vars)
in_vals = map(read_primal, eqn.invars)
all_args, in_tree_def = tree_flatten((sub_consts, sub_freevar_vals, in_vals))
fun = hashable_partial(wrap_init(_eval_primals), subjaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
out_flat = eqn.primitive.bind(fun, *all_args, **eqn.params)
ans = tree_unflatten(out_tree(), out_flat)
map(write_primal, eqn.outvars, ans)
return map(read_primal, jaxpr.outvars)

class UndefinedPrimal(object):
def __repr__(self): return '_'
undefined_primal = UndefinedPrimal()
Expand Down Expand Up @@ -460,6 +541,7 @@ def call_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
out_flat = primitive.bind(fun, *all_args, **params)
return tree_unflatten(out_tree(), out_flat)
primitive_transposes[core.call_p] = partial(call_transpose, call_p)
primitive_transposes[pe.remat_call_p] = partial(call_transpose, pe.remat_call_p)

def map_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
all_args, in_tree_def = tree_flatten((consts, freevar_vals, args, ct))
Expand Down
110 changes: 107 additions & 3 deletions jax/interpreters/
Expand Up @@ -24,7 +24,7 @@

from .. import core
from .. import linear_util as lu
from ..abstract_arrays import ShapedArray, ConcreteArray
from ..abstract_arrays import ShapedArray, ConcreteArray, raise_to_shaped
from ..linear_util import thunk, transformation, transformation_with_aux
from ..util import unzip2, safe_zip, safe_map, toposort, partial, split_list
from ..core import (Trace, Tracer, new_master, Jaxpr, Literal, get_aval,
Expand Down Expand Up @@ -104,6 +104,8 @@ def process_primitive(self, primitive, tracers, params):
return out_tracer

def process_call(self, call_primitive, f, tracers, params):
if call_primitive in call_partial_eval_rules:
return call_partial_eval_rules[call_primitive](self, f, tracers, params)
if call_primitive in map_primitives:
return self.process_map(call_primitive, f, tracers, params)
in_pvs, in_consts = unzip2([t.pval for t in tracers])
Expand Down Expand Up @@ -188,7 +190,6 @@ def todo(x):
return out_tracers
return out, todo

def _mapped_aval(aval):
if aval is core.abstract_unit:
return aval
Expand All @@ -207,6 +208,8 @@ def _unmapped_aval(size, aval):
raise TypeError(aval)

map_primitives = set()
custom_partial_eval_rules = {}
call_partial_eval_rules = {}

def partial_eval(f, trace, pvs):
Expand Down Expand Up @@ -450,4 +453,105 @@ def fun(*vals):
def _split_aval(unknown, aval):
return (abstract_unit, aval) if unknown else (aval, abstract_unit)

custom_partial_eval_rules = {}

remat_call_p = core.Primitive('remat_call')
remat_call = partial(core.call_bind, remat_call_p)
remat_call_p.multiple_results = True

def _remat_partial_eval(trace, f, tracers, params):
concrete = params['concrete']

# Unlike JaxprTrace.process_call, we want to form a jaxpr for the entirety of
# the function being called, not just for the unknown parts. To do that, we
# instantiate all the input tracers as constants in the jaxpr being formed.
# Those tracers might have concrete avals, and doing abstract interpretation
# on concrete avals engenders a tradeoff: it allows data-dependent Python
# control flow to work, but it can in some cases lead to redundant FLOPs (done
# both in the `bind` call below and the `core.jaxpr_as_fun` call). We use the
# `concrete` parameter to switch this behavior, and if `concrete` is False
# then we raise the avals to the Shaped level.
instantiated_tracers = map(trace.instantiate_const, tracers)
if not concrete:
instantiated_tracers = [
JaxprTracer(trace, PartialVal((raise_to_shaped(t.pval[0]), unit)), t.recipe)
if type(t.pval[0]) is ConcreteArray else t for t in instantiated_tracers]

# Using the instantiated tracers, run call_bind like JaxprTrace.process_call.
in_pvs, in_consts = unzip2(t.pval for t in instantiated_tracers)
fun, aux = partial_eval(f, trace, in_pvs)
out_flat = remat_call_p.bind(fun, *in_consts, **params)
out_pvs, jaxpr, env = aux()
out_pval_consts1, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
out_pvals1 = [PartialVal((pv, const)) for pv, const in zip(out_pvs, out_pval_consts1)]

# Since we traced with everything marked as unknown, but we need to know which
# outputs are known/unknown, we use partial_eval_jaxpr to get out_unknowns.
in_avals = [raise_to_shaped(pv) for pv in in_pvs]
out_avals = [raise_to_shaped(pv if pv is not None else core.get_aval(const))
for pv, const in zip(out_pvs, out_pval_consts1)]
typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
in_unknowns = [t.pval[0] is not None for t in tracers]
jaxpr_1, jaxpr_2, out_unknowns = partial_eval_jaxpr(typed_jaxpr, in_unknowns, False)
num_res = len(jaxpr_1.out_avals) - len(jaxpr_2.out_avals)

# First, we revise the jaxpr to be staged out not to output too much.
typed_jaxpr = _dce_jaxpr(typed_jaxpr, out_unknowns)

# Next, we need values for the outputs that should be known. Since consts
# weren't passed through Python for evaluation, we need to evaluate jaxpr_1,
# minus the residual outputs that we don't need. When `concrete=True`, as an
# optimization we can avoid redoing *some* redundant FLOPs, namely those that
# produced concrete avals at the output, simply by using those as computed
# values. For the use case of reverse-mode ad, all the primal outputs should
# be concrete (thus not recomputed).
to_compute = [not uk and type(pv) is not ConcreteArray
for uk, pv in zip(out_unknowns, out_pvs)]
jaxpr_1 = _dce_jaxpr(jaxpr_1, to_compute + [False] * num_res)
_, in_consts = unzip2(t.pval for t in tracers)
out_pval_consts2 = core.jaxpr_as_fun(jaxpr_1)(*in_consts)[:-num_res or None]
out_pvals = map(_reconstruct_pval, out_pvals1, out_pval_consts2, out_unknowns)

# Now that we have out_pvals, the rest is just like JaxprTrace.process_call.
const_tracers = map(trace.new_instantiated_const, consts)
bound_subjaxpr = (jaxpr, const_tracers, map(trace.full_raise, env))
out_tracers = [JaxprTracer(trace, out_pval, None) for out_pval in out_pvals]
eqn = new_eqn_recipe(instantiated_tracers, out_tracers, remat_call_p,
(bound_subjaxpr,), params)
for t in out_tracers:
t.recipe = eqn
return out_tracers
call_partial_eval_rules[remat_call_p] = _remat_partial_eval

def _dce_jaxpr(typed_jaxpr, outputs):
# This dead-code elimination is pretty rudimentary, and in particular doesn't
# nontrivially DCE through scan or other higher-order primitives.
jaxpr = typed_jaxpr.jaxpr
outvars, out_avals = jaxpr.outvars, typed_jaxpr.out_avals
out_pairs = [(var, aval) if output else (core.unitvar, core.abstract_unit)
for var, aval, output in zip(outvars, out_avals, outputs)]
new_outvars, new_out_avals = unzip2(out_pairs)

needed_vars = set(new_outvars)
new_eqns = []
for eqn in jaxpr.eqns[::-1]:
if set(eqn.outvars) & needed_vars:
new_eqns = new_eqns[::-1]

new_jaxpr = core.Jaxpr(jaxpr.constvars, jaxpr.freevars, jaxpr.invars,
new_outvars, new_eqns)
return core.TypedJaxpr(new_jaxpr, typed_jaxpr.literals, typed_jaxpr.in_avals,

def _reconstruct_pval(pval1, const2, unknown):
pv1, const1 = pval1
if unknown or pv1 is None:
return pval1
if type(pv1) is ConcreteArray:
return PartialVal((None, pv1.val))
return PartialVal((None, const2))
26 changes: 26 additions & 0 deletions jax/interpreters/
Expand Up @@ -762,6 +762,32 @@ def _device_put_impl(x, device=None):
ad.deflinear(device_put_p, lambda cotangent, **kwargs: [cotangent])

def _remat_translation_rule(c, jaxpr, axis_env, const_nodes, freevar_nodes, in_nodes,
backend=None, device=None, concrete=None):
# This looks a lot like _xla_call_translation_rule, except for a widget we use
# to foil CSE.
del device, concrete # Unused.
subc = xb.make_computation_builder("remat_call_subcomputation")
consts = [subc.ParameterWithShape(c.GetShape(n)) for n in const_nodes]
freevars = [subc.ParameterWithShape(c.GetShape(n)) for n in freevar_nodes]
args = [subc.ParameterWithShape(c.GetShape(n)) for n in in_nodes]
args = [_foil_cse(subc, x) for x in args]
out_nodes = jaxpr_subcomp(subc, jaxpr, backend, axis_env, consts, freevars, *args)
subc = subc.Build(subc.Tuple(*out_nodes))
return c.Call(subc, list(const_nodes) + list(freevar_nodes) + list(in_nodes))
call_translations[pe.remat_call_p] = _remat_translation_rule

def _foil_cse(c, x):
rng = c.RngNormal(c.Constant(onp.array(0, dtype=onp.float32)),
c.Constant(onp.array(1, dtype=onp.float32)),
pred = c.Lt(rng, c.Constant(onp.finfo(onp.float32).max))
xla_shape = c.GetShape(x)
shape, dtype = xla_shape.dimensions(), xla_shape.numpy_dtype()
zero = c.Broadcast(c.Constant(onp.array(0, dtype=dtype)), shape)
return c.Select(pred, x, zero)

### lazy constants

class DeviceConstant(DeviceArray):
Expand Down

0 comments on commit 57dd913

Please sign in to comment.