Skip to content

Commit

Permalink
Add a partial_eval_jaxpr_custom_rule for xmap
Browse files Browse the repository at this point in the history
Additionaly fix a bug in partial_eval rule for xmap.

PiperOrigin-RevId: 428738277
  • Loading branch information
apaszke authored and jax authors committed Feb 15, 2022
1 parent 0d9990e commit c551bed
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 24 deletions.
7 changes: 2 additions & 5 deletions jax/_src/ad_checkpoint.py
Expand Up @@ -335,11 +335,8 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params):
return pe._zip_knowns(out_known_tracers, out_jaxpr_tracers, out_unknowns)
pe.custom_partial_eval_rules[remat_p] = remat_partial_eval

def remat_partial_eval_custom_params_updater(_, __, params_known, params_staged):
jaxpr_known = params_known.pop('call_jaxpr')
jaxpr_staged = params_staged.pop('call_jaxpr')
return (dict(params_known, jaxpr=jaxpr_known),
dict(params_staged, jaxpr=jaxpr_staged, differentiated=True))
def remat_partial_eval_custom_params_updater(_, __, ___, ____, params_known, params_staged):
return params_known, dict(params_staged, differentiated=True)
pe.partial_eval_jaxpr_custom_rules[remat_p] = \
partial(pe.call_partial_eval_custom_rule, 'jaxpr',
remat_partial_eval_custom_params_updater)
Expand Down
61 changes: 53 additions & 8 deletions jax/experimental/maps.py
Expand Up @@ -1098,8 +1098,49 @@ def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
self.frame.eqns.append(eqn)
return out_tracers
pe.DynamicJaxprTrace.process_xmap = _dynamic_jaxpr_process_xmap # type: ignore

def _xmap_partial_eval_custom_params_updater(
unks_in: Sequence[bool],
kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool],
num_res: int, params_known: dict, params_staged: dict
) -> Tuple[dict, dict]:
assert params_known['spmd_in_axes'] is None and params_known['spmd_out_axes'] is None
assert params_staged['spmd_in_axes'] is None and params_staged['spmd_out_axes'] is None

# pruned inputs to jaxpr_known according to unks_in
donated_invars_known, _ = pe.partition_list(unks_in, params_known['donated_invars'])
in_axes_known, _ = pe.partition_list(unks_in, params_known['in_axes'])
if num_res == 0:
residual_axes = []
else:
residual_axes = [
AxisNamePos(zip(sort_named_shape, range(len(sort_named_shape))),
user_repr=f'<internal: {sort_named_shape}>')
for named_shape in (v.aval.named_shape for v in params_known['call_jaxpr'].outvars[:-num_res])
# We sort here to make the iteration order deterministic
for sort_named_shape in [sorted(named_shape, key=str)]
]
_, out_axes_known = pe.partition_list(kept_outs_known, params_known['out_axes'])
new_params_known = dict(params_known,
in_axes=tuple(in_axes_known),
out_axes=(*out_axes_known, *residual_axes),
donated_invars=tuple(donated_invars_known))
assert len(new_params_known['in_axes']) == len(params_known['call_jaxpr'].invars)
assert len(new_params_known['out_axes']) == len(params_known['call_jaxpr'].outvars)

# added num_res new inputs to jaxpr_staged
donated_invars_staged = (*(False for _ in range(num_res)), *params_staged['donated_invars'])
_, out_axes_staged = pe.partition_list(kept_outs_staged, params_staged['out_axes'])
new_params_staged = dict(params_staged,
in_axes=(*residual_axes, *params_staged['in_axes']),
out_axes=tuple(out_axes_staged),
donated_invars=donated_invars_staged)
assert len(new_params_staged['in_axes']) == len(params_staged['call_jaxpr'].invars)
assert len(new_params_staged['out_axes']) == len(params_staged['call_jaxpr'].outvars)
return new_params_known, new_params_staged
pe.partial_eval_jaxpr_custom_rules[xmap_p] = \
partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'xmap')
partial(pe.call_partial_eval_custom_rule, 'call_jaxpr',
_xmap_partial_eval_custom_params_updater)


@lu.transformation_with_aux
Expand Down Expand Up @@ -1155,13 +1196,17 @@ def new_out_axes_thunk():
assert not any(const_units)
num_consts = len(const_units)
out_axes_no_units = [a for a, u in zip(out_axes, axes_units) if not u]
const_axes = [
AxisNamePos(zip(sort_named_shape, range(len(sort_named_shape))),
user_repr=f'<internal: {sort_named_shape}>')
for named_shape in out_named_shapes()[-num_consts:]
# We sort here to make the iteration order deterministic
for sort_named_shape in [sorted(named_shape, key=str)]
]
const_axes: Sequence[AxisNamePos]
if num_consts == 0:
const_axes = ()
else:
const_axes = [
AxisNamePos(zip(sort_named_shape, range(len(sort_named_shape))),
user_repr=f'<internal: {sort_named_shape}>')
for named_shape in out_named_shapes()[-num_consts:]
# We sort here to make the iteration order deterministic
for sort_named_shape in [sorted(named_shape, key=str)]
]
if not const_axes_s: # NOTE: This can be called multiple times
const_axes_s.store(const_axes)
assert const_axes_s.val == const_axes
Expand Down
18 changes: 11 additions & 7 deletions jax/interpreters/partial_eval.py
Expand Up @@ -1026,7 +1026,8 @@ def partial_eval_jaxpr_custom_rule_not_implemented(
raise NotImplementedError(msg)


ParamsUpdater = Callable[[List[bool], int, dict, dict], Tuple[dict, dict]]
ParamsUpdater = Callable[[Sequence[bool], Sequence[bool], Sequence[bool],
int, dict, dict], Tuple[dict, dict]]

def call_partial_eval_custom_rule(
jaxpr_param_name: str, params_updater: ParamsUpdater,
Expand All @@ -1040,13 +1041,16 @@ def call_partial_eval_custom_rule(
# by convention, _partial_eval_jaxpr_custom drops units on known outputs
known_units_out = [v.aval is core.abstract_unit for v in jaxpr.outvars]
dropped_outs_known = map(op.or_, unks_out, known_units_out)
kept_outs_known = [not d for d in dropped_outs_known]
out_binders_known, _ = partition_list(dropped_outs_known, eqn.outvars)
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
kept_outs_staged = inst_out
newvar = core.gensym([jaxpr_known, jaxpr_staged])
residuals = [newvar(v.aval) for v in jaxpr_staged.invars[:num_res]]
params_known = dict(eqn.params, call_jaxpr=jaxpr_known)
params_staged = dict(eqn.params, call_jaxpr=jaxpr_staged)
params_known, params_staged = params_updater(unks_in, num_res, params_known, params_staged)
params_known = {**eqn.params, jaxpr_param_name: jaxpr_known}
params_staged = {**eqn.params, jaxpr_param_name: jaxpr_staged}
params_known, params_staged = params_updater(
unks_in, kept_outs_known, kept_outs_staged, num_res, params_known, params_staged)
eqn_known = new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
eqn.primitive, params_known, eqn.source_info)
eqn_staged = new_jaxpr_eqn([*residuals, *eqn.invars], out_binders_staged,
Expand All @@ -1057,13 +1061,13 @@ def call_partial_eval_custom_rule(
return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals
partial_eval_jaxpr_custom_rules[core.call_p] = \
partial(call_partial_eval_custom_rule, 'call_jaxpr',
lambda _, __, x, y: (x, y))
lambda _, __, ___, ____, x, y: (x, y))
partial_eval_jaxpr_custom_rules[core.named_call_p] = \
partial(call_partial_eval_custom_rule, 'call_jaxpr',
lambda _, __, x, y: (x, y))
lambda _, __, ___, ____, x, y: (x, y))
partial_eval_jaxpr_custom_rules[remat_call_p] = \
partial(call_partial_eval_custom_rule, 'call_jaxpr',
lambda _, __, p1, p2: (p1, dict(p2, differentiated=True)))
lambda _, __, ___, ____, p1, p2: (p1, dict(p2, differentiated=True)))


# TODO(mattjj): unify with dce code below
Expand Down
4 changes: 3 additions & 1 deletion jax/interpreters/xla.py
Expand Up @@ -859,7 +859,9 @@ def _xla_call_translation_rule(ctx, avals_in, avals_out, *in_nodes, name,


def _xla_call_partial_eval_custom_params_updater(
unks_in: List[bool], num_res: int, params_known: dict, params_staged: dict
unks_in: Sequence[bool],
kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool],
num_res: int, params_known: dict, params_staged: dict
) -> Tuple[dict, dict]:
# pruned inputs to jaxpr_known according to unks_in, so prune donated_invars
donated_invars_known, _ = partition_list(unks_in, params_known['donated_invars'])
Expand Down
5 changes: 2 additions & 3 deletions tests/xmap_test.py
Expand Up @@ -659,10 +659,9 @@ def testLowerCompileArgTypeMismatch(self):
"called with:\n.*int32.*",
lambda: f_exe(x_i32))

def testNewCheckpointError(self):
def testNewCheckpoint(self):
f = checkpoint(xmap(lambda x: x, in_axes=['i', ...], out_axes=['i', ...]))
with self.assertRaisesRegex(NotImplementedError, 'xmap'):
jax.grad(f)(jnp.arange(3.))
self.assertAllClose(jax.grad(lambda x: f(x).sum())(jnp.arange(3.)), jnp.ones(3))


class XMapTestSPMD(SPMDTestMixin, XMapTest):
Expand Down

0 comments on commit c551bed

Please sign in to comment.