Skip to content

Commit

Permalink
ReduceScatter translation and abstract eval.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 387152857
  • Loading branch information
xnning authored and jax authors committed Jul 27, 2021
1 parent 717540d commit e7f0307
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 0 deletions.
162 changes: 162 additions & 0 deletions jax/_src/lax/parallel.py
Expand Up @@ -1129,6 +1129,168 @@ def _all_gather_batched_collective(frame, vals_in, dims_in, all_gather_dimension
batching.collective_rules[all_gather_p] = _all_gather_batched_collective
core.axis_substitution_rules[all_gather_p] = partial(_subst_all_names_in_param, 'axis_name')


def _reduce_scatter_via_reducer(x, *, reducer, scatter_dimension, axis_name, axis_index_groups, axis_size, tiled):
index = _index_in_group(axis_name, axis_index_groups)
scatter_dim_input_size = x.shape[scatter_dimension]
if tiled and scatter_dim_input_size % axis_size != 0:
raise ValueError(f"tiled reduce_scatter operand scatter dimension size "
f"{scatter_dim_input_size} must be divisible by "
f"shard count {axis_size}")
elif not tiled and scatter_dim_input_size != axis_size:
raise ValueError(f"reduce_scatter operand scatter dimension size "
f"{scatter_dim_input_size} must match shard count"
f"{axis_size}")
scatter_dim_output_size = scatter_dim_input_size // axis_size

outs = reducer(x, axis_name=axis_name, axis_index_groups=axis_index_groups)
outs = lax.dynamic_slice_in_dim(
outs,
start_index=index * scatter_dim_output_size,
slice_size=scatter_dim_output_size,
axis=scatter_dimension)
if not tiled:
outs = lax.squeeze(outs, [scatter_dimension])
return outs


def _reduce_scatter_translation_rule(prim, reducer, c, x, *, scatter_dimension, axis_name,axis_index_groups, axis_size, tiled, axis_env, platform):
# TODO(b/194706412): Enable this for TPU?
if platform == "gpu":
scalar = ShapedArray((), c.get_shape(x).numpy_dtype())
computation = xla.primitive_subcomputation(prim, scalar, scalar)
replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups)
x = xops.ReduceScatter(
x,
computation,
scatter_dimension=scatter_dimension,
shard_count=axis_size,
replica_groups=xc.make_replica_groups(replica_groups))
if not tiled:
new_shape = list(c.get_shape(x).dimensions())
del new_shape[scatter_dimension]
x = xops.Reshape(x, new_shape)
return x
else:
return xla.lower_fun(
_reduce_scatter_via_reducer, multiple_results=False, parallel=True)(
c,
x,
reducer=reducer,
scatter_dimension=scatter_dimension,
axis_name=axis_name,
axis_index_groups=axis_index_groups,
axis_size=axis_size,
tiled=tiled,
axis_env=axis_env,
platform=platform)


def _reduce_scatter_abstract_eval(x, *, axis_name, scatter_dimension,
axis_index_groups, axis_size, tiled):
if not isinstance(axis_name, (list, tuple)):
axis_name = (axis_name,)
x_aval = core.raise_to_shaped(x)
new_shape = list(x_aval.shape)
scatter_dim_input_size = x_aval.shape[scatter_dimension]
if tiled:
if scatter_dim_input_size % axis_size != 0:
raise ValueError(f"tiled reduce_scatter operand scatter dimension size "
f"{scatter_dim_input_size} must be divisible by "
f"shard_count {axis_size}")
new_shape[scatter_dimension] = scatter_dim_input_size // axis_size
else:
if scatter_dim_input_size != axis_size:
raise ValueError(f"reduce_scatter operand scatter dimension size "
f"{scatter_dim_input_size} must match shard count "
f"{axis_size}")
del new_shape[scatter_dimension]

new_named_shape = {
name: size
for name, size in x_aval.named_shape.items()
if name not in axis_name
}
return x_aval.update(shape=new_shape, named_shape=new_named_shape)


reduce_scatter_p = core.AxisPrimitive("reduce_scatter")
reduce_scatter_p.def_abstract_eval(_reduce_scatter_abstract_eval)
xla.parallel_translations[reduce_scatter_p] = partial(
_reduce_scatter_translation_rule, lax.add_p, psum)
pxla.multi_host_supported_collectives.add(reduce_scatter_p)


def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, tiled=False):
"""Compute an all-reduce sum over the axis ``axis_name``, and scatter the result.
Args:
x: array(s) with a mapped axis named ``axis_name``.
axis_name: hashable Python object used to name a pmapped axis (see the
:func:`jax.pmap` documentation for more details).
scatter_dimension: a positional axis into which the all reduce result along
``axis_name`` will be scattered.
axis_index_groups: optional list of lists containing axis indices (e.g. for
an axis of size 4, [[0, 1], [2, 3]] would run reduce-scatter over the
first two and the last two replicas). Groups must cover all axis indices
exactly once, and all groups must be the same size.
tiled: when ``False``, the size of dimension in ``scatter_dimension`` must
match the size of axis ``axis_name`` (or the group size if
``axis_index_groups`` is given). After scattering the all reduce result
along ``scatter_dimension``, the output is sequeezed by removing
``scatter_dimension``. When ``True``, the size of dimension in
``scatter_dimension` must be dividible by the size of axis ``axis_name``
(or the group size if ``axis_index_groups`` is given),
and ``scatter_dimension`` is preserved.
Returns:
Array(s) with the similar shape as ``x``, except the size of dimension in
position``scatter_dimension`` is divided by the size of axis ``axis_name``.
For example, with 4 XLA devices available:
>>> x = np.arange(16).reshape(4,4)
>>> print(x)
[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]
[12 13 14 15]]
>>> y = jax.pmap(lambda x: jax.lax.psum_scatter(x, 'i'), axis_name='i')(x)
>>> print(y)
[24 28 32 36]
if using tiled:
>>> y = jax.pmap(lambda x: jax.lax.psum_scatter(x, 'i', tiled=True), axis_name='i')(x)
>>> print(y)
[[24]
[28]
[32]
[36]]
An example of using axis_index_groups:
>>> def f(x):
... return jax.lax.psum_scatter(
... x, 'i', axis_index_groups=[[0, 2], [3, 1]], tiled=True)
>>> y = jax.pmap(f, axis_name='i')(x)
>>> print(y)
[[ 8 10]
[20 22]
[12 14]
[16 18]]
"""
axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups)
bind = partial(
reduce_scatter_p.bind,
axis_name=axis_name,
scatter_dimension=scatter_dimension,
axis_index_groups=axis_index_groups,
axis_size=axis_size,
tiled=tiled)
return tree_util.tree_map(bind, x)


def _axis_index_translation_rule(c, *, axis_name, axis_env, platform):
axis_pos = list(axis_env.names).index(axis_name)
nreplicas = axis_env.nreps // prod(axis_env.sizes)
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -932,6 +932,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
"psum",
"pmax",
"pgather",
"reduce_scatter",
"axis_index",
"pdot",
"all_gather",
Expand Down
1 change: 1 addition & 0 deletions jax/lax/__init__.py
Expand Up @@ -349,6 +349,7 @@
pshuffle,
psum,
psum_p,
psum_scatter,
pswapaxes,
pdot,
xeinsum,
Expand Down
48 changes: 48 additions & 0 deletions tests/pmap_test.py
Expand Up @@ -159,6 +159,54 @@ def testGatherTiled(self):
ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)

def testReduceScatter(self):
f = pmap(lambda x: lax.psum_scatter(x, 'i'), axis_name='i')

device_count = xla_bridge.device_count()
shape = (device_count, device_count)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
expected = np.sum(x, axis=0)
ans = f(x)
for i, actual in enumerate(ans):
self.assertAllClose(actual, expected[i])

def testReduceScatterTiled(self):
f = pmap(lambda x: lax.psum_scatter(x, 'i', tiled=True), axis_name='i')

device_count = xla_bridge.device_count()
shape = (device_count, 4 * device_count)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
expected = np.sum(x, axis=0)
ans = f(x)
scatter_len = len(expected) // device_count
for i, actual in enumerate(ans):
self.assertAllClose(actual,
expected[i * scatter_len:(i + 1) * scatter_len])

def testReduceScatterReplicaGroupsTiled(self):
replicas = xla_bridge.device_count()
if replicas % 2 != 0:
raise SkipTest
axis_index_groups = [[i for i in range(jax.device_count()) if i % 2 == 0],
[i for i in range(jax.device_count()) if i % 2 != 0]]
f = lambda x: lax.psum_scatter(
x, 'i', axis_index_groups=axis_index_groups, tiled=True)
f = pmap(f, axis_name='i')

shape = (replicas, 4 * replicas)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
ans = f(x)

group_1_result = np.sum(x[0::2,:], axis=0)
group_2_result = np.sum(x[1::2,:], axis=0)
# the result is scattered over (replicas // 2) devices
scatter_len = len(group_1_result) * 2 // replicas

for i, actual in enumerate(ans):
expected = group_1_result if i % 2 == 0 else group_2_result
self.assertAllClose(
actual, expected[i // 2 * scatter_len:(i // 2 + 1) * scatter_len])

@ignore_slow_all_to_all_warning()
def testTrees(self):
ptranspose = lambda x, axis_name: lax.all_to_all(x, axis_name, 0, 0)
Expand Down

0 comments on commit e7f0307

Please sign in to comment.