Skip to content

Commit

Permalink
Added associative_scan. (#2170)
Browse files Browse the repository at this point in the history
* Added `associative_scan`.

* Fixed problem where base case of associative scan could fail

* remove jax.numpy dependence in associative_scan

Co-authored-by: Matthew Johnson <mattjj@google.com>
  • Loading branch information
dpiponi and mattjj committed May 21, 2020
1 parent 5c1de28 commit c459280
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 0 deletions.
1 change: 1 addition & 0 deletions jax/lax/__init__.py
Expand Up @@ -297,6 +297,7 @@
scan_p,
while_loop,
while_p,
associative_scan,
)
from .lax_fft import (
fft,
Expand Down
128 changes: 128 additions & 0 deletions jax/lax/lax_control_flow.py
Expand Up @@ -1935,3 +1935,131 @@ def _linear_solve_batching_rule(args, dims, **kwargs):
xla.lower_fun_initial_style(_custom_linear_solve_impl)
ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule
batching.primitive_batchers[linear_solve_p] = _linear_solve_batching_rule


def _interleave(a, b):
"""Given two Tensors of static shape, interleave them along the first axis."""
# TODO(mattjj)
import jax.numpy as np
# [a b c ...] [d e f ...] -> [a d b e c f ...]
half_num_elems = b.shape[0]

if a.shape[0] > b.shape[0]:
return np.concatenate(
[np.reshape(np.stack([a[: -1], b], axis=1),
(2 * half_num_elems,) + a.shape[1:]),
a[-1:]], axis=0)
else:
return np.reshape(np.stack([a, b], axis=1),
(2 * half_num_elems,) + a.shape[1:])

def associative_scan(fn, elems):
"""Perform a scan with an associative binary operation, in parallel.
Args:
fn: Python callable implementing an associative binary operation with
signature `r = fn(a, b)`. This must satisfy associativity:
`fn(a, fn(b, c)) == fn(fn(a, b), c)`. The inputs and result are
(possibly nested structures of) `Tensor`(s), matching `elems`. Each
`Tensor` has a leading batch dimension in place of `num_elems`; the `fn`
is expected to map over this dimension. The result `r` has the same shape
(and structure) as the two inputs `a` and `b`.
elems: A (possibly nested structure of) `Tensor`(s), each with leading
dimension `num_elems`, which must be known statically.
Returns:
result: A (possibly nested structure of) `Tensor`(s) of the same shape
and structure as `elems`, in which the `k`th element is the result of
recursively applying `fn` to combine the first `k` elements of
`elems`. For example, given `elems = [a, b, c, ...]`, the result
would be `[a, fn(a, b), fn(fn(a, b), c), ...]`.
#### Examples
```python
# Example 1: Partials sums of numbers.
np.associative_scan(operator.add, np.arange(0, 4))
# ==> [ 0, 1, 3, 6]
# Example 2: Partial products of random matrices.
np.associative_scan(np.matmul, matrices)
```
"""
elems_flat, tree = tree_flatten(elems)

def lowered_fn(a_flat, b_flat):
# Lower `fn` to operate on flattened sequences of elems.
a = tree_unflatten(tree, a_flat)
b = tree_unflatten(tree, b_flat)
c = fn(a, b)
c_flat, _ = tree_flatten(c)
return c_flat

# Check that all inputs have a consistent leading dimension `num_elems`.
num_elems = int(elems_flat[0].shape[0])

if not all(int(elem.shape[0]) == num_elems for elem in elems_flat[1:]):
raise ValueError('Input `Tensor`s must have the same first dimension.'
' (saw: {})'.format([elems.shape for elem in elems_flat]))

if num_elems < 2:
return elems

# Summary of algorithm:
#
# Consider elements of `_scan(elems)` at odd indices. That's the same as first
# summing successive pairs of elements of `elems` and performing a scan on
# that half sized tensor. We perform the latter scan by recursion.
#
# Now consider the even elements of `_scan(elems)`. These can be computed
# from the odd elements of `_scan(elems)` by adding each odd element of
# `_scan(elems)` to the matching even element in the original `elems`.
#
# We return the odd and even elements interleaved.
#
# For the base case of the recursion we return the first element
# of `elems` followed by the sum of the first two elements computed as
# a (small two-down-to-one) reduction step.
def _scan(elems):
"""Perform scan on `elems`."""

num_elems = elems[0].shape[0]

reduced_elems = lowered_fn([elem[0:-1:2] for elem in elems],
[elem[1::2] for elem in elems])

if reduced_elems[0].shape[0] == 1:
# Base case has either 2 or 3 elements.
if num_elems == 2:
return [lax.concatenate([elem[0:1], reduced_elem], dimension=0)
for (reduced_elem, elem) in zip(reduced_elems, elems)]
elif num_elems == 3:
reduced_reduced_elems = lowered_fn(
reduced_elems,
[elem[2:3] for elem in elems])
return [
lax.concatenate([elem[0:1], reduced_elem, reduced_reduced_elem],
dimension=0)
for (reduced_reduced_elem, reduced_elem, elem)
in zip(reduced_reduced_elems, reduced_elems, elems)]

# Recursively compute scan for partially reduced tensors.
odd_elems = _scan(reduced_elems)

if num_elems % 2 == 0:
results = lowered_fn([odd_elem[:-1] for odd_elem in odd_elems],
[elem[2::2] for elem in elems])
else:
results = lowered_fn([odd_elem for odd_elem in odd_elems],
[elem[2::2] for elem in elems])

# The first element of a scan is the same as the first element
# of the original `elems`.
even_elems = [lax.concatenate([elem[0:1], result], dimension=0)
for (elem, result) in zip(elems, results)]
return tuple(_map(_interleave, even_elems, odd_elems))

scans = _scan(elems_flat)

return tree_unflatten(tree, scans)
22 changes: 22 additions & 0 deletions tests/lax_control_flow_test.py
Expand Up @@ -16,6 +16,7 @@
import collections
from functools import partial
import itertools
import operator
import re
from typing import Callable
from unittest import SkipTest
Expand Down Expand Up @@ -1988,6 +1989,27 @@ def f(x, n): return lax.fori_loop(0, n, lambda _, x: x + 1, x)
x, n = jnp.arange(3), jnp.arange(4)
api.vmap(api.vmap(f, (None, 0)), (0, None))(x, n) # doesn't crash

def testAssociativeScanUnstructured1000(self):
data = np.arange(1000)
expected = np.cumsum(data)
result = lax.associative_scan(operator.add, data)
self.assertAllClose(result, expected, check_dtypes=False)

def testAssociativeScanStructured3(self):
pair = collections.namedtuple('pair', ('first', 'second'))
data = pair(first=np.array([0., 1., 2.]),
second=np.array([0., 10., 20.]))

def fn(a, b):
return pair(first=a.first + b.first,
second=a.second + b.second)

result = lax.associative_scan(fn, elems=data)
self.assertAllClose(result.first, np.array([0., 1., 3.]),
check_dtypes=False)
self.assertAllClose(result.second, np.array([0., 10., 30.]),
check_dtypes=False)


if __name__ == '__main__':
absltest.main()

0 comments on commit c459280

Please sign in to comment.