diff --git a/jax/experimental/maps.py b/jax/experimental/maps.py index 9de960c02066..9aec47e3adb6 100644 --- a/jax/experimental/maps.py +++ b/jax/experimental/maps.py @@ -731,10 +731,12 @@ def make_xmap_callable(fun: lu.WrappedFun, for ax, av, ips in safe_zip(mesh_in_axes, in_avals, in_positional_semantics) ] in_is_gda = [ips == _PositionalSemantics.GLOBAL for ips in in_positional_semantics] + tiling_method: pxla.TilingMethod if config.experimental_xmap_spmd_lowering_manual: - tiling_method = pxla.TilingMethod.MANUAL + manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values())) + tiling_method = pxla.TileManual(manual_mesh_axes) else: - tiling_method = pxla.TilingMethod.VECTORIZE + tiling_method = pxla.TileVectorize() return pxla.lower_mesh_computation( f, name, mesh, mesh_in_axes, mesh_out_axes, donated_invars, @@ -1519,6 +1521,7 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes, xla.check_backend_matches(backend, ctx.module_context.platform) plan = EvaluationPlan.from_axis_resources( axis_resources, resource_env, global_axis_sizes, in_positional_semantics) + manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values())) resource_call_jaxpr = plan.subst_axes_with_resources(call_jaxpr) f = lu.wrap_init(core.jaxpr_as_fun(core.ClosedJaxpr(resource_call_jaxpr, ()))) @@ -1528,7 +1531,7 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes, # NOTE: Sharding constraints are handled entirely by vtile_manual! mesh_in_axes, mesh_out_axes = plan.to_mesh_axes(in_axes, out_axes) mesh = resource_env.physical_mesh - f = pxla.vtile_manual(f, mesh, mesh_in_axes, mesh_out_axes) + f = pxla.vtile_manual(f, tuple(manual_mesh_axes), mesh, mesh_in_axes, mesh_out_axes) # NOTE: We don't extend the resource env with the mesh shape, because those # resources are already in scope! It's the outermost xmap that introduces @@ -1539,7 +1542,6 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes, # We in-line here rather than generating a Call HLO as in the xla_call # translation rule just because the extra tuple stuff is a pain. - manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values())) assert isinstance(ctx.module_context.axis_context, mlir.SPMDAxisContext) sub_ctx = ctx.module_context.replace( name_stack=xla.extend_name_stack(ctx.module_context.name_stack, diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index f248a541d8b9..24cbf7ee8b0b 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -37,9 +37,8 @@ import itertools as it import operator as op import threading -from typing import (Any, Callable, Dict, List, NamedTuple, Optional, +from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet, Sequence, Set, Tuple, Type, Union, Iterable, Mapping, cast) -import enum import sys from absl import logging @@ -2047,11 +2046,11 @@ def vtile_by_mesh(fun: lu.WrappedFun, full_to_shard_p = core.Primitive('full_to_shard') @full_to_shard_p.def_abstract_eval -def _full_to_shard_abstract_eval(x, axes, mesh): +def _full_to_shard_abstract_eval(x, axes, mesh, **_): # TODO: Assert x is a global aval! Or ideally check that it's global in dims from axes! return tile_aval_nd(mesh.shape, axes, x) -def _manual_proto(aval, axes, mesh): +def _manual_proto(aval: core.ShapedArray, manual_axes_set: FrozenSet[MeshAxisName], mesh: Mesh): """Create an OpSharding proto that declares all mesh axes from `axes` as manual and all others as replicated. """ @@ -2059,8 +2058,8 @@ def _manual_proto(aval, axes, mesh): mesh_shape = list(named_mesh_shape.values()) axis_order = {axis: i for i, axis in enumerate(mesh.axis_names)} - manual_axes = list(axes) - replicated_axes = list(axis for axis in mesh.axis_names if axis not in axes) + manual_axes = list(sorted(manual_axes_set, key=str)) + replicated_axes = list(axis for axis in mesh.axis_names if axis not in manual_axes_set) tad_perm = ([axis_order[a] for a in replicated_axes] + [axis_order[a] for a in manual_axes]) @@ -2077,29 +2076,29 @@ def _manual_proto(aval, axes, mesh): return proto @partial(mlir.register_lowering, full_to_shard_p) -def _full_to_shard_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh): +def _full_to_shard_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh, manual_axes: FrozenSet[MeshAxisName]): # TODO: Can we short-circuit for replicated values? Probably not. aval_in, = ctx.avals_in aval_out, = ctx.avals_out sharding_proto = mesh_sharding_specs(mesh.shape, mesh.axis_names)(aval_in, axes).sharding_proto() unspecified_dims = set(range(aval_in.ndim)) - set(axes.values()) sx = mlir.wrap_with_sharding_op(x, sharding_proto, unspecified_dims=unspecified_dims) - manual_proto = _manual_proto(aval_in, axes, mesh) + manual_proto = _manual_proto(aval_in, manual_axes, mesh) result_type, = mlir.aval_to_ir_types(aval_out) return mlir.wrap_with_full_to_shard_op(result_type, sx, manual_proto, unspecified_dims=unspecified_dims), shard_to_full_p = core.Primitive('shard_to_full') @shard_to_full_p.def_abstract_eval -def _shard_to_full_abstract_eval(x, axes, mesh): +def _shard_to_full_abstract_eval(x, axes, mesh, **_): # TODO: Assert x is a global aval! Or ideally check that it's global in dims from axes! return untile_aval_nd(mesh.shape, axes, x) @partial(mlir.register_lowering, shard_to_full_p) -def _shard_to_full_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh): +def _shard_to_full_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh, manual_axes: FrozenSet[MeshAxisName]): aval_in, = ctx.avals_in aval_out, = ctx.avals_out - manual_proto = _manual_proto(aval_in, axes, mesh) + manual_proto = _manual_proto(aval_in, manual_axes, mesh) result_type, = mlir.aval_to_ir_types(aval_out) unspecified_dims = set(range(aval_in.ndim)) - set(axes.values()) sx = mlir.wrap_with_sharding_op(x, manual_proto, unspecified_dims=unspecified_dims) @@ -2107,18 +2106,29 @@ def _shard_to_full_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh): return mlir.wrap_with_shard_to_full_op(result_type, sx, sharding_proto, unspecified_dims), @lu.transformation -def vtile_manual(mesh: Mesh, +def vtile_manual(manual_axes: FrozenSet[MeshAxisName], + mesh: Mesh, in_axes: Sequence[ArrayMapping], out_axes: Sequence[ArrayMapping], *args): - tiled_args = [full_to_shard_p.bind(arg, axes=axes, mesh=mesh) + tiled_args = [full_to_shard_p.bind(arg, axes=axes, mesh=mesh, manual_axes=manual_axes) for arg, axes in zip(args, in_axes)] tiled_outs = yield tiled_args, {} - outs = [shard_to_full_p.bind(out, axes=axes, mesh=mesh) + outs = [shard_to_full_p.bind(out, axes=axes, mesh=mesh, manual_axes=manual_axes) for out, axes in zip(tiled_outs, out_axes)] yield outs -TilingMethod = enum.Enum("TilingMethod", ["VECTORIZE", "MANUAL"]) + +@dataclasses.dataclass(frozen=True) +class TileVectorize: + pass + +@dataclasses.dataclass(frozen=True) +class TileManual: + manual_axes: FrozenSet[MeshAxisName] + +TilingMethod = Union[TileVectorize, TileManual] + @profiler.annotate_function def lower_mesh_computation( @@ -2152,17 +2162,17 @@ def lower_mesh_computation( if spmd_lowering: # TODO: Consider handling xmap's 'vectorize' in here. We can vmap once instead of vtile twice! if tiling_method is not None: - if tiling_method is TilingMethod.VECTORIZE: + if isinstance(tiling_method, TileVectorize): tiling_transform = vtile_by_mesh - elif tiling_method is TilingMethod.MANUAL: - tiling_transform = vtile_manual + elif isinstance(tiling_method, TileManual): + tiling_transform = lambda f, *args: vtile_manual(f, tiling_method.manual_axes, *args) # type: ignore else: raise NotImplementedError(f"Unrecognized tiling method: {tiling_method}") assert not callable(out_axes) fun = tiling_transform(fun, mesh, in_axes, out_axes) in_jaxpr_avals = global_in_avals else: - assert tiling_method is TilingMethod.VECTORIZE + assert isinstance(tiling_method, TileVectorize) in_jaxpr_avals = in_tiled_avals with core.extend_axis_env_nd(mesh.shape.items()): with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} " diff --git a/tests/xmap_test.py b/tests/xmap_test.py index 189443b3f32b..4dcd4adba1c5 100644 --- a/tests/xmap_test.py +++ b/tests/xmap_test.py @@ -717,6 +717,14 @@ def testBasic(self): x = jnp.arange(20, dtype=jnp.float32) self.assertAllClose(fx(x), f(x)) + @jtu.with_mesh([('x', 2)]) + def testReplicated(self): + # TODO(apaszke): This seems to be failing if I try to have a replicated and a mapped argument? + f = lambda x: jnp.sin(jnp.cos(x) + x) * x + fx = xmap(f, in_axes=[...], out_axes=[...], axis_sizes={'i': 4}, axis_resources={'i': 'x'}) + x = jnp.arange(20, dtype=jnp.float32) + self.assertAllClose(fx(x), f(x)) + @jtu.with_mesh([('x', 2), ('y', 1)]) def testInPJit(self): f = xmap(lambda x: jnp.sin(x) + x, in_axes=['i'], out_axes=['i'], axis_resources={'i': 'x'}) @@ -724,6 +732,13 @@ def testInPJit(self): x = jnp.arange(20, dtype=jnp.float32) self.assertAllClose(h(x), jnp.sin(x * x) + x * x + x) + @jtu.with_mesh([('x', 2), ('y', 1)]) + def testInPJitReplicated(self): + f = xmap(lambda x: jnp.sin(x) + x, in_axes={}, out_axes={}, axis_sizes={'i': 4}, axis_resources={'i': 'x'}) + h = pjit(lambda x: f(x * x) + x, in_axis_resources=P('y'), out_axis_resources=None) + x = jnp.arange(20, dtype=jnp.float32) + self.assertAllClose(h(x), jnp.sin(x * x) + x * x + x) + @jtu.with_mesh([('x', 2), ('y', 1)]) def testNestedConstraint(self): # TODO(b/219691408): Using P('y') instead of P() causes an XLA crash!