Skip to content

Commit

Permalink
add while_loop custom-policy partial eval rule
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Jul 29, 2022
1 parent 22bc535 commit 7f3aa12
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 14 deletions.
6 changes: 3 additions & 3 deletions jax/_src/lax/control_flow/conditionals.py
Expand Up @@ -446,7 +446,7 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn):
unks_out: List[bool] = [False] * len(eqn.outvars)
for jaxpr in branches:
_, _, unks_out_, _, _ = pe.partial_eval_jaxpr_custom(
jaxpr.jaxpr, in_unknowns=ops_uk, in_inst=[True] * len(ops_uk),
jaxpr.jaxpr, in_unknowns=ops_uk, in_inst=True,
ensure_out_unknowns=False, ensure_out_inst=True, saveable=saveable)
unks_out = map(operator.or_, unks_out, unks_out_)

Expand All @@ -458,7 +458,7 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn):
for jaxpr in branches:
jaxpr_known, jaxpr_staged, _, inst_out, num_res = \
pe.partial_eval_jaxpr_custom(
jaxpr.jaxpr, in_unknowns=ops_uk, in_inst=[True] * len(ops_uk),
jaxpr.jaxpr, in_unknowns=ops_uk, in_inst=True,
ensure_out_unknowns=unks_out, ensure_out_inst=True,
saveable=saveable)
branches_known_.append( core.ClosedJaxpr(jaxpr_known, jaxpr.consts))
Expand All @@ -481,7 +481,7 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn):
# passing in_inst argument to partial_eval_jaxpr_custom above).
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
if type(x) is core.Var and not inst]
inst_in = [True] * len(inst_in)
del inst_in

# Create residual variables.
newvar = core.gensym()
Expand Down
71 changes: 67 additions & 4 deletions jax/_src/lax/control_flow/loops.py
Expand Up @@ -828,14 +828,14 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn):
unks_in = const_uk + carry_uk + xs_uk
jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res = \
pe.partial_eval_jaxpr_custom(
jaxpr.jaxpr, in_unknowns=unks_in, in_inst=[True] * len(unks_in),
jaxpr.jaxpr, in_unknowns=unks_in, in_inst=True,
ensure_out_unknowns=carry_uk + [False] * num_ys,
ensure_out_inst=True, saveable=saveable)
carry_uk_out, ys_uk = split_list(unks_out, [num_carry])
if carry_uk_out == carry_uk:
break
else:
carry_uk = _map(operator.or_, carry_uk , carry_uk_out )
carry_uk = _map(operator.or_, carry_uk, carry_uk_out)
else:
assert False, "Fixpoint not reached"
jaxpr_known = core.ClosedJaxpr(jaxpr_known_ , jaxpr.consts)
Expand Down Expand Up @@ -1309,6 +1309,70 @@ def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts:
out_tracers = [t for t, uk in zip(out_tracers_, carry_uk) if uk]
return util.merge_lists(carry_uk, out_known, out_tracers)

# TODO(mattjj): de-duplicate code with _while_partial_eval
def _while_partial_eval_custom(saveable, unks_in, inst_in, eqn):
del saveable # We can't save any residuals anyway (w/o dynamic shapes)!
cond_jaxpr = eqn.params['cond_jaxpr']
cond_nconsts = eqn.params['cond_nconsts']
body_jaxpr = eqn.params['body_jaxpr']
body_nconsts = eqn.params['body_nconsts']

cond_consts_uk, body_consts_uk, carry_init_uk = \
split_list(unks_in, [cond_nconsts, body_nconsts])

# Fixpoint to compute known part of the body (trivial on 'inst_in', since we
# make all inputs available as DCE can subsequently prune any unused ones)
carry_uk = carry_init_uk
for _ in range(1 + len(carry_uk)):
body_unks_in = body_consts_uk + carry_uk
jaxpr_known_, _, carry_uk_out, _, num_res = \
pe.partial_eval_jaxpr_custom(
body_jaxpr.jaxpr, in_unknowns=body_unks_in, in_inst=True,
ensure_out_unknowns=carry_uk, ensure_out_inst=True,
saveable=ad_checkpoint.nothing_saveable)
if carry_uk_out == carry_uk:
break
else:
carry_uk = _map(operator.or_, carry_uk, carry_uk_out)
else:
assert False, "Fixpoint not reached"
assert not num_res
body_jaxpr_known = core.ClosedJaxpr(jaxpr_known_, body_jaxpr.consts)
del jaxpr_known_, carry_uk_out, num_res

# Compute the known part of cond_fun (basically pruning inputs on known side).
cond_unks_in = cond_consts_uk + carry_uk
cond_jaxpr_known_, _, [cond_uk], _, _ = \
pe.partial_eval_jaxpr_custom(
cond_jaxpr.jaxpr, cond_unks_in, in_inst=True,
ensure_out_unknowns=False, ensure_out_inst=True,
saveable=ad_checkpoint.nothing_saveable)
assert not cond_uk # only possible with old-style remat
cond_jaxpr_known = core.ClosedJaxpr(cond_jaxpr_known_, cond_jaxpr.consts)
del cond_uk

# Build the known eqn.
ins_known, _ = partition_list(unks_in, eqn.invars)
out_binders_known, _ = partition_list(carry_uk, eqn.outvars)
params_known = dict(cond_jaxpr=cond_jaxpr_known, body_jaxpr=body_jaxpr_known,
cond_nconsts=len(cond_consts_uk) - sum(cond_consts_uk),
body_nconsts=len(body_consts_uk) - sum(body_consts_uk))
effects_known = core.join_effects(cond_jaxpr_known.effects,
body_jaxpr_known.effects)
eqn_known = pe.new_jaxpr_eqn(ins_known, out_binders_known, while_p,
params_known, effects_known, eqn.source_info)

# Staged eqn is same as input eqn.
eqn_staged = eqn

# Instantiate all inputs (b/c jaxpr_staged takes all inputs).
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
if type(x) is core.Var and not inst]

unks_out = carry_uk
inst_out = [True] * len(unks_out)
return eqn_known, eqn_staged, unks_out, inst_out, new_inst

def _while_transpose_error(*_, **kwargs):
raise ValueError("Reverse-mode differentiation does not work for "
"lax.while_loop or lax.fori_loop. "
Expand All @@ -1323,8 +1387,7 @@ def _while_transpose_error(*_, **kwargs):
xla.register_initial_style_primitive(while_p)
ad.primitive_transposes[while_p] = _while_transpose_error
batching.axis_primitive_batchers[while_p] = _while_loop_batching_rule
pe.partial_eval_jaxpr_custom_rules[while_p] = \
partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'while_loop')
pe.partial_eval_jaxpr_custom_rules[while_p] = _while_partial_eval_custom


def _pred_bcast_select_mhlo(
Expand Down
4 changes: 3 additions & 1 deletion jax/interpreters/partial_eval.py
Expand Up @@ -1241,11 +1241,13 @@ def _remat_partial_eval(trace, _, f, tracers, params):
def partial_eval_jaxpr_custom(
jaxpr: Jaxpr,
in_unknowns: Sequence[bool],
in_inst: Sequence[bool],
in_inst: Union[bool, Sequence[bool]],
ensure_out_unknowns: Union[bool, Sequence[bool]],
ensure_out_inst: Union[bool, Sequence[bool]],
saveable: Callable[..., bool],
) -> Tuple[Jaxpr, Jaxpr, List[bool], List[bool], int]:
if type(in_inst) is bool:
in_inst = (in_inst,) * len(jaxpr.invars)
if type(ensure_out_unknowns) is bool:
ensure_out_unknowns = (ensure_out_unknowns,) * len(jaxpr.outvars)
if type(ensure_out_inst) is bool:
Expand Down
44 changes: 44 additions & 0 deletions tests/api_test.py
Expand Up @@ -5207,6 +5207,50 @@ def sin_jvp(primals, tangents):
self.assertEqual(jaxpr_text.count(' sin '), 1)
self.assertEqual(jaxpr_text.count(' cos '), 2)

@parameterized.named_parameters(
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
('', api.remat),
('_new', new_checkpoint),
])
def test_remat_of_while_loop(self, remat):
def cond_fn(carry):
i, _ = carry
return i < 3
def body_fn(carry):
i, x = carry
return i + 1, jnp.sin(x)
def f(x):
_, y = lax.while_loop(cond_fn, body_fn, (0, x))
return y

_, f_lin = jax.linearize(remat(f), 3.)
y_dot = f_lin(1.0)
expected = jax.grad(lambda x: jnp.sin(jnp.sin(jnp.sin(x))))(3.)
self.assertArraysAllClose(y_dot, expected, check_dtypes=False)

jaxpr = api.make_jaxpr(jax.linearize(remat(f), 4.)[1])(1.)
self.assertIn(' sin ', str(jaxpr))
self.assertIn(' cos ', str(jaxpr))

def test_remat_of_while_loop_policy(self):
def cond_fn(carry):
i, _ = carry
return i < 3
def body_fn(carry):
i, x = carry
return i + 1, jnp.sin(x)
def f(x):
_, y = lax.while_loop(cond_fn, body_fn, (0, x))
return y

# even with a policy, we can't save residuals (w/o dynamic shapes)!
save_cos = lambda prim, *_, **__: str(prim) == 'cos'
g = new_checkpoint(f, policy=save_cos)
jaxpr = api.make_jaxpr(jax.linearize(g, 4.)[1])(1.)
self.assertIn(' sin ', str(jaxpr))
self.assertIn(' cos ', str(jaxpr))


class JaxprTest(jtu.JaxTestCase):

Expand Down
24 changes: 18 additions & 6 deletions tests/lax_control_flow_test.py
Expand Up @@ -105,6 +105,15 @@ def scan_with_for(f, *args, **kwargs):
]


def while_loop_new_checkpoint(cond_fun, body_fun, init_val):
return new_checkpoint(partial(lax.while_loop, cond_fun, body_fun))(init_val)

WHILE_LOOP_IMPLS = [
(lax.while_loop, 'while_loop'),
(while_loop_new_checkpoint, 'new_checkpoint'),
]


def while_loop_reference(cond, body, carry):
while cond(carry):
carry = body(carry)
Expand Down Expand Up @@ -2007,13 +2016,16 @@ def testWhileJVP(self, jit_loop=True, jit_body=False, jit_cond=True):
jtu.check_grads(loop, (x,), order=2, modes=["fwd"])

@parameterized.named_parameters(
{"testcase_name": "_jit_loop={}_jit_body={}_jit_cond={}".format(
jit_loop, jit_body, jit_cond),
"jit_loop": jit_loop, "jit_body": jit_body, "jit_cond": jit_cond}
{"testcase_name": "_jit_loop={}_jit_body={}_jit_cond={}_impl={}".format(
jit_loop, jit_body, jit_cond, while_name),
"jit_loop": jit_loop, "jit_body": jit_body, "jit_cond": jit_cond,
"while_loop": while_impl}
for jit_loop in [False, True]
for jit_body in [False, True]
for jit_cond in [False, True])
def testWhileLinearize(self, jit_loop=True, jit_body=False, jit_cond=True):
for jit_cond in [False, True]
for while_impl, while_name in WHILE_LOOP_IMPLS)
def testWhileLinearize(self, while_loop, jit_loop=True, jit_body=False,
jit_cond=True):
cond = lambda x: x[0, 2] <= 8
body = lambda x: x * x

Expand All @@ -2022,7 +2034,7 @@ def testWhileLinearize(self, jit_loop=True, jit_body=False, jit_cond=True):
if jit_body:
body = jax.jit(body)

loop = partial(lax.while_loop, cond, body)
loop = partial(while_loop, cond, body)
if jit_loop:
loop = jax.jit(loop)

Expand Down

0 comments on commit 7f3aa12

Please sign in to comment.