Skip to content

Commit

Permalink
[new-remat] add _scan_partial_eval_custom rule for new remat
Browse files Browse the repository at this point in the history
Also enable scan-of-remat tests which weren't passing before.

Co-authored-by: Sharad Vikram <sharadmv@google.com>
  • Loading branch information
mattjj and sharadmv committed Jun 18, 2022
1 parent 4f5115c commit 83a8dc4
Show file tree
Hide file tree
Showing 8 changed files with 428 additions and 37 deletions.
16 changes: 16 additions & 0 deletions jax/_src/ad_checkpoint.py
Expand Up @@ -388,6 +388,22 @@ def remat_vmap(axis_size, axis_name, main_type, args, dims, *, jaxpr, **params):
return remat_p.bind(*consts, *args, jaxpr=jaxpr_batched, **params), out_dims
batching.axis_primitive_batchers[remat_p] = remat_vmap

# TODO(mattjj,sharadmv): test this more
# TODO(mattjj,sharadmv): de-duplicate with pe.dce_jaxpr_call_rule
def remat_dce(used_outputs: List[bool], eqn: core.JaxprEqn
) -> Tuple[List[bool], Optional[core.JaxprEqn]]:
new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs)
new_params = dict(eqn.params, jaxpr=new_jaxpr)
if not any(used_inputs) and not any(used_outputs) and not new_jaxpr.effects:
return used_inputs, None
else:
new_eqn = pe.new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, used_inputs) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
return used_inputs, new_eqn
pe.dce_rules[remat_p] = remat_dce


def checkpoint_name(x, name):
return name_p.bind(x, name=name)
Expand Down
105 changes: 103 additions & 2 deletions jax/_src/lax/control_flow/loops.py
Expand Up @@ -46,6 +46,7 @@
from jax._src.util import (
cache,
extend_name_stack,
partition_list,
safe_map,
safe_zip,
split_list,
Expand Down Expand Up @@ -829,6 +830,107 @@ def _scan_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn
assert len(new_eqn.outvars) == len(new_params['jaxpr'].out_avals)
return used_inputs, new_eqn

# TODO(mattjj): de-duplicate code with _scan_partial_eval
def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn):
jaxpr = eqn.params['jaxpr']
num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry']
num_ys = len(jaxpr.out_avals) - num_carry

# Fixpoint (currently trivial on 'inst_in')
const_uk, carry_uk, xs_uk = split_list(unks_in, [num_consts, num_carry])
for _ in range(1 + len(carry_uk)):
unks_in = const_uk + carry_uk + xs_uk
jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res = \
pe.partial_eval_jaxpr_custom(
jaxpr.jaxpr, in_unknowns=unks_in, in_inst=[True] * len(unks_in),
ensure_out_unknowns=carry_uk + [False] * num_ys,
ensure_out_inst=True, saveable=saveable)
carry_uk_out , ys_uk = split_list(unks_out, [num_carry])
if carry_uk_out == carry_uk:
break
else:
carry_uk = _map(operator.or_, carry_uk , carry_uk_out )
else:
assert False, "Fixpoint not reached"
jaxpr_known = core.ClosedJaxpr(jaxpr_known_ , jaxpr.consts)
jaxpr_staged = core.ClosedJaxpr(jaxpr_staged_, jaxpr.consts)

# Ensure residuals are all moved to the back.
# TODO(mattjj): make jaxpr_staged only take instantiated inputs
res_avals = jaxpr_staged.in_avals[:num_res]
jaxpr_staged = pe.move_binders_to_back(
jaxpr_staged, [True] * num_res + [False] * len(jaxpr.in_avals))

# Instantiate all inputs (b/c jaxpr_staged takes all inputs).
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
if type(x) is core.Var and not inst]
inst_in = [True] * len(inst_in)

# As an optimization, hoist loop-invariant residuals out of the loop rather
# than using extensive outputs for them. See _scan_partial_eval for comments.
num_const_known = len(const_uk) - sum(const_uk)
num_carry_known = len(carry_uk) - sum(carry_uk)
num_xs_known = len( xs_uk) - sum( xs_uk)
jaxpr_known_hoist, jaxpr_known_loop, loop_dep, _ = \
pe.partial_eval_jaxpr_nounits(
jaxpr_known,
[False] * num_const_known + [True] * (num_carry_known + num_xs_known),
[True] * (len(unks_out) - sum(unks_out)) + [False] * num_res)
# jaxpr_known_hoist produces intensive residuals followed by the constants for
# jaxpr_known_loop. We adjust jaxpr_staged to accept intensive res as consts.
_, loop_dep_res = split_list(loop_dep, [len(loop_dep) - num_res])
jaxpr_staged = pe.move_binders_to_front(
jaxpr_staged, [False] * sum(inst_in) + _map(operator.not_, loop_dep_res))
num_intensive_res = len(loop_dep_res) - sum(loop_dep_res)
del loop_dep, num_carry_known, num_xs_known

# Create residual variables.
intensive_avals, ext_avals_mapped = partition_list(loop_dep_res, res_avals)
ext_avals = [core.unmapped_aval(eqn.params['length'], core.no_axis_name, 0, a)
for a in ext_avals_mapped]
newvar = core.gensym()
intensive_res = _map(newvar, intensive_avals)
extensive_res = _map(newvar, ext_avals)

# Create known eqn, which is a call_p combining evaluation of
# jaxpr_known_hoist and a scan of jaxpr_known_loop.
ins_known, _ = partition_list(unks_in, eqn.invars)
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
linear_known = [l for l, uk in zip(eqn.params['linear'], unks_in) if not uk]
params_known = dict(eqn.params, jaxpr=jaxpr_known_loop,
num_consts=len(const_uk)-sum(const_uk),
num_carry=len(carry_uk)-sum(carry_uk),
linear=tuple(linear_known))

@lu.wrap_init
def known(*ins_known):
consts_known_hoist, ins_known_lp = split_list(ins_known, [num_const_known])
out_hoist = core.jaxpr_as_fun(jaxpr_known_hoist)(*consts_known_hoist)
intensive_res, consts_known_lp = split_list(out_hoist, [num_intensive_res])
out_loop = scan_p.bind(*consts_known_lp, *ins_known_lp, **params_known)
return [*intensive_res, *out_loop]
call_jaxpr_, _, call_jaxpr_consts = pe.trace_to_jaxpr_dynamic(
known, [v.aval for v in ins_known])
call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts)
eqn_known = pe.new_jaxpr_eqn(
ins_known, [*intensive_res, *out_binders_known, *extensive_res],
core.closed_call_p, dict(call_jaxpr=call_jaxpr), call_jaxpr.effects,
eqn.source_info)

_, out_binders_staged = partition_list(inst_out, eqn.outvars)
linear_staged = ([False] * len(intensive_res) + list(eqn.params['linear']) +
[False] * len(extensive_res))
params_staged = dict(eqn.params, jaxpr=jaxpr_staged,
num_consts=len(intensive_res) + eqn.params['num_consts'],
linear=tuple(linear_staged))
eqn_staged = pe.new_jaxpr_eqn([*intensive_res, *eqn.invars, *extensive_res],
out_binders_staged, eqn.primitive,
params_staged, jaxpr_staged.effects,
eqn.source_info)

new_vars = [*new_inst, *intensive_res, *extensive_res]
return eqn_known, eqn_staged, unks_out, inst_out, new_vars

def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts, num_carry,
jaxpr, linear, unroll):
avals = [x.aval for x in in_atoms]
Expand Down Expand Up @@ -899,8 +1001,7 @@ def scan_bind(*args, **params):
batching.axis_primitive_batchers[scan_p] = _scan_batching_rule
masking.masking_rules[scan_p] = _scan_masking_rule
core.custom_typechecks[scan_p] = partial(_scan_typecheck, False)
pe.partial_eval_jaxpr_custom_rules[scan_p] = \
partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'scan')
pe.partial_eval_jaxpr_custom_rules[scan_p] = _scan_partial_eval_custom
pe.padding_rules[scan_p] = _scan_padding_rule
pe.dce_rules[scan_p] = _scan_dce_rule

Expand Down
6 changes: 5 additions & 1 deletion jax/interpreters/ad.py
Expand Up @@ -589,7 +589,11 @@ def traceable(num_primals, in_tree_def, *primals_and_tangents):


def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
if isinstance(call_jaxpr, core.ClosedJaxpr):
call_jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts
else:
consts = ()
all_args, in_tree_def = tree_flatten((consts, args, ct))
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr,
reduce_axes, False)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
Expand Down
4 changes: 3 additions & 1 deletion jax/interpreters/mlir.py
Expand Up @@ -997,7 +997,6 @@ def f_lowered(ctx, *args, **params):
return f_lowered



def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in,
avals_out, tokens_in, *args):
if isinstance(call_jaxpr, core.Jaxpr):
Expand Down Expand Up @@ -1041,6 +1040,9 @@ def _named_call_lowering(ctx, *args, name, backend=None,
register_lowering(core.closed_call_p,
partial(_named_call_lowering, name="core_closed_call"))

register_lowering(core.closed_call_p,
partial(_named_call_lowering, name="core_closed_call"))


def full_like_aval(value, aval: core.ShapedArray) -> ir.Value:
"""Returns an IR constant shaped full of `value` shaped like `aval`."""
Expand Down
18 changes: 11 additions & 7 deletions jax/interpreters/partial_eval.py
Expand Up @@ -915,7 +915,7 @@ def fun(*known_vals_in):
assert ([v.aval.strip_weak_type() for v in jaxpr_known.outvars] ==
[a.strip_weak_type() for a, uk in zip(jaxpr.out_avals, out_unknowns)
if not uk] + [a.strip_weak_type() for a in res_avals])
# check jaxpr_unknown has input type corresponding to unknown inputs plus res
# check jaxpr_unknown has input type corresponding to res plus unknown inputs
assert ([v.aval for v in jaxpr_unknown.invars] ==
res_avals + [a for a, uk in zip(jaxpr.in_avals, in_unknowns) if uk])
# check jaxpr_unknown has output type corresponding to unknown outputs
Expand Down Expand Up @@ -1092,6 +1092,7 @@ def ensure_instantiated(inst: bool, x: Atom) -> Atom:

known_eqns, staged_eqns = [], []
map(write, in_unknowns, in_inst, jaxpr.invars)
map(partial(write, False, True), jaxpr.constvars)
for eqn in jaxpr.eqns:
unks_in, inst_in = unzip2(map(read, eqn.invars))
rule = partial_eval_jaxpr_custom_rules.get(eqn.primitive)
Expand Down Expand Up @@ -1277,17 +1278,20 @@ def _default_dce_rule(


def dce_jaxpr_call_rule(used_outputs: List[bool], eqn: JaxprEqn
) -> Tuple[List[bool], JaxprEqn]:
) -> Tuple[List[bool], Optional[JaxprEqn]]:
new_jaxpr, used_inputs = dce_jaxpr(eqn.params['call_jaxpr'], used_outputs)
new_params = dict(eqn.params, call_jaxpr=new_jaxpr)
update_params = call_param_updaters.get(eqn.primitive)
if update_params:
new_params = update_params(new_params, used_inputs, 0)
new_eqn = new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, used_inputs) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
return used_inputs, new_eqn
if not any(used_inputs) and not any(used_outputs) and not new_jaxpr.effects:
return used_inputs, None
else:
new_eqn = new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, used_inputs) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
return used_inputs, new_eqn
dce_rules[core.call_p] = dce_jaxpr_call_rule
dce_rules[core.named_call_p] = dce_jaxpr_call_rule
dce_rules[remat_call_p] = dce_jaxpr_call_rule
Expand Down

0 comments on commit 83a8dc4

Please sign in to comment.