Skip to content

Commit

Permalink
improve psum_scatter docstring (formatting and content)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Nov 15, 2023
1 parent 9b683e3 commit f33ef3f
Showing 1 changed file with 34 additions and 17 deletions.
51 changes: 34 additions & 17 deletions jax/_src/lax/parallel.py
Expand Up @@ -1473,31 +1473,48 @@ def _reduce_scatter_collective(frame_size, frame_name, _, vals_in, dims_in,
partial(_subst_all_names_in_param, 'axis_name')


def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, tiled=False):
"""Compute an all-reduce sum over the axis ``axis_name``, and scatter the result.
def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None,
tiled=False):
"""
Like ``psum(x, axis_name)`` but each device retains only part of the result.
For example, ``psum_scatter(x, axis_name, scatter_dimension=0, tiled=False)``
computes the same value as ``psum(x, axis_name)[axis_index(axis_name)]``, but
it is more efficient. Thus the ``psum`` result is left scattered along the
mapped axis.
One efficient algorithm for computing ``psum(x, axis_name)`` is to perform a
``psum_scatter`` followed by an ``all_gather``, essentially evaluating
``all_gather(psum_scatter(x, axis_name))``. So we can think of
``psum_scatter`` as "the first half" of a ``psum``.
Args:
x: array(s) with a mapped axis named ``axis_name``.
axis_name: hashable Python object used to name a pmapped axis (see the
axis_name: hashable Python object used to name a mapped axis (see the
:func:`jax.pmap` documentation for more details).
scatter_dimension: a positional axis into which the all reduce result along
scatter_dimension: a positional axis into which the all-reduce result along
``axis_name`` will be scattered.
axis_index_groups: optional list of lists containing axis indices (e.g. for
an axis of size 4, [[0, 1], [2, 3]] would run reduce-scatter over the
first two and the last two replicas). Groups must cover all axis indices
exactly once, and all groups must be the same size.
tiled: when ``False``, the size of dimension in ``scatter_dimension`` must
match the size of axis ``axis_name`` (or the group size if
``axis_index_groups`` is given). After scattering the all reduce result
along ``scatter_dimension``, the output is sequeezed by removing
``scatter_dimension``. When ``True``, the size of dimension in
``scatter_dimension` must be dividible by the size of axis ``axis_name``
(or the group size if ``axis_index_groups`` is given),
and ``scatter_dimension`` is preserved.
axis_index_groups: optional list of lists of integers containing axis
indices. For example, for an axis of size 4,
``axis_index_groups=[[0, 1], [2, 3]]`` would run reduce-scatter over the
first two and the last two axis indices. Groups must cover all axis
indices exactly once, and all groups must be the same size.
tiled: boolean representing whether to use rank-preserving 'tiled' behavior.
When ``False`` (the default value), the size of dimension in
``scatter_dimension`` must match the size of axis ``axis_name`` (or the
group size if ``axis_index_groups`` is given). After scattering the
all-reduce result along ``scatter_dimension``, the output is sequeezed by
removing ``scatter_dimension``, so the result has lower rank than the
input. When ``True``, the size of dimension in ``scatter_dimension`` must
be dividible by the size of axis ``axis_name`` (or the group size if
``axis_index_groups`` is given), and the ``scatter_dimension`` axis is
preserved (so the result has the same rank as the input).
Returns:
Array(s) with the similar shape as ``x``, except the size of dimension in
position``scatter_dimension`` is divided by the size of axis ``axis_name``.
position ``scatter_dimension`` is divided by the size of axis ``axis_name``
(when ``tiled=True``), or the dimension in position ``scatter_dimension`` is
eliminated (when ``tiled=False``).
For example, with 4 XLA devices available:
Expand Down

0 comments on commit f33ef3f

Please sign in to comment.