diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index d297438bd7dd..6774de08df9e 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -1473,31 +1473,48 @@ def _reduce_scatter_collective(frame_size, frame_name, _, vals_in, dims_in, partial(_subst_all_names_in_param, 'axis_name') -def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, tiled=False): - """Compute an all-reduce sum over the axis ``axis_name``, and scatter the result. +def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, + tiled=False): + """ + Like ``psum(x, axis_name)`` but each device retains only part of the result. + + For example, ``psum_scatter(x, axis_name, scatter_dimension=0, tiled=False)`` + computes the same value as ``psum(x, axis_name)[axis_index(axis_name)]``, but + it is more efficient. Thus the ``psum`` result is left scattered along the + mapped axis. + + One efficient algorithm for computing ``psum(x, axis_name)`` is to perform a + ``psum_scatter`` followed by an ``all_gather``, essentially evaluating + ``all_gather(psum_scatter(x, axis_name))``. So we can think of + ``psum_scatter`` as "the first half" of a ``psum``. Args: x: array(s) with a mapped axis named ``axis_name``. - axis_name: hashable Python object used to name a pmapped axis (see the + axis_name: hashable Python object used to name a mapped axis (see the :func:`jax.pmap` documentation for more details). - scatter_dimension: a positional axis into which the all reduce result along + scatter_dimension: a positional axis into which the all-reduce result along ``axis_name`` will be scattered. - axis_index_groups: optional list of lists containing axis indices (e.g. for - an axis of size 4, [[0, 1], [2, 3]] would run reduce-scatter over the - first two and the last two replicas). Groups must cover all axis indices - exactly once, and all groups must be the same size. - tiled: when ``False``, the size of dimension in ``scatter_dimension`` must - match the size of axis ``axis_name`` (or the group size if - ``axis_index_groups`` is given). After scattering the all reduce result - along ``scatter_dimension``, the output is sequeezed by removing - ``scatter_dimension``. When ``True``, the size of dimension in - ``scatter_dimension` must be dividible by the size of axis ``axis_name`` - (or the group size if ``axis_index_groups`` is given), - and ``scatter_dimension`` is preserved. + axis_index_groups: optional list of lists of integers containing axis + indices. For example, for an axis of size 4, + ``axis_index_groups=[[0, 1], [2, 3]]`` would run reduce-scatter over the + first two and the last two axis indices. Groups must cover all axis + indices exactly once, and all groups must be the same size. + tiled: boolean representing whether to use rank-preserving 'tiled' behavior. + When ``False`` (the default value), the size of dimension in + ``scatter_dimension`` must match the size of axis ``axis_name`` (or the + group size if ``axis_index_groups`` is given). After scattering the + all-reduce result along ``scatter_dimension``, the output is sequeezed by + removing ``scatter_dimension``, so the result has lower rank than the + input. When ``True``, the size of dimension in ``scatter_dimension`` must + be dividible by the size of axis ``axis_name`` (or the group size if + ``axis_index_groups`` is given), and the ``scatter_dimension`` axis is + preserved (so the result has the same rank as the input). Returns: Array(s) with the similar shape as ``x``, except the size of dimension in - position``scatter_dimension`` is divided by the size of axis ``axis_name``. + position ``scatter_dimension`` is divided by the size of axis ``axis_name`` + (when ``tiled=True``), or the dimension in position ``scatter_dimension`` is + eliminated (when ``tiled=False``). For example, with 4 XLA devices available: