diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index 5b736c4a95aa..1d740e752d79 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -297,6 +297,7 @@ scan_p, while_loop, while_p, + associative_scan, ) from .lax_fft import ( fft, diff --git a/jax/lax/lax_control_flow.py b/jax/lax/lax_control_flow.py index 4f145760d57e..f78fd877ea17 100644 --- a/jax/lax/lax_control_flow.py +++ b/jax/lax/lax_control_flow.py @@ -1935,3 +1935,131 @@ def _linear_solve_batching_rule(args, dims, **kwargs): xla.lower_fun_initial_style(_custom_linear_solve_impl) ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule batching.primitive_batchers[linear_solve_p] = _linear_solve_batching_rule + + +def _interleave(a, b): + """Given two Tensors of static shape, interleave them along the first axis.""" + # TODO(mattjj) + import jax.numpy as np + # [a b c ...] [d e f ...] -> [a d b e c f ...] + half_num_elems = b.shape[0] + + if a.shape[0] > b.shape[0]: + return np.concatenate( + [np.reshape(np.stack([a[: -1], b], axis=1), + (2 * half_num_elems,) + a.shape[1:]), + a[-1:]], axis=0) + else: + return np.reshape(np.stack([a, b], axis=1), + (2 * half_num_elems,) + a.shape[1:]) + +def associative_scan(fn, elems): + """Perform a scan with an associative binary operation, in parallel. + + Args: + fn: Python callable implementing an associative binary operation with + signature `r = fn(a, b)`. This must satisfy associativity: + `fn(a, fn(b, c)) == fn(fn(a, b), c)`. The inputs and result are + (possibly nested structures of) `Tensor`(s), matching `elems`. Each + `Tensor` has a leading batch dimension in place of `num_elems`; the `fn` + is expected to map over this dimension. The result `r` has the same shape + (and structure) as the two inputs `a` and `b`. + elems: A (possibly nested structure of) `Tensor`(s), each with leading + dimension `num_elems`, which must be known statically. + Returns: + result: A (possibly nested structure of) `Tensor`(s) of the same shape + and structure as `elems`, in which the `k`th element is the result of + recursively applying `fn` to combine the first `k` elements of + `elems`. For example, given `elems = [a, b, c, ...]`, the result + would be `[a, fn(a, b), fn(fn(a, b), c), ...]`. + + #### Examples + + ```python + # Example 1: Partials sums of numbers. + + np.associative_scan(operator.add, np.arange(0, 4)) + # ==> [ 0, 1, 3, 6] + + # Example 2: Partial products of random matrices. + + np.associative_scan(np.matmul, matrices) + ``` + """ + elems_flat, tree = tree_flatten(elems) + + def lowered_fn(a_flat, b_flat): + # Lower `fn` to operate on flattened sequences of elems. + a = tree_unflatten(tree, a_flat) + b = tree_unflatten(tree, b_flat) + c = fn(a, b) + c_flat, _ = tree_flatten(c) + return c_flat + + # Check that all inputs have a consistent leading dimension `num_elems`. + num_elems = int(elems_flat[0].shape[0]) + + if not all(int(elem.shape[0]) == num_elems for elem in elems_flat[1:]): + raise ValueError('Input `Tensor`s must have the same first dimension.' + ' (saw: {})'.format([elems.shape for elem in elems_flat])) + + if num_elems < 2: + return elems + + # Summary of algorithm: + # + # Consider elements of `_scan(elems)` at odd indices. That's the same as first + # summing successive pairs of elements of `elems` and performing a scan on + # that half sized tensor. We perform the latter scan by recursion. + # + # Now consider the even elements of `_scan(elems)`. These can be computed + # from the odd elements of `_scan(elems)` by adding each odd element of + # `_scan(elems)` to the matching even element in the original `elems`. + # + # We return the odd and even elements interleaved. + # + # For the base case of the recursion we return the first element + # of `elems` followed by the sum of the first two elements computed as + # a (small two-down-to-one) reduction step. + def _scan(elems): + """Perform scan on `elems`.""" + + num_elems = elems[0].shape[0] + + reduced_elems = lowered_fn([elem[0:-1:2] for elem in elems], + [elem[1::2] for elem in elems]) + + if reduced_elems[0].shape[0] == 1: + # Base case has either 2 or 3 elements. + if num_elems == 2: + return [lax.concatenate([elem[0:1], reduced_elem], dimension=0) + for (reduced_elem, elem) in zip(reduced_elems, elems)] + elif num_elems == 3: + reduced_reduced_elems = lowered_fn( + reduced_elems, + [elem[2:3] for elem in elems]) + return [ + lax.concatenate([elem[0:1], reduced_elem, reduced_reduced_elem], + dimension=0) + for (reduced_reduced_elem, reduced_elem, elem) + in zip(reduced_reduced_elems, reduced_elems, elems)] + + # Recursively compute scan for partially reduced tensors. + odd_elems = _scan(reduced_elems) + + if num_elems % 2 == 0: + results = lowered_fn([odd_elem[:-1] for odd_elem in odd_elems], + [elem[2::2] for elem in elems]) + else: + results = lowered_fn([odd_elem for odd_elem in odd_elems], + [elem[2::2] for elem in elems]) + + # The first element of a scan is the same as the first element + # of the original `elems`. + even_elems = [lax.concatenate([elem[0:1], result], dimension=0) + for (elem, result) in zip(elems, results)] + return tuple(_map(_interleave, even_elems, odd_elems)) + + scans = _scan(elems_flat) + + return tree_unflatten(tree, scans) diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index c516df3b8e73..a334a5cac195 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -16,6 +16,7 @@ import collections from functools import partial import itertools +import operator import re from typing import Callable from unittest import SkipTest @@ -1988,6 +1989,27 @@ def f(x, n): return lax.fori_loop(0, n, lambda _, x: x + 1, x) x, n = jnp.arange(3), jnp.arange(4) api.vmap(api.vmap(f, (None, 0)), (0, None))(x, n) # doesn't crash + def testAssociativeScanUnstructured1000(self): + data = np.arange(1000) + expected = np.cumsum(data) + result = lax.associative_scan(operator.add, data) + self.assertAllClose(result, expected, check_dtypes=False) + + def testAssociativeScanStructured3(self): + pair = collections.namedtuple('pair', ('first', 'second')) + data = pair(first=np.array([0., 1., 2.]), + second=np.array([0., 10., 20.])) + + def fn(a, b): + return pair(first=a.first + b.first, + second=a.second + b.second) + + result = lax.associative_scan(fn, elems=data) + self.assertAllClose(result.first, np.array([0., 1., 3.]), + check_dtypes=False) + self.assertAllClose(result.second, np.array([0., 10., 30.]), + check_dtypes=False) + if __name__ == '__main__': absltest.main()