Skip to content

Commit

Permalink
Add scan implementation using for and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sharadmv committed Jun 29, 2022
1 parent e1ba52b commit 7901359
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 13 deletions.
88 changes: 82 additions & 6 deletions jax/_src/lax/control_flow/for_loop.py
Expand Up @@ -15,24 +15,27 @@
from functools import partial
import operator

from typing import Any, Callable, Dict, Generic, List, Sequence, Tuple, TypeVar
from typing import Any, Callable, Dict, Generic, List, Optional, Sequence, Tuple, TypeVar

from jax import core
from jax import lax
from jax import linear_util as lu
from jax.api_util import flatten_fun_nokwargs
from jax.interpreters import ad
from jax.interpreters import masking
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import (tree_flatten, tree_structure, tree_unflatten,
treedef_tuple, PyTreeDef)
treedef_tuple, tree_map, tree_leaves, PyTreeDef)
from jax._src import ad_util
from jax._src import dtypes
from jax._src import pretty_printer as pp
from jax._src.util import safe_map, safe_zip, split_list
import jax.numpy as jnp

from jax._src.lax.control_flow import loops
from jax._src.lax.control_flow.common import _abstractify, _initial_style_jaxpr

## JAX utilities

Expand Down Expand Up @@ -433,7 +436,8 @@ def val_to_ref_aval(x) -> ShapedArrayRef:
raise Exception(f"can't make ref from {x}")
return ShapedArrayRef(aval.shape, aval.dtype)

def for_loop(nsteps: int, body: Callable[[Array, Ref[S]], None], init_state: S) -> S:
def for_loop(nsteps: int, body: Callable[[Array, Ref[S]], None], init_state: S,
*, reverse: bool = False) -> S:
"""A for-loop combinator that allows read/write semantics in the loop body.
`for_loop` is a higher-order function that enables writing loops that can be
Expand Down Expand Up @@ -476,12 +480,81 @@ def for_loop(nsteps, body, init_state):
jaxpr = _hoist_consts_to_refs(jaxpr)
which_linear = (False,) * (len(consts) + len(flat_state))
out_flat = for_p.bind(*consts, *flat_state, jaxpr=jaxpr, nsteps=int(nsteps),
reverse=False, which_linear=which_linear)
reverse=reverse, which_linear=which_linear)
# Consts are `Ref`s so they are both inputs and outputs. We remove them from
# the outputs.
out_flat = out_flat[len(consts):]
return tree_unflatten(state_tree, out_flat)

Carry = TypeVar('Carry')
X = TypeVar('X')
Y = TypeVar('Y')

def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
init: Carry,
xs: X,
length: Optional[int] = None,
reverse: bool = False,
unroll: int = 1) -> Tuple[Carry, Y]:
if unroll != 1:
raise NotImplementedError("Unroll not implemented")
if not callable(f):
raise TypeError("scan: f argument should be a callable.")
xs_flat, xs_tree = tree_flatten(xs)

try:
lengths = [x.shape[0] for x in xs_flat]
except AttributeError as err:
msg = "scan got value with no leading axis to scan over: {}."
raise ValueError(
msg.format(', '.join(str(x) for x in xs_flat
if not hasattr(x, 'shape')))) from err

if length is not None:
length = int(length)
if not all(length == l for l in lengths):
msg = ("scan got `length` argument of {} which disagrees with "
"leading axis sizes {}.")
raise ValueError(msg.format(length, [x.shape[0] for x in xs_flat]))
else:
unique_lengths = set(lengths)
if len(unique_lengths) > 1:
msg = "scan got values with different leading axis sizes: {}."
raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat)))
elif len(unique_lengths) == 0:
msg = "scan got no values to scan over and `length` not provided."
raise ValueError(msg)
else:
length, = unique_lengths

x_shapes = [masking.padded_shape_as_value(x.shape[1:]) for x in xs_flat]
x_dtypes = [dtypes.canonicalize_dtype(x.dtype) for x in xs_flat]
x_avals = tuple(map(core.ShapedArray, x_shapes, x_dtypes))

def _create_jaxpr(init):
init_flat = tree_leaves(init)
_, in_tree = tree_flatten((init, xs))

carry_avals = tuple(map(_abstractify, init_flat))
jaxpr, _, out_tree = _initial_style_jaxpr(
f, in_tree, carry_avals + x_avals, "scan")
return jaxpr, out_tree
jaxpr, out_tree = _create_jaxpr(init)
_, ys_avals = tree_unflatten(out_tree, jaxpr.out_avals)
ys = tree_map(lambda aval: jnp.zeros([length, *aval.shape], aval.dtype),
ys_avals)
def for_body(i, refs):
carry_refs, xs_refs, ys_refs = refs
carry = tree_map(lambda x: x[()], carry_refs)
x = tree_map(lambda x: x[i], xs_refs)
carry, y = f(carry, x)
tree_map(lambda c_ref, c: ref_set(c_ref, (), c), carry_refs, carry)
tree_map(lambda y_ref, y: ref_set(y_ref, (i,), y), ys_refs, y)
assert isinstance(length, int)
init, _, ys = for_loop(length, for_body, (init, xs, ys), reverse=reverse)
return init, ys


@for_p.def_abstract_eval
def _for_abstract_eval(*avals, jaxpr, **__):
return list(avals)
Expand Down Expand Up @@ -545,7 +618,7 @@ def _for_jvp(primals, tangents, *, jaxpr, nsteps, reverse, which_linear):

### Testing utility

def discharged_for_loop(nsteps, body, init_state):
def discharged_for_loop(nsteps, body, init_state, *, reverse: bool = False):
"""A `for_loop` implementation that discharges its body right away.
Potentially useful for testing and benchmarking.
Expand All @@ -560,8 +633,11 @@ def discharged_for_loop(nsteps, body, init_state):
discharged_jaxpr, discharged_consts = discharge_state(jaxpr, consts)

def fori_body(i, carry):
i = jnp.int32(i)
if reverse:
i = nsteps - i - 1
out_flat = core.eval_jaxpr(discharged_jaxpr, discharged_consts,
jnp.int32(i), *carry)
i, *carry)
return out_flat
out_flat = loops.fori_loop(0, nsteps, fori_body, flat_state)
return tree_unflatten(state_tree, out_flat)
30 changes: 23 additions & 7 deletions tests/lax_control_flow_test.py
Expand Up @@ -70,6 +70,8 @@ def scan_with_new_checkpoint2(f, *args, **kwargs):
return new_checkpoint(partial(lax.scan, f, **kwargs),
policy=checkpoint_policies.everything_saveable)(*args)

def scan_with_for(f, *args, **kwargs):
return for_loop.scan(f, *args, **kwargs)

COND_IMPLS = [
(lax.cond, 'cond'),
Expand All @@ -84,6 +86,14 @@ def scan_with_new_checkpoint2(f, *args, **kwargs):
(scan_with_new_checkpoint2, 'new_checkpoint2'),
]

SCAN_IMPLS_WITH_FOR = [
(lax.scan, 'unroll1'),
(partial(lax.scan, unroll=2), 'unroll2'),
(scan_with_new_checkpoint , 'new_checkpoint'),
(scan_with_new_checkpoint2, 'new_checkpoint2'),
(scan_with_for, 'for'),
]


def while_loop_reference(cond, body, carry):
while cond(carry):
Expand Down Expand Up @@ -1454,11 +1464,12 @@ def cond(x):
@parameterized.named_parameters(
{"testcase_name": "_jit_scan={}_jit_f={}_impl={}".format(
jit_scan, jit_f, scan_name),
"jit_scan": jit_scan, "jit_f": jit_f, "scan": scan_impl}
"jit_scan": jit_scan, "jit_f": jit_f, "scan": scan_impl,
"impl_name": scan_name}
for jit_scan in [False, True]
for jit_f in [False, True]
for scan_impl, scan_name in SCAN_IMPLS)
def testScanImpl(self, jit_scan, jit_f, scan):
for scan_impl, scan_name in SCAN_IMPLS_WITH_FOR)
def testScanImpl(self, jit_scan, jit_f, scan, impl_name):
rng = self.rng()

d = rng.randn(2)
Expand All @@ -1480,20 +1491,25 @@ def f(c, a):

ans = scan(f, c, as_)
expected = scan_reference(f, c, as_)
rtol = {np.float64: 1.4e-15}
atol = {np.float64: 8e-15}
if impl_name == "for":
rtol[np.float32] = 8e-5
atol[np.float32] = 3e-5
self.assertAllClose(
ans,
expected,
check_dtypes=False,
rtol={np.float64: 1.4e-15},
atol={np.float64: 8e-15})
rtol=rtol,
atol=atol)

@parameterized.named_parameters(
{"testcase_name": "_jit_scan={}_jit_f={}_impl={}".format(
jit_scan, jit_f, scan_name),
"jit_scan": jit_scan, "jit_f": jit_f, "scan": scan_impl}
for jit_scan in [False, True]
for jit_f in [False, True]
for scan_impl, scan_name in SCAN_IMPLS)
for scan_impl, scan_name in SCAN_IMPLS_WITH_FOR)
def testScanJVP(self, jit_scan, jit_f, scan):
rng = self.rng()

Expand Down Expand Up @@ -2141,7 +2157,7 @@ def body(i, x):
@parameterized.named_parameters(
{"testcase_name": f"_{scan_name}",
"scan": scan_impl}
for scan_impl, scan_name in SCAN_IMPLS)
for scan_impl, scan_name in SCAN_IMPLS_WITH_FOR)
def test_scan_reverse(self, scan):
def cumsum(x, reverse):
return scan(lambda c, x: (c + x, c + x), 0, x, reverse=reverse)[1]
Expand Down

0 comments on commit 7901359

Please sign in to comment.