Skip to content

Commit

Permalink
Always treat all mesh axes controlled by xmap as MANUAL
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 430192736
  • Loading branch information
apaszke authored and jax authors committed Feb 22, 2022
1 parent a65841f commit 2641f06
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 23 deletions.
10 changes: 6 additions & 4 deletions jax/experimental/maps.py
Expand Up @@ -731,10 +731,12 @@ def make_xmap_callable(fun: lu.WrappedFun,
for ax, av, ips in safe_zip(mesh_in_axes, in_avals, in_positional_semantics)
]
in_is_gda = [ips == _PositionalSemantics.GLOBAL for ips in in_positional_semantics]
tiling_method: pxla.TilingMethod
if config.experimental_xmap_spmd_lowering_manual:
tiling_method = pxla.TilingMethod.MANUAL
manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values()))
tiling_method = pxla.TileManual(manual_mesh_axes)
else:
tiling_method = pxla.TilingMethod.VECTORIZE
tiling_method = pxla.TileVectorize()
return pxla.lower_mesh_computation(
f, name, mesh,
mesh_in_axes, mesh_out_axes, donated_invars,
Expand Down Expand Up @@ -1519,6 +1521,7 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,
xla.check_backend_matches(backend, ctx.module_context.platform)
plan = EvaluationPlan.from_axis_resources(
axis_resources, resource_env, global_axis_sizes, in_positional_semantics)
manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values()))

resource_call_jaxpr = plan.subst_axes_with_resources(call_jaxpr)
f = lu.wrap_init(core.jaxpr_as_fun(core.ClosedJaxpr(resource_call_jaxpr, ())))
Expand All @@ -1528,7 +1531,7 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,
# NOTE: Sharding constraints are handled entirely by vtile_manual!
mesh_in_axes, mesh_out_axes = plan.to_mesh_axes(in_axes, out_axes)
mesh = resource_env.physical_mesh
f = pxla.vtile_manual(f, mesh, mesh_in_axes, mesh_out_axes)
f = pxla.vtile_manual(f, tuple(manual_mesh_axes), mesh, mesh_in_axes, mesh_out_axes)

# NOTE: We don't extend the resource env with the mesh shape, because those
# resources are already in scope! It's the outermost xmap that introduces
Expand All @@ -1539,7 +1542,6 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,

# We in-line here rather than generating a Call HLO as in the xla_call
# translation rule just because the extra tuple stuff is a pain.
manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values()))
assert isinstance(ctx.module_context.axis_context, mlir.SPMDAxisContext)
sub_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
Expand Down
48 changes: 29 additions & 19 deletions jax/interpreters/pxla.py
Expand Up @@ -37,9 +37,8 @@
import itertools as it
import operator as op
import threading
from typing import (Any, Callable, Dict, List, NamedTuple, Optional,
from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
Sequence, Set, Tuple, Type, Union, Iterable, Mapping, cast)
import enum
import sys

from absl import logging
Expand Down Expand Up @@ -2047,20 +2046,20 @@ def vtile_by_mesh(fun: lu.WrappedFun,
full_to_shard_p = core.Primitive('full_to_shard')

@full_to_shard_p.def_abstract_eval
def _full_to_shard_abstract_eval(x, axes, mesh):
def _full_to_shard_abstract_eval(x, axes, mesh, **_):
# TODO: Assert x is a global aval! Or ideally check that it's global in dims from axes!
return tile_aval_nd(mesh.shape, axes, x)

def _manual_proto(aval, axes, mesh):
def _manual_proto(aval: core.ShapedArray, manual_axes_set: FrozenSet[MeshAxisName], mesh: Mesh):
"""Create an OpSharding proto that declares all mesh axes from `axes` as manual
and all others as replicated.
"""
named_mesh_shape = mesh.shape
mesh_shape = list(named_mesh_shape.values())
axis_order = {axis: i for i, axis in enumerate(mesh.axis_names)}

manual_axes = list(axes)
replicated_axes = list(axis for axis in mesh.axis_names if axis not in axes)
manual_axes = list(sorted(manual_axes_set, key=str))
replicated_axes = list(axis for axis in mesh.axis_names if axis not in manual_axes_set)

tad_perm = ([axis_order[a] for a in replicated_axes] +
[axis_order[a] for a in manual_axes])
Expand All @@ -2077,48 +2076,59 @@ def _manual_proto(aval, axes, mesh):
return proto

@partial(mlir.register_lowering, full_to_shard_p)
def _full_to_shard_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh):
def _full_to_shard_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh, manual_axes: FrozenSet[MeshAxisName]):
# TODO: Can we short-circuit for replicated values? Probably not.
aval_in, = ctx.avals_in
aval_out, = ctx.avals_out
sharding_proto = mesh_sharding_specs(mesh.shape, mesh.axis_names)(aval_in, axes).sharding_proto()
unspecified_dims = set(range(aval_in.ndim)) - set(axes.values())
sx = mlir.wrap_with_sharding_op(x, sharding_proto, unspecified_dims=unspecified_dims)
manual_proto = _manual_proto(aval_in, axes, mesh)
manual_proto = _manual_proto(aval_in, manual_axes, mesh)
result_type, = mlir.aval_to_ir_types(aval_out)
return mlir.wrap_with_full_to_shard_op(result_type, sx, manual_proto, unspecified_dims=unspecified_dims),

shard_to_full_p = core.Primitive('shard_to_full')

@shard_to_full_p.def_abstract_eval
def _shard_to_full_abstract_eval(x, axes, mesh):
def _shard_to_full_abstract_eval(x, axes, mesh, **_):
# TODO: Assert x is a global aval! Or ideally check that it's global in dims from axes!
return untile_aval_nd(mesh.shape, axes, x)

@partial(mlir.register_lowering, shard_to_full_p)
def _shard_to_full_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh):
def _shard_to_full_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh, manual_axes: FrozenSet[MeshAxisName]):
aval_in, = ctx.avals_in
aval_out, = ctx.avals_out
manual_proto = _manual_proto(aval_in, axes, mesh)
manual_proto = _manual_proto(aval_in, manual_axes, mesh)
result_type, = mlir.aval_to_ir_types(aval_out)
unspecified_dims = set(range(aval_in.ndim)) - set(axes.values())
sx = mlir.wrap_with_sharding_op(x, manual_proto, unspecified_dims=unspecified_dims)
sharding_proto = mesh_sharding_specs(mesh.shape, mesh.axis_names)(aval_out, axes).sharding_proto()
return mlir.wrap_with_shard_to_full_op(result_type, sx, sharding_proto, unspecified_dims),

@lu.transformation
def vtile_manual(mesh: Mesh,
def vtile_manual(manual_axes: FrozenSet[MeshAxisName],
mesh: Mesh,
in_axes: Sequence[ArrayMapping],
out_axes: Sequence[ArrayMapping],
*args):
tiled_args = [full_to_shard_p.bind(arg, axes=axes, mesh=mesh)
tiled_args = [full_to_shard_p.bind(arg, axes=axes, mesh=mesh, manual_axes=manual_axes)
for arg, axes in zip(args, in_axes)]
tiled_outs = yield tiled_args, {}
outs = [shard_to_full_p.bind(out, axes=axes, mesh=mesh)
outs = [shard_to_full_p.bind(out, axes=axes, mesh=mesh, manual_axes=manual_axes)
for out, axes in zip(tiled_outs, out_axes)]
yield outs

TilingMethod = enum.Enum("TilingMethod", ["VECTORIZE", "MANUAL"])

@dataclasses.dataclass(frozen=True)
class TileVectorize:
pass

@dataclasses.dataclass(frozen=True)
class TileManual:
manual_axes: FrozenSet[MeshAxisName]

TilingMethod = Union[TileVectorize, TileManual]


@profiler.annotate_function
def lower_mesh_computation(
Expand Down Expand Up @@ -2152,17 +2162,17 @@ def lower_mesh_computation(
if spmd_lowering:
# TODO: Consider handling xmap's 'vectorize' in here. We can vmap once instead of vtile twice!
if tiling_method is not None:
if tiling_method is TilingMethod.VECTORIZE:
if isinstance(tiling_method, TileVectorize):
tiling_transform = vtile_by_mesh
elif tiling_method is TilingMethod.MANUAL:
tiling_transform = vtile_manual
elif isinstance(tiling_method, TileManual):
tiling_transform = lambda f, *args: vtile_manual(f, tiling_method.manual_axes, *args) # type: ignore
else:
raise NotImplementedError(f"Unrecognized tiling method: {tiling_method}")
assert not callable(out_axes)
fun = tiling_transform(fun, mesh, in_axes, out_axes)
in_jaxpr_avals = global_in_avals
else:
assert tiling_method is TilingMethod.VECTORIZE
assert isinstance(tiling_method, TileVectorize)
in_jaxpr_avals = in_tiled_avals
with core.extend_axis_env_nd(mesh.shape.items()):
with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} "
Expand Down
15 changes: 15 additions & 0 deletions tests/xmap_test.py
Expand Up @@ -717,13 +717,28 @@ def testBasic(self):
x = jnp.arange(20, dtype=jnp.float32)
self.assertAllClose(fx(x), f(x))

@jtu.with_mesh([('x', 2)])
def testReplicated(self):
# TODO(apaszke): This seems to be failing if I try to have a replicated and a mapped argument?
f = lambda x: jnp.sin(jnp.cos(x) + x) * x
fx = xmap(f, in_axes=[...], out_axes=[...], axis_sizes={'i': 4}, axis_resources={'i': 'x'})
x = jnp.arange(20, dtype=jnp.float32)
self.assertAllClose(fx(x), f(x))

@jtu.with_mesh([('x', 2), ('y', 1)])
def testInPJit(self):
f = xmap(lambda x: jnp.sin(x) + x, in_axes=['i'], out_axes=['i'], axis_resources={'i': 'x'})
h = pjit(lambda x: f(x * x) + x, in_axis_resources=P('y'), out_axis_resources=None)
x = jnp.arange(20, dtype=jnp.float32)
self.assertAllClose(h(x), jnp.sin(x * x) + x * x + x)

@jtu.with_mesh([('x', 2), ('y', 1)])
def testInPJitReplicated(self):
f = xmap(lambda x: jnp.sin(x) + x, in_axes={}, out_axes={}, axis_sizes={'i': 4}, axis_resources={'i': 'x'})
h = pjit(lambda x: f(x * x) + x, in_axis_resources=P('y'), out_axis_resources=None)
x = jnp.arange(20, dtype=jnp.float32)
self.assertAllClose(h(x), jnp.sin(x * x) + x * x + x)

@jtu.with_mesh([('x', 2), ('y', 1)])
def testNestedConstraint(self):
# TODO(b/219691408): Using P('y') instead of P() causes an XLA crash!
Expand Down

0 comments on commit 2641f06

Please sign in to comment.