Skip to content

Commit

Permalink
Remove allowlist for multihost collectives.
Browse files Browse the repository at this point in the history
This allowlist used to prevent users from using collectives that didn't work correctly in multihost pmap(). But currently every collective in JAX (except for pgather(), which isn't public), is on the list. So the allowlist serves no purpose any more.

PiperOrigin-RevId: 555124144
  • Loading branch information
hawkinsp authored and jax authors committed Aug 9, 2023
1 parent 1bd5fd2 commit c9cf6b4
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 45 deletions.
23 changes: 0 additions & 23 deletions jax/_src/interpreters/pxla.py
Expand Up @@ -633,15 +633,6 @@ def stage_parallel_callable(
assert len(out_sharded_avals) == len(pci.out_axes), (
len(out_sharded_avals), len(pci.out_axes))

# TODO(skye,mattjj): allow more collectives on multi-host as we test them, but
# for now raise an error
if pci.devices is not None:
is_multi_host_pmap = len(pci.local_devices) != len(pci.devices)
else:
is_multi_host_pmap = xb.process_count(pci.backend) > 1
if is_multi_host_pmap:
check_multihost_collective_allowlist(jaxpr)

replicas = find_replicas(jaxpr, pci.axis_size, pci.global_axis_size)
num_local_shards = replicas.num_local_replicas
num_global_shards = replicas.num_global_replicas
Expand Down Expand Up @@ -1036,17 +1027,6 @@ def _get_pmap_sharding(devices, specs):
return [sharding_impls.PmapSharding(devices, spec) for spec in specs]


multi_host_supported_collectives: set[core.Primitive] = set()


def check_multihost_collective_allowlist(jaxpr):
used_collectives = set(xla.jaxpr_collectives(jaxpr))
if not used_collectives.issubset(multi_host_supported_collectives):
bad_collectives = used_collectives - multi_host_supported_collectives
msg = "using collectives that aren't supported for multi-host: {}"
raise TypeError(msg.format(", ".join(map(str, bad_collectives))))


class InputsHandler:
__slots__ = ("handler", "local_devices", "in_shardings", "input_indices")

Expand Down Expand Up @@ -1982,7 +1962,6 @@ def lower_sharding_computation(
da_object = _create_da_object(tuple(device_assignment))

if not da_object.is_fully_addressable:
check_multihost_collective_allowlist(jaxpr)
if inline and config.jax_spmd_mode != 'allow_all':
raise RuntimeError(
"Running operations on `Array`s that are not fully addressable by this "
Expand Down Expand Up @@ -2144,8 +2123,6 @@ def lower_mesh_computation(
for aval, o in safe_zip(out_jaxpr_avals, out_shardings)]

_sanitize_mesh_jaxpr(jaxpr)
if mesh.is_multi_process:
check_multihost_collective_allowlist(jaxpr)
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)

# 2. Build up the HLO
Expand Down
12 changes: 0 additions & 12 deletions jax/_src/interpreters/xla.py
Expand Up @@ -248,18 +248,6 @@ def xla_destructure(c, ans):
return [xops.GetTupleElement(ans, i) for i in range(num_elements)]


# TODO(mattjj,skyewm): the functions here are utilities for checking if
# not-yet-supported features are used with multi-host programming


def jaxpr_collectives(jaxpr):
"""Generates all the collective primitives anywhere inside a Jaxpr."""
for eqn in jaxpr.eqns:
if eqn.primitive in _collective_primitives:
yield eqn.primitive
for subjaxpr in core.subjaxprs(jaxpr): yield from jaxpr_collectives(subjaxpr)


### translation tables

MYPY = False
Expand Down
10 changes: 0 additions & 10 deletions jax/_src/lax/parallel.py
Expand Up @@ -809,7 +809,6 @@ def broadcast_positional(ct, arg):
mlir.register_lowering(
psum_p, partial(_allreduce_lowering, lax.add_p, lax._reduce_sum))
ad.deflinear2(psum_p, _psum_transpose_rule)
pxla.multi_host_supported_collectives.add(psum_p)
batching.primitive_batchers[psum_p] = partial(_reduction_batcher, psum_p)
batching.axis_primitive_batchers[psum_p] = \
partial(_batched_reduction_collective, psum_p, lambda v, axis_size: axis_size * v)
Expand Down Expand Up @@ -845,7 +844,6 @@ def pos_reduce(x):
xla.register_collective_primitive(pmax_p)
mlir.register_lowering(
pmax_p, partial(_allreduce_lowering, lax.max_p, lax._reduce_max))
pxla.multi_host_supported_collectives.add(pmax_p)
batching.primitive_batchers[pmax_p] = partial(_reduction_batcher, pmax_p)
batching.axis_primitive_batchers[pmax_p] = \
partial(_batched_reduction_collective, pmax_p, lambda v, axis_size: v)
Expand All @@ -859,7 +857,6 @@ def pos_reduce(x):
xla.register_collective_primitive(pmin_p)
mlir.register_lowering(
pmin_p, partial(_allreduce_lowering, lax.min_p, lax._reduce_min))
pxla.multi_host_supported_collectives.add(pmin_p)
batching.primitive_batchers[pmin_p] = partial(_reduction_batcher, pmin_p)
batching.axis_primitive_batchers[pmin_p] = \
partial(_batched_reduction_collective, pmin_p, lambda v, axis_size: v)
Expand Down Expand Up @@ -928,7 +925,6 @@ def _collective_batcher(prim, args, dims, **params):
ad.deflinear2(ppermute_p, _ppermute_transpose_rule)
xla.register_collective_primitive(ppermute_p)
mlir.register_lowering(ppermute_p, _ppermute_lowering)
pxla.multi_host_supported_collectives.add(ppermute_p)
batching.primitive_batchers[ppermute_p] = partial(_collective_batcher, ppermute_p)
batching.axis_primitive_batchers[ppermute_p] = _ppermute_batcher
core.axis_substitution_rules[ppermute_p] = partial(_subst_all_names_in_param, 'axis_name')
Expand Down Expand Up @@ -1081,7 +1077,6 @@ def _all_to_all_abstract_eval(x, axis_name, split_axis, concat_axis, axis_index_
xla.register_collective_primitive(all_to_all_p)
mlir.register_lowering(all_to_all_p, _all_to_all_lowering)
ad.deflinear2(all_to_all_p, _all_to_all_transpose_rule)
pxla.multi_host_supported_collectives.add(all_to_all_p)
batching.primitive_batchers[all_to_all_p] = _all_to_all_batcher
batching.axis_primitive_batchers[all_to_all_p] = _all_to_all_batched_collective
core.axis_substitution_rules[all_to_all_p] = partial(_subst_all_names_in_param, 'axis_name')
Expand Down Expand Up @@ -1294,7 +1289,6 @@ def _all_gather_batched_collective(frame_size, frame_name, _, vals_in, dims_in,
xla.register_collective_primitive(all_gather_p)
mlir.register_lowering(all_gather_p, _all_gather_lowering)
ad.deflinear2(all_gather_p, _all_gather_transpose_rule)
pxla.multi_host_supported_collectives.add(all_gather_p)
batching.primitive_batchers[all_gather_p] = _all_gather_batcher
batching.axis_primitive_batchers[all_gather_p] = _all_gather_batched_collective
core.axis_substitution_rules[all_gather_p] = partial(_subst_all_names_in_param, 'axis_name')
Expand Down Expand Up @@ -1465,7 +1459,6 @@ def _reduce_scatter_collective(frame_size, frame_name, _, vals_in, dims_in,
mlir.register_lowering(
reduce_scatter_p,
partial(_reduce_scatter_lowering, lax.add_p, psum))
pxla.multi_host_supported_collectives.add(reduce_scatter_p)
core.axis_substitution_rules[reduce_scatter_p] = \
partial(_subst_all_names_in_param, 'axis_name')

Expand Down Expand Up @@ -1585,7 +1578,6 @@ def _axis_index_abstract_eval(*, axis_name):
xla.register_collective_primitive(axis_index_p)
mlir.register_lowering(axis_index_p, _axis_index_lowering)
axis_index_p.def_abstract_eval(_axis_index_abstract_eval)
pxla.multi_host_supported_collectives.add(axis_index_p)
core.axis_substitution_rules[axis_index_p] = partial(_subst_all_names_in_param, 'axis_name')

# Axis index doesn't get any arguments, so that the default bind would have no
Expand Down Expand Up @@ -1691,8 +1683,6 @@ def _pdot_transpose_rhs(g, x, y, *, axis_name, pos_contract, pos_batch, precisio
preferred_element_type=None)
ad.defbilinear(pdot_p, _pdot_transpose_lhs, _pdot_transpose_rhs)

pxla.multi_host_supported_collectives.add(pdot_p)


def _pgather_impl(src, idx, *, axes):
assert all(isinstance(axis, int) for axis in axes)
Expand Down

0 comments on commit c9cf6b4

Please sign in to comment.