Skip to content

Commit

Permalink
Add translation rule for optimization barrier.
Browse files Browse the repository at this point in the history
Also adds a translation rule for remat that uses the new optimization barrier
op. If you find errors, consider disabling the remat lowering using
`jax_remat_opt_barrier` config flag.
  • Loading branch information
pschuh committed Feb 14, 2022
1 parent 51c7d3b commit 7ce911b
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 3 deletions.
7 changes: 7 additions & 0 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,3 +685,10 @@ def _update_disable_jit_thread_local(val):
default=False,
help=('Enables experimental features for staging out computations with '
'dynamic shapes.'))

# This flag is temporary during rollout of the remat barrier.
# TODO(parkers): Remove if there are no complaints.
config.define_bool_state(
name='jax_remat_opt_barrier',
default=True,
help=('Enables using optimization-barrier op for lowering remat.'))
45 changes: 44 additions & 1 deletion jax/_src/lax/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3084,6 +3084,12 @@ def body(carry):
carry_res = while_loop(cond, body, carry_init)
return carry_res[1]


def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr):
args = _optimization_barrier(args)
return core.eval_jaxpr(jaxpr, (), *args)


def _remat_translation_rule(*args,
call_jaxpr: Optional[core.Jaxpr] = None,
jaxpr: Optional[core.Jaxpr] = None,
Expand All @@ -3104,6 +3110,8 @@ def _remat_translation_rule(*args,
if differentiated and prevent_cse:
if platform == "gpu":
translation_rule = _remat_translation_using_while
elif platform == "tpu" and config.jax_remat_opt_barrier:
translation_rule = _remat_translation_using_opt_barrier
else:
translation_rule = _remat_translation_using_cond
else:
Expand All @@ -3122,4 +3130,39 @@ def _remat_translation_rule(*args,
mlir.register_lowering(remat_primitive,
mlir.lower_fun(partial(_remat_translation_rule,
platform=platform),
multiple_results=True))
multiple_results=True),
platform=platform)


def _optimization_barrier_abstract_eval(*args):
return args


def _optimization_barrier_translation_rule(ctx, avals_in, avals_out, *args):
out = xops.OptimizationBarrier(xops.Tuple(ctx.builder, args))
return [xops.GetTupleElement(out, i) for i in range(len(args))]


def _optimization_barrier_lowering_rule(ctx, *args):
barrier_types = _map(mlir.aval_to_ir_types, ctx.avals_in)
flat_barrier_types = util.flatten(barrier_types)

flat_args = mlir.flatten_lowering_ir_args(args)
barrier_op = mhlo.OptimizationBarrierOp(flat_barrier_types, flat_args)
return util.unflatten(barrier_op.results, _map(len, barrier_types))


def _optimization_barrier(arg):
flat_args, treedef = tree_flatten(arg)
return tree_unflatten(treedef, optimization_barrier_p.bind(*flat_args))


optimization_barrier_p = core.Primitive('optimization_barrier')
optimization_barrier_p.multiple_results = True
optimization_barrier_p.def_impl(
partial(xla.apply_primitive, optimization_barrier_p))
optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval)
xla.register_translation(optimization_barrier_p,
_optimization_barrier_translation_rule)
mlir.register_lowering(optimization_barrier_p,
_optimization_barrier_lowering_rule)
3 changes: 2 additions & 1 deletion jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
"reduce_precision",
"schur",
"name",
"optimization_barrier",

# Not high priority?
"after_all",
Expand Down Expand Up @@ -2325,7 +2326,7 @@ def select_one_carry(new_c: TfVal, c: TfVal, c_aval: core.ShapedArray) -> TfVal:
tf_impl_with_avals[ad_checkpoint.remat_p] = \
_convert_jax_impl(partial(lax_control_flow._remat_translation_rule,
# TODO: jax2tf cannot discriminate by platform
platform="tpu"),
platform="cpu"),
multiple_results=True,
extra_name_stack="checkpoint")

Expand Down
4 changes: 3 additions & 1 deletion tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3894,10 +3894,12 @@ def f(x):
c = api.xla_computation(f)(2.)
self.assertNotIn('while', c.as_hlo_text())
self.assertNotIn('conditional', c.as_hlo_text())
self.assertNotIn('opt-barrier', c.as_hlo_text())

c = api.xla_computation(grad(f))(2.)
text = c.as_hlo_text()
self.assertTrue('while' in text or 'conditional' in text)
self.assertTrue('while' in text or 'conditional' in text
or 'opt-barrier' in text)

def test_no_cse_widget_with_prevent_cse_false(self):
@partial(api.remat, prevent_cse=False)
Expand Down

0 comments on commit 7ce911b

Please sign in to comment.