Skip to content

Commit

Permalink
[remat] Change remat lowering to XLA::Conditional (#2391)
Browse files Browse the repository at this point in the history
* [remat] Change remat lowering to XLA::Conditional

`jax.remat` creates rematerializing passes that don't have data dependencies on
the actual loss-computing forward pass. This means that the XLA scheduler was
free to schedule the remat forward pass before the loss-computing pass,
defeating the goal of saving accelerator memory with `jax.remat`.

In practice, it sometimes did for my workloads.

This change expresses the lowering of remat_call(f) as:
Conditional(true, inputs, f, inputs, dummy_f).

In the common case of `jax.grad(jax.remat(f))`, the content of the
lowered remat_call are both the forwards & backwards; that is, the
incoming cotangents are part of the args.

Additionally, Conditional (AFAIK) is un-inlineable in the sense that it
doesn't execute until all its inputs (e.g. cotangents!) are available.

Downsides:

- AFAICT, we can no longer interleave computation in/outside the
  rematerialized block.
- Potentially, lower performance. I do not observe this in my tests.

* provide no replication info for subcomputation params
  • Loading branch information
trevorcai committed Mar 11, 2020
1 parent 2dfeaeb commit 620bf43
Showing 1 changed file with 30 additions and 23 deletions.
53 changes: 30 additions & 23 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,31 +973,38 @@ def _device_put_impl(x, device=None):
def _remat_translation_rule(c, axis_env, in_nodes,
name_stack, backend, name, call_jaxpr,
device=None, concrete=None):
# This looks a lot like _xla_call_translation_rule, except for a widget we use
# to foil CSE.
"""Lower remat to a Conditional which always returns true. This:
1. Circumvents common subexpression elimination.
2. In common case of `jax.grad(jax.remat(f))`, ensures the remat blocks
occur after the primal blocks, because cotangent is an input to the
Conditional."""
del device, concrete # Unused.
subc = xb.make_computation_builder("remat_call_subcomputation")
args = [subc.ParameterWithShape(c.GetShape(n)) for n in in_nodes]
args = _foil_cse(subc, args)
out_nodes = jaxpr_subcomp(subc, call_jaxpr, backend, axis_env, (),
extend_name_stack(name_stack, wrap_name(name, 'remat')), *args)
subc = subc.Build(subc.Tuple(*out_nodes))
return c.Call(subc, list(in_nodes))
call_translations[pe.remat_call_p] = _remat_translation_rule

def _foil_cse(c, args):
# Fake condition which always selects True branch.
rng = c.RngUniform(c.Constant(onp.array(0, dtype=onp.float32)),
c.Constant(onp.array(1, dtype=onp.float32)),
[])
pred = c.Lt(rng, c.Constant(onp.array(2, dtype=onp.float32)))
outs = []
for x in args:
xla_shape = c.GetShape(x)
if xla_shape.is_tuple():
assert not xla_shape.tuple_shapes()
outs.append(x)
else:
shape, dtype = xla_shape.dimensions(), xla_shape.numpy_dtype()
zero = c.Broadcast(c.Constant(onp.array(0, dtype=dtype)), shape)
outs.append(c.Select(pred, x, zero))
return outs

true_op = c.Tuple(*in_nodes)
remat_subc = xb.make_computation_builder("remat_call_subcomputation")
input_op = remat_subc.ParameterWithShape(c.GetShape(true_op), replicated=[])
args = [remat_subc.GetTupleElement(input_op, i) for i in range(len(in_nodes))]
out_nodes = jaxpr_subcomp(remat_subc, call_jaxpr, backend, axis_env, (),
extend_name_stack(name_stack, wrap_name(name, 'remat')),
*args)
out_node_shapes = [remat_subc.GetShape(o) for o in out_nodes]
remat_subc = remat_subc.Build(remat_subc.Tuple(*out_nodes))

false_op = true_op
dummy_subc = xb.make_computation_builder("remat_call_dummy_subcomputation")
dummy_subc.ParameterWithShape(c.GetShape(false_op), replicated=[])

def zeros(xla_shape):
shape, dtype = xla_shape.dimensions(), xla_shape.numpy_dtype()
zero = dummy_subc.Constant(onp.array(0, dtype=dtype))
return dummy_subc.Broadcast(zero, shape)
out_nodes = [zeros(s) for s in out_node_shapes]
dummy_subc = dummy_subc.Build(dummy_subc.Tuple(*out_nodes))

return c.Conditional(pred, true_op, remat_subc, false_op, dummy_subc)
call_translations[pe.remat_call_p] = _remat_translation_rule

0 comments on commit 620bf43

Please sign in to comment.