Skip to content

Commit

Permalink
add custom-policy partial eval and dce rules for pmap
Browse files Browse the repository at this point in the history
Also add a failing test for xmap.
  • Loading branch information
mattjj committed Jul 29, 2022
1 parent 560c936 commit e0c1e6c
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 7 deletions.
3 changes: 2 additions & 1 deletion jax/_src/lax/lax.py
Expand Up @@ -1428,7 +1428,8 @@ def zeros_like_array(x: Array) -> Array:
for t in itertools.chain(
dtypes.python_scalar_dtypes.keys(), array_types,
device_array.device_array_types,
[pxla.ShardedDeviceArray, pxla.pmap_lib.ShardedDeviceArray]):
[pxla.ShardedDeviceArray, pxla._ShardedDeviceArray,
pxla.pmap_lib.ShardedDeviceArray]):
ad_util.jaxval_adders[t] = add
ad_util.jaxval_zeros_likers[device_array._DeviceArray] = zeros_like_array
ad_util.jaxval_zeros_likers[device_array.Buffer] = zeros_like_array
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/maps.py
Expand Up @@ -1075,7 +1075,7 @@ def _xmap_partial_eval_custom_params_updater(
assert params_known['spmd_in_axes'] is None is params_known['spmd_out_axes']
assert params_staged['spmd_in_axes'] is None is params_staged['spmd_out_axes']

# pruned inputs to jaxpr_known according to unks_in
# prune inputs to jaxpr_known according to unks_in
donated_invars_known, _ = pe.partition_list(unks_in, params_known['donated_invars'])
in_axes_known, _ = pe.partition_list(unks_in, params_known['in_axes'])
if num_res == 0:
Expand Down
12 changes: 9 additions & 3 deletions jax/interpreters/partial_eval.py
Expand Up @@ -374,6 +374,7 @@ def const_out_axes_thunk():
staged_params = update_params(params, map(op.not_, in_knowns), num_new_args)
staged_params = dict(staged_params, in_axes=staged_in_axes,
out_axes=tuple(staged_out_axes), call_jaxpr=call_jaxpr)
del staged_params['out_axes_thunk']
# The outputs of the staged-out call are Tracers with the new eqn as recipe.
out_avals = [unmapped_aval(params['axis_size'], params['axis_name'], ax, a)
for ax, a in zip(staged_out_axes, out_avals_mapped)]
Expand Down Expand Up @@ -1357,11 +1358,15 @@ def partial_eval_jaxpr_custom_rule_not_implemented(
ParamsUpdater = Callable[[Sequence[bool], Sequence[bool], Sequence[bool],
Sequence[bool], int, dict, dict],
Tuple[dict, dict]]
ResAvalUpdater = Callable[[Dict[str, Any], AbstractValue], AbstractValue]
def _default_res_aval_updater(
params: Dict[str, Any], aval: AbstractValue) -> AbstractValue:
return aval

def call_partial_eval_custom_rule(
jaxpr_param_name: str, params_updater: ParamsUpdater,
saveable: Callable[..., bool], unks_in: List[bool], inst_in: List[bool],
eqn: JaxprEqn
eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater,
) -> Tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], List[Var]]:
jaxpr = eqn.params[jaxpr_param_name]
jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \
Expand All @@ -1371,12 +1376,13 @@ def call_partial_eval_custom_rule(
_, ins_staged = partition_list(inst_in, eqn.invars)
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
newvar = core.gensym([jaxpr_known, jaxpr_staged])
residuals = [newvar(v.aval) for v in jaxpr_staged.invars[:num_res]]
params_known = {**eqn.params, jaxpr_param_name: jaxpr_known}
params_staged = {**eqn.params, jaxpr_param_name: jaxpr_staged}
params_known, params_staged = params_updater(
unks_in, inst_in, map(op.not_, unks_out), inst_out, num_res, params_known,
params_staged)
residuals = [newvar(res_aval(params_known, var.aval))
for var in jaxpr_staged.invars[:num_res]]
eqn_known = new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
eqn.primitive, params_known, jaxpr_known.effects, eqn.source_info)
eqn_staged = new_jaxpr_eqn([*residuals, *ins_staged], out_binders_staged,
Expand Down Expand Up @@ -1891,7 +1897,7 @@ def process_map(self, map_primitive, f, tracers, params):
if update_params:
new_params = update_params(new_params, [True] * len(tracers), len(consts))
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, map_primitive,
new_params, new_params['call_jaxpr'].effects, source_info)
new_params, jaxpr.effects, source_info)
self.frame.add_eqn(eqn)
return out_tracers

Expand Down
50 changes: 48 additions & 2 deletions jax/interpreters/pxla.py
Expand Up @@ -73,7 +73,7 @@
from jax._src.lib import pmap_lib
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
from jax._src.util import (unzip3, prod, safe_map, safe_zip,
from jax._src.util import (unzip3, prod, safe_map, safe_zip, partition_list,
new_name_stack, wrap_name, assert_unreachable,
tuple_insert, tuple_delete, distributed_debug_log)

Expand Down Expand Up @@ -1671,10 +1671,56 @@ def __call__(self, *args):
xla_pmap = xla_pmap_p.bind
xla_pmap_p.def_impl(xla_pmap_impl)

def _pmap_partial_eval_custom_params_updater(
unks_in, inst_in, kept_outs_known, kept_outs_staged, num_res, params_known,
params_staged):
# prune inputs to jaxpr_known according to unks_in
donated_invars_known, _ = partition_list(unks_in, params_known['donated_invars'])
in_axes_known, _ = partition_list(unks_in, params_known['in_axes'])
_, out_axes_known = partition_list(kept_outs_known, params_known['out_axes'])
out_axes_known = out_axes_known + [0] * num_res
new_params_known = dict(params_known, in_axes=tuple(in_axes_known),
out_axes=tuple(out_axes_known),
donated_invars=tuple(donated_invars_known))

# added num_res new inputs to jaxpr_staged, pruning according to inst_in
_, donated_invars_staged = partition_list(inst_in, params_staged['donated_invars'])
donated_invars_staged = [False] * num_res + donated_invars_staged
_, in_axes_staged = partition_list(inst_in, params_staged['in_axes'])
in_axes_staged = [0] * num_res + in_axes_staged
_, out_axes_staged = partition_list(kept_outs_staged, params_staged['out_axes'])
new_params_staged = dict(params_staged, in_axes=tuple(in_axes_staged),
out_axes=tuple(out_axes_staged),
donated_invars=tuple(donated_invars_staged))
return new_params_known, new_params_staged

def _pmap_partial_eval_custom_res_maker(params_known, aval):
return core.unmapped_aval(params_known['axis_size'], core.no_axis_name, 0, aval)

def _pmap_dce_rule(used_outputs, eqn):
# just like pe.dce_jaxpr_call_rule, except handles in_axes / out_axes
new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['call_jaxpr'], used_outputs)
_, in_axes = partition_list(used_inputs, eqn.params['in_axes'])
_, out_axes = partition_list(used_outputs, eqn.params['out_axes'])
new_params = dict(eqn.params, call_jaxpr=new_jaxpr, in_axes=tuple(in_axes),
out_axes=tuple(out_axes))
if not any(used_inputs) and not any(used_outputs) and not new_jaxpr.effects:
return used_inputs, None
else:
new_eqn = pe.new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, used_inputs) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
return used_inputs, new_eqn


# Set param update handlers to update `donated_invars` just like xla_call_p
pe.call_param_updaters[xla_pmap_p] = pe.call_param_updaters[xla.xla_call_p]
pe.partial_eval_jaxpr_custom_rules[xla_pmap_p] = \
partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'pmap')
partial(pe.call_partial_eval_custom_rule,
'call_jaxpr', _pmap_partial_eval_custom_params_updater,
res_aval=_pmap_partial_eval_custom_res_maker)
pe.dce_rules[xla_pmap_p] = _pmap_dce_rule
ad.call_param_updaters[xla_pmap_p] = ad.call_param_updaters[xla.xla_call_p]
ad.call_transpose_param_updaters[xla_pmap_p] = \
ad.call_transpose_param_updaters[xla.xla_call_p]
Expand Down
49 changes: 49 additions & 0 deletions tests/pmap_test.py
Expand Up @@ -1999,6 +1999,55 @@ def g(x):
return x
jax.grad(f)(3.) # doesn't fail

@parameterized.named_parameters(
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
('', jax.remat),
('_new', new_checkpoint),
])
def test_remat_of_pmap(self, remat):
f = remat(jax.pmap(lambda x: jnp.sin(jnp.sin(x))))
jtu.check_grads(f, (jnp.arange(1.),), order=2, modes=["rev"])

x = jnp.arange(1.)
jaxpr = jax.make_jaxpr(jax.linearize(f, x)[1])(x)
self.assertIn(' sin ', str(jaxpr))
self.assertIn(' cos ', str(jaxpr))

@parameterized.named_parameters(
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
('', jax.remat),
('_new', new_checkpoint),
])
def test_remat_of_pmap_policy(self, remat):
g = jax.pmap(lambda x: jnp.sin(jnp.sin(x)))
x = jnp.arange(1.)

save_cos = lambda prim, *_, **__: str(prim) == 'cos'
f = remat(g, policy=save_cos)
_, f_vjp = jax.vjp(f, x)
jaxpr = f_vjp.args[0].func.args[1]
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 0)

save_sin = lambda prim, *_, **__: str(prim) == 'sin'
f = remat(g, policy=save_sin)
_, f_vjp = jax.vjp(f, x)
jaxpr = f_vjp.args[0].func.args[1]
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 0)
self.assertEqual(jaxpr_text.count(' cos '), 2)

save_nothing = lambda prim, *_, **__: False
f = remat(g, policy=save_nothing)
_, f_vjp = jax.vjp(f, x)
jaxpr = f_vjp.args[0].func.args[1]
jaxpr_text = str(jaxpr)
self.assertEqual(jaxpr_text.count(' sin '), 1)
self.assertEqual(jaxpr_text.count(' cos '), 2)


class CppPmapTest(PythonPmapTest):

Expand Down
7 changes: 7 additions & 0 deletions tests/xmap_test.py
Expand Up @@ -728,6 +728,13 @@ def testNewCheckpoint(self):
f = checkpoint(xmap(lambda x: x, in_axes=['i', ...], out_axes=['i', ...]))
self.assertAllClose(jax.grad(lambda x: f(x).sum())(jnp.arange(3.)), jnp.ones(3))

def testNewCheckpointNonlinearWithPolicy(self):
raise SkipTest("fails!") # TODO(mattjj,apaszke): residual outvars problem
f = checkpoint(xmap(lambda x: jnp.sin(jnp.sin(x)), in_axes=['i', ...],
out_axes=['i', ...]),
policy=lambda prim, *_, **__: str(prim) == 'sin')
jax.grad(lambda x: f(x).sum())(jnp.arange(3.)) # TODO crashes!


class XMapTestSPMD(SPMDTestMixin, XMapTest):
"""Re-executes all basic tests with the SPMD partitioner enabled"""
Expand Down

0 comments on commit e0c1e6c

Please sign in to comment.