Skip to content

Commit

Permalink
[export] Adapt several collective lowering rules for multi-platform l…
Browse files Browse the repository at this point in the history
…owering

This fixes a few more places where the lowering rules used module_context.platform,
which is not supported for multi-platform lowering.
  • Loading branch information
gnecula committed Oct 13, 2023
1 parent e088a8e commit a59ada0
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 76 deletions.
12 changes: 8 additions & 4 deletions jax/_src/interpreters/mlir.py
Expand Up @@ -1505,9 +1505,13 @@ def lower_multi_platform(ctx: LoweringRuleContext,
rule_args: the args of the lowering rules.
rule_kwargs: the kwargs of the lowering rules.
"""
assert isinstance(ctx.module_context.lowering_parameters.platforms, tuple)
platforms = ctx.module_context.lowering_parameters.platforms
platforms_with_specific_rules = util.flatten(
platforms: Sequence[str]
if ctx.module_context.lowering_parameters.is_multi_platform:
assert ctx.module_context.lowering_parameters.platforms is not None
platforms = ctx.module_context.lowering_parameters.platforms
else:
platforms = (ctx.module_context.platform,)
platforms_with_specific_rules: Sequence[str] = util.flatten(
[ps for ps, _ in rules if ps is not None])
platforms_with_default_rule = [p for p in platforms
if p not in platforms_with_specific_rules]
Expand All @@ -1517,7 +1521,7 @@ def lower_multi_platform(ctx: LoweringRuleContext,
rule_index = len(kept_rules)
if ps is not None:
# Keep only rules that mention the platforms of interest
interesting_ps = [p for p in platforms if p in ps]
interesting_ps = [p for p in platforms if p in ps] # type: ignore
if interesting_ps:
for p in interesting_ps:
assert p not in platform_to_kept_rules_idx
Expand Down
1 change: 0 additions & 1 deletion jax/_src/interpreters/pxla.py
Expand Up @@ -1352,7 +1352,6 @@ def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs):
else:
raise TypeError(aval)


def _extend_axis_env(env: sharding_impls.AxisEnv, name, size: int):
return sharding_impls.AxisEnv(env.nreps, env.names + (name,),
env.sizes + (size,))
Expand Down
146 changes: 84 additions & 62 deletions jax/_src/lax/parallel.py
Expand Up @@ -725,10 +725,15 @@ def _allreduce_abstract_eval(*args, axes, axis_index_groups):
for arg, named_shape in zip(args, named_shapes)]

def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
if axis_index_groups is not None and ctx.module_context.platform == "tpu":
# TODO(necula): clean this up when we have module_context.platforms
if ctx.module_context.lowering_parameters.is_multi_platform:
for_tpu = ("tpu" in ctx.module_context.lowering_parameters.platforms)
else:
for_tpu = (ctx.module_context.platform == "tpu")
if axis_index_groups is not None and for_tpu:
len_0 = len(axis_index_groups[0])
if any(len(g) != len_0 for g in axis_index_groups):
raise ValueError("axis_index_groups must all be the same size")
raise ValueError("axis_index_groups must all be the same size for TPU lowering")
named_axes, positional_axes = axes_partition = [], []
for axis in axes:
axes_partition[isinstance(axis, int)].append(axis)
Expand Down Expand Up @@ -1175,7 +1180,8 @@ def _all_gather_impl(x, *, all_gather_dimension, axis_name, axis_index_groups, a
raise AssertionError("Unexpected call to _all_gather_impl")

def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name,
axis_index_groups, axis_size, tiled):
axis_index_groups, axis_size, tiled,
platform=None):
# TODO(jekbradbury): enable for all_gather_dimension > 0
x_aval, = ctx.avals_in
out_aval, = ctx.avals_out
Expand All @@ -1184,9 +1190,8 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name,
axis_context,
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
)
if (ctx.module_context.platform == 'tpu' or
ctx.module_context.platform in ('cuda', 'rocm')
and all_gather_dimension == 0):
if (platform == 'tpu' or
(platform in ('cuda', 'rocm') and all_gather_dimension == 0)):
if not tiled:
new_shape = list(x_aval.shape)
new_shape.insert(all_gather_dimension, 1)
Expand Down Expand Up @@ -1282,6 +1287,10 @@ def _all_gather_batched_collective(frame_size, frame_name, _, vals_in, dims_in,
all_gather_p.def_abstract_eval(_all_gather_abstract_eval)
all_gather_p.def_impl(_all_gather_impl)
mlir.register_lowering(all_gather_p, _all_gather_lowering)
for p in ("cuda", "rocm", "tpu"):
mlir.register_lowering(all_gather_p,
partial(_all_gather_lowering, platform=p),
platform=p)
ad.deflinear2(all_gather_p, _all_gather_transpose_rule)
batching.primitive_batchers[all_gather_p] = _all_gather_batcher
batching.axis_primitive_batchers[all_gather_p] = _all_gather_batched_collective
Expand Down Expand Up @@ -1313,63 +1322,68 @@ def _reduce_scatter_via_reducer(x, *, reducer, scatter_dimension, axis_name,
return outs


def _reduce_scatter_lowering(prim, reducer, ctx, x,
*, scatter_dimension, axis_name,
axis_index_groups, axis_size, tiled):
if ctx.module_context.platform in ("tpu", "cuda", "rocm"):
x_aval, = ctx.avals_in
aval_out, = ctx.avals_out
scalar_aval = x_aval.update(shape=())
replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name,
axis_index_groups)
scatter_out_shape = list(x_aval.shape)
scatter_out_shape[scatter_dimension] //= axis_size
axis_context = ctx.module_context.axis_context
is_spmd = isinstance(
axis_context,
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
)
if is_spmd:
# We want to emit the all-gather with global device IDs and a unique
# channel ID, as otherwise it interprets the devices as replicas instead
# of partitions - and XLA is configured with only a single replica.
channel = ctx.module_context.new_channel()
other_args = dict(
channel_handle=hlo.ChannelHandle.get(
channel, mlir.DEVICE_TO_DEVICE_TYPE),
use_global_device_ids=ir.BoolAttr.get(True))
else:
other_args = {}
op = hlo.ReduceScatterOp(
mlir.aval_to_ir_type(x_aval.update(shape=scatter_out_shape)),
x,
scatter_dimension=mlir.i64_attr(scatter_dimension),
replica_groups=_replica_groups_hlo(replica_groups),
**other_args)
scalar_type = mlir.aval_to_ir_type(scalar_aval)
reducer_block = op.regions[0].blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(reducer_block):
lower_reducer = mlir.lower_fun(prim.bind, multiple_results=False)
reducer_ctx = ctx.replace(primitive=None,
avals_in=[scalar_aval] * 2,
avals_out=[scalar_aval])
out_nodes = lower_reducer(
reducer_ctx, *([a] for a in reducer_block.arguments))
hlo.ReturnOp(util.flatten(out_nodes))
def _reduce_scatter_lowering(
prim, reducer, ctx, x,
*, scatter_dimension, axis_name,
axis_index_groups, axis_size, tiled):
x_aval, = ctx.avals_in
aval_out, = ctx.avals_out
scalar_aval = x_aval.update(shape=())
replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name,
axis_index_groups)
scatter_out_shape = list(x_aval.shape)
scatter_out_shape[scatter_dimension] //= axis_size
axis_context = ctx.module_context.axis_context
is_spmd = isinstance(
axis_context,
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
)
if is_spmd:
# We want to emit the all-gather with global device IDs and a unique
# channel ID, as otherwise it interprets the devices as replicas instead
# of partitions - and XLA is configured with only a single replica.
channel = ctx.module_context.new_channel()
other_args = dict(
channel_handle=hlo.ChannelHandle.get(
channel, mlir.DEVICE_TO_DEVICE_TYPE),
use_global_device_ids=ir.BoolAttr.get(True))
else:
other_args = {}
op = hlo.ReduceScatterOp(
mlir.aval_to_ir_type(x_aval.update(shape=scatter_out_shape)),
x,
scatter_dimension=mlir.i64_attr(scatter_dimension),
replica_groups=_replica_groups_hlo(replica_groups),
**other_args)
scalar_type = mlir.aval_to_ir_type(scalar_aval)
reducer_block = op.regions[0].blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(reducer_block):
lower_reducer = mlir.lower_fun(prim.bind, multiple_results=False)
reducer_ctx = ctx.replace(primitive=None,
avals_in=[scalar_aval] * 2,
avals_out=[scalar_aval])
out_nodes = lower_reducer(
reducer_ctx, *([a] for a in reducer_block.arguments))
hlo.ReturnOp(util.flatten(out_nodes))

if tiled:
return op.results
else:
return hlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), op.result).results
if tiled:
return op.results
else:
return mlir.lower_fun(_reduce_scatter_via_reducer, multiple_results=False)(
ctx, x,
reducer=reducer,
scatter_dimension=scatter_dimension,
axis_name=axis_name,
axis_index_groups=axis_index_groups,
axis_size=axis_size,
tiled=tiled)
return hlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), op.result).results

def _reduce_scatter_lowering_via_reducer(
prim, reducer, ctx, x,
*, scatter_dimension, axis_name,
axis_index_groups, axis_size, tiled):

return mlir.lower_fun(_reduce_scatter_via_reducer, multiple_results=False)(
ctx, x,
reducer=reducer,
scatter_dimension=scatter_dimension,
axis_name=axis_name,
axis_index_groups=axis_index_groups,
axis_size=axis_size,
tiled=tiled)


def _reduce_scatter_abstract_eval(x, *, axis_name, scatter_dimension,
Expand Down Expand Up @@ -1449,9 +1463,17 @@ def _reduce_scatter_collective(frame_size, frame_name, _, vals_in, dims_in,
ad.deflinear2(reduce_scatter_p, _reduce_scatter_transpose_rule)
batching.primitive_batchers[reduce_scatter_p] = _reduce_scatter_batcher
batching.axis_primitive_batchers[reduce_scatter_p] = _reduce_scatter_collective

mlir.register_lowering(
reduce_scatter_p,
partial(_reduce_scatter_lowering, lax.add_p, psum))
partial(_reduce_scatter_lowering_via_reducer, lax.add_p, psum))
reduce_scatter_lowering_for_psum = partial(_reduce_scatter_lowering,
lax.add_p, psum)
for p in ("tpu", "cuda", "rocm"):
mlir.register_lowering(
reduce_scatter_p, reduce_scatter_lowering_for_psum,
platform=p)

core.axis_substitution_rules[reduce_scatter_p] = \
partial(_subst_all_names_in_param, 'axis_name')

Expand Down
60 changes: 51 additions & 9 deletions jax/experimental/jax2tf/tests/multi_platform_export_test.py
Expand Up @@ -13,12 +13,18 @@
# limitations under the License.
"""Tests for multi-platform and cross-platform JAX export."""

import math
import re
from typing import Literal
from typing import Callable, Sequence

from absl import logging
from absl.testing import absltest

import numpy as np

import jax
from jax import lax
from jax._src import pjit
from jax._src import test_util as jtu
from jax.experimental.export import export
# TODO(necula): Move the primitive harness out of jax2tf so that we can move
Expand Down Expand Up @@ -46,7 +52,6 @@ def make_disjunction_regexp(*parts: str) -> re.Pattern[str]:
"random_",
)


class PrimitiveTest(jtu.JaxTestCase):

@classmethod
Expand Down Expand Up @@ -88,8 +93,21 @@ def test_prim(self, harness: primitive_harness.Harness):
for l in harness.jax_unimplemented:
if l.filter(dtype=harness.dtype):
unimplemented_platforms = unimplemented_platforms.union(l.devices)
if (_skip_cuda_lowering_unless_have_gpus.search(harness.fullname)
and all(d.platform != "gpu" for d in self.devices)):
unimplemented_platforms.add("gpu")

logging.info("Harness is not implemented on %s", unimplemented_platforms)

self.export_and_compare_to_native(
func_jax, *args,
unimplemented_platforms=unimplemented_platforms)

def export_and_compare_to_native(
self, func_jax: Callable,
*args: jax.Array,
unimplemented_platforms: set[str] = set(),
skip_run_on_platforms: set[str] = set()):
devices = [
d
for d in self.__class__.devices
Expand All @@ -99,14 +117,9 @@ def test_prim(self, harness: primitive_harness.Harness):
# lowering_platforms uses "cuda" instead of "gpu"
lowering_platforms: list[str] = [
p if p != "gpu" else "cuda"
for p in {"cpu", "gpu", "tpu"} - unimplemented_platforms
for p in ("cpu", "gpu", "tpu")
if p not in unimplemented_platforms
]
if (
"cuda" in lowering_platforms
and _skip_cuda_lowering_unless_have_gpus.search(harness.fullname)
and all(d.platform != "gpu" for d in devices)
):
lowering_platforms.remove("cuda")

if len(lowering_platforms) <= 1:
self.skipTest(
Expand All @@ -117,6 +130,9 @@ def test_prim(self, harness: primitive_harness.Harness):
exp = export.export(func_jax, lowering_platforms=lowering_platforms)(*args)

for device in devices:
if device.platform in skip_run_on_platforms:
logging.info("Skipping running on %s", device)
continue
device_args = jax.tree_util.tree_map(
lambda x: jax.device_put(x, device), args
)
Expand All @@ -127,6 +143,32 @@ def test_prim(self, harness: primitive_harness.Harness):
self.assertAllClose(native_res, exported_res)
# TODO(necula): Check HLO equivalence for the ultimate test.

def test_psum_scatter(self):
f = jax.jit(jax.pmap(lambda x: lax.psum_scatter(x, 'i'),
axis_name='i',
devices=jax.devices()[:1]))

shape = (1, 1, 8)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
self.export_and_compare_to_native(f, x)

# The lowering rule for all_gather has special cases for bool.
@jtu.parameterized_filterable(
kwargs=[
dict(dtype=dtype)
for dtype in [np.bool_, np.float32]],
)
def test_all_gather(self, *, dtype):
f = jax.jit(jax.pmap(lambda x: lax.all_gather(x, 'i'),
axis_name='i',
devices=jax.devices()[:1]))

shape = (1, 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
if dtype == np.bool_:
x = (x % 2).astype(np.bool_)
self.export_and_compare_to_native(f, x)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit a59ada0

Please sign in to comment.