Skip to content

Commit

Permalink
add optional 'forward' argument to lax.scan (#2921)
Browse files Browse the repository at this point in the history
* add optional 'forward' argument to lax.scan

* switch to reverse; revise disable-jit case

* fix jaxpr.rst

* fix loops.py

Co-authored-by: James Bradbury <jekbradbury@gmail.com>
  • Loading branch information
mattjj and jekbradbury committed May 5, 2020
1 parent 3e52237 commit 3cd409e
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 31 deletions.
6 changes: 3 additions & 3 deletions docs/jaxpr.rst
Original file line number Diff line number Diff line change
Expand Up @@ -375,16 +375,16 @@ For the example consider the function ``func11`` below
...
>>> print(make_jaxpr(func11)(onp.ones(16), 5.))
{ lambda c ; a b.
let d e = scan[ forward=True
jaxpr={ lambda ; f a b c.
let d e = scan[ jaxpr={ lambda ; f a b c.
let d = mul b c
e = add a d
g = add e f
in (g, a) }
length=16
linear=(False, False, False, False)
num_carry=1
num_consts=1 ] b 0.0 a c
num_consts=1
reverse=False ] b 0.0 a c
in (d, e) }

The top-level jaxpr has one constvar ``c`` corresponding to the ``ones`` constant,
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def build_output_vals(self, scope, carried_state_names, carried_tree,
arange_val = jnp.arange(self.start, stop=self.stop, step=self.step)
return lax_control_flow.scan_p.bind(*itertools.chain(body_const_vals,
init_vals, [arange_val]),
forward=True, length=arange_val.shape[0],
reverse=False, length=arange_val.shape[0],
jaxpr=body_typed_jaxpr,
num_consts=len(body_const_vals),
num_carry=len(init_vals),
Expand Down
53 changes: 29 additions & 24 deletions jax/lax/lax_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
split_dict, cache, extend_name_stack)
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
treedef_children, treedef_tuple, tree_leaves,
tree_multimap)
tree_map, tree_multimap)
from jax import ad_util

xops = xla_client.ops
Expand Down Expand Up @@ -823,7 +823,7 @@ def cond_bind(*args, true_jaxpr, false_jaxpr, linear):

### scan

def scan(f, init, xs, length=None):
def scan(f, init, xs, length=None, reverse=False):
"""Scan a function over leading array axes while carrying along state.
The type signature in brief is
Expand Down Expand Up @@ -883,6 +883,9 @@ def scan(f, init, xs, length=None):
length: optional integer specifying the number of loop iterations, which
must agree with the sizes of leading axes of the arrays in ``xs`` (but can
be used to perform scans where no input ``xs`` are needed).
reverse: optional boolean specifying whether to run the scan iteration
forward (the default) or in reverse, equivalent to reversing the leading
axes of the arrays in both ``xs`` and in ``ys``.
Returns:
A pair of type ``(c, [b])`` where the first element represents the final
Expand Down Expand Up @@ -921,13 +924,15 @@ def scan(f, init, xs, length=None):
if jax.api._jit_is_disabled():
carry = init
ys = []
for i in range(length):
maybe_reversed = reversed if reverse else lambda x: x
for i in maybe_reversed(range(length)):
xs_slice = [_index_array(i, core.get_aval(x), x) for x in xs_flat]
carry, y = f(carry, tree_unflatten(xs_tree, xs_slice))
ys.append(y)
stack = lambda y, *ys: (y if core.get_aval(y) is core.abstract_unit
else jax.numpy.stack((y, *ys)))
return carry, tree_multimap(stack, *ys)
ys = tree_multimap(stack, *maybe_reversed(ys))
return carry, ys

carry_avals = tuple(_map(_abstractify, init_flat))
x_shapes = [masking.padded_shape_as_value(x.shape[1:]) for x in xs_flat]
Expand All @@ -944,12 +949,12 @@ def scan(f, init, xs, length=None):
init_tree, carry_avals)

out = scan_p.bind(*itertools.chain(consts, in_flat),
forward=True, length=length, jaxpr=jaxpr,
reverse=reverse, length=length, jaxpr=jaxpr,
num_consts=len(consts), num_carry=len(init_flat),
linear=(False,) * (len(consts) + len(in_flat)))
return tree_unflatten(out_tree, out)

def _scan_impl(*args, forward, length, num_consts, num_carry, jaxpr, linear):
def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear):
consts, init, xs = split_list(args, [num_consts, num_carry])
_, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
Expand All @@ -960,7 +965,7 @@ def cond_fun(vals):

def body_fun(vals):
[i], carry, ys = split_list(vals, [1, num_carry])
i_ = i if forward else length - i - 1
i_ = length - i - 1 if reverse else i
x = _map(partial(_index_array, i_), x_avals, xs)
out_flat = core.jaxpr_as_fun(jaxpr)(*(consts + carry + x))
carry_out, y_updates = split_list(out_flat, [num_carry])
Expand Down Expand Up @@ -993,13 +998,13 @@ def _update_array(i, aval, xs, x):
else:
return lax.dynamic_update_index_in_dim(xs, x, i, 0)

def _scan_abstract_eval(*args, forward, length, num_consts, num_carry, jaxpr, linear):
def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr, linear):
carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
ys_avals = [ShapedArray((length,) + aval.shape, aval.dtype)
if aval is not core.abstract_unit else aval for aval in y_avals]
return carry_avals + ys_avals

def _scan_jvp(primals, tangents, forward, length, jaxpr, num_consts, num_carry,
def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry,
linear):
num_xs = len(jaxpr.in_avals) - num_carry - num_consts
num_ys = len(jaxpr.out_avals) - num_carry
Expand Down Expand Up @@ -1043,7 +1048,7 @@ def _scan_jvp(primals, tangents, forward, length, jaxpr, num_consts, num_carry,

out_flat = scan_p.bind(
*(consts + consts_dot + init + init_dot + xs + xs_dot),
forward=forward, length=length, jaxpr=jaxpr_jvp_rearranged,
reverse=reverse, length=length, jaxpr=jaxpr_jvp_rearranged,
num_consts=num_consts+len(consts_dot), num_carry=num_carry+len(init_dot),
linear=jaxpr_jvp_linear)

Expand All @@ -1057,10 +1062,10 @@ def _scan_jvp(primals, tangents, forward, length, jaxpr, num_consts, num_carry,
def _prune_zeros(ts):
return [t for t in ts if t is not ad_util.zero]

def _scan_partial_eval(trace, *tracers, forward, length, num_consts, num_carry,
def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
jaxpr, linear):
if trace.master.trace_type is pe.StagingJaxprTrace:
params = {"forward": forward, "length": length, "num_consts": num_consts,
params = {"reverse": reverse, "length": length, "num_consts": num_consts,
"num_carry": num_carry, "jaxpr": jaxpr, "linear": linear}
return trace.default_process_primitive(scan_p, tracers, params)

Expand Down Expand Up @@ -1126,7 +1131,7 @@ def _scan_partial_eval(trace, *tracers, forward, length, num_consts, num_carry,
[lin or uk for uk, lin
in zip(unknowns[num_consts:], linear[num_consts:])])
out_flat = scan_p.bind(
*in_consts, forward=forward, length=length, jaxpr=jaxpr_1_opt,
*in_consts, reverse=reverse, length=length, jaxpr=jaxpr_1_opt,
num_consts=num_consts_1, num_carry=num_carry, linear=tuple(linear_1))
out_carry, ys, res_and_units = split_list(out_flat, [num_carry, num_ys])
extensive_residuals = [r for r, (pv, _) in zip(res_and_units, res_pvals) if pv is not None]
Expand All @@ -1148,7 +1153,7 @@ def _scan_partial_eval(trace, *tracers, forward, length, num_consts, num_carry,
[False] * len(ext_res_tracers))
eqn = pe.new_eqn_recipe(int_res_tracers + new_tracers + ext_res_tracers,
out_tracers, scan_p,
dict(forward=forward, length=length, jaxpr=jaxpr_2_opt,
dict(reverse=reverse, length=length, jaxpr=jaxpr_2_opt,
num_consts=num_consts_2,
num_carry=num_carry, linear=tuple(linear_2)))
for t in out_tracers: t.recipe = eqn
Expand All @@ -1160,7 +1165,7 @@ def _promote_aval_rank(sz, aval):
else:
return ShapedArray((sz,) + aval.shape, aval.dtype)

def _scan_transpose(cts, *args, forward, length, num_consts, num_carry, jaxpr, linear):
def _scan_transpose(cts, *args, reverse, length, num_consts, num_carry, jaxpr, linear):
# we've only implemented transposing scans with specific lin/nonlin patterns
consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry])
num_ires = len(consts_lin) - sum(consts_lin)
Expand Down Expand Up @@ -1194,7 +1199,7 @@ def _scan_transpose(cts, *args, forward, length, num_consts, num_carry, jaxpr, l
[False] * num_eres)

outs = scan_p.bind(
*(ires + ct_consts + ct_carry + ct_ys + eres), forward=not forward,
*(ires + ct_consts + ct_carry + ct_ys + eres), reverse=not reverse,
length=length, jaxpr=jaxpr_trans, num_consts=num_ires,
num_carry=num_consts-num_ires+num_carry, linear=tuple(linear_trans))
ct_consts, ct_init, ct_xs = split_list(outs, [num_consts - num_ires, num_carry])
Expand Down Expand Up @@ -1231,7 +1236,7 @@ def _make_typed_jaxpr(traceable: lu.WrappedFun, in_avals: Sequence[core.Abstract
return core.TypedJaxpr(jaxpr, consts, in_avals, _map(raise_to_shaped, out_avals))


def _scan_batching_rule(args, dims, forward, length, jaxpr, num_consts,
def _scan_batching_rule(args, dims, reverse, length, jaxpr, num_consts,
num_carry, linear):
num_ys = len(jaxpr.out_avals) - num_carry
size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
Expand Down Expand Up @@ -1268,22 +1273,22 @@ def _scan_batching_rule(args, dims, forward, length, jaxpr, num_consts,
else x for x, d in zip(xs, xs_bdims)]
new_args = new_consts + new_init + new_xs

outs = scan_p.bind(*new_args, forward=forward, length=length, jaxpr=jaxpr_batched,
outs = scan_p.bind(*new_args, reverse=reverse, length=length, jaxpr=jaxpr_batched,
num_consts=num_consts, num_carry=num_carry, linear=linear)
carry_bdims = [0 if b else batching.not_mapped for b in carry_batched]
ys_bdims = [1 if b else batching.not_mapped for b in ys_batched]
return outs, carry_bdims + ys_bdims

def _scan_shape_rule(shapes, forward, length, jaxpr,
def _scan_shape_rule(shapes, reverse, length, jaxpr,
num_consts, num_carry, linear):
const_shexprs, init_shexprs, xs_shexprs = split_list(shapes, [num_consts, num_carry])
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
ys_shapes = [(length,) + tuple(y_aval.shape) for y_aval in y_avals]
return init_shexprs + ys_shapes

def _scan_masking_rule(shape_envs, padded_vals, shape_exprs, forward, length,
def _scan_masking_rule(shape_envs, padded_vals, shape_exprs, reverse, length,
jaxpr, num_consts, num_carry, linear):
out_shape = _scan_shape_rule(shape_exprs, forward, length, jaxpr,
out_shape = _scan_shape_rule(shape_exprs, reverse, length, jaxpr,
num_consts, num_carry, linear)
dynamic_length = length.evaluate(shape_envs.logical)
masked_jaxpr = _masked_scan_jaxpr(jaxpr, num_consts, num_carry)
Expand All @@ -1292,7 +1297,7 @@ def _scan_masking_rule(shape_envs, padded_vals, shape_exprs, forward, length,
const_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry])
out_vals = scan_p.bind(
*itertools.chain([dynamic_length] + consts, [0], init, xs),
forward=forward, length=max_length, jaxpr=masked_jaxpr,
reverse=reverse, length=max_length, jaxpr=masked_jaxpr,
num_consts=1 + num_consts, num_carry=1 + num_carry,
linear=tuple([False] + const_linear + [False] + init_linear + xs_linear))
return out_vals[1:], out_shape
Expand All @@ -1314,7 +1319,7 @@ def masked(*args):
const_avals, carry_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
return _make_typed_jaxpr(masked, [aval] + const_avals + [aval] + carry_avals + x_avals)

def scan_bind(*args, forward, length, num_consts, num_carry, jaxpr, linear):
def scan_bind(*args, reverse, length, num_consts, num_carry, jaxpr, linear):
if not core.skip_checks:
assert len(linear) == len(args)
consts, init, xs = split_list(args, [num_consts, num_carry])
Expand All @@ -1326,7 +1331,7 @@ def scan_bind(*args, forward, length, num_consts, num_carry, jaxpr, linear):
carry_avals, _ = split_list(jaxpr.out_avals, [num_carry])
assert all(_map(typematch, init_avals, carry_avals))
core.check_jaxpr(jaxpr.jaxpr)
return core.Primitive.bind(scan_p, *args, forward=forward, length=length,
return core.Primitive.bind(scan_p, *args, reverse=reverse, length=length,
jaxpr=jaxpr, num_consts=num_consts,
num_carry=num_carry, linear=linear)

Expand Down
6 changes: 3 additions & 3 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1798,16 +1798,16 @@ def body(carry, aelems):
# TODO(#2640): update docs/jaxpr.rst to reflect new jaxpr
self.assertMultiLineStrippedEqual("""
{ lambda c ; a b.
let d e = scan[ forward=True
jaxpr={ lambda ; f a b c.
let d e = scan[ jaxpr={ lambda ; f a b c.
let d = mul b c
e = add a d
g = add e f
in (g, a) }
length=16
linear=(False, False, False, False)
num_carry=1
num_consts=1 ] b 0.0 a c
num_consts=1
reverse=False ] b 0.0 a c
in (d, e) }
""", str(jaxpr))

Expand Down
13 changes: 13 additions & 0 deletions tests/lax_control_flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1903,6 +1903,19 @@ def body(i, x):
.format(too_big, api.device_count(), jtu.device_under_test())),
lambda: f_loop(np.ones(too_big)))

def test_scan_reverse(self):
def cumsum(x, reverse):
return lax.scan(lambda c, x: (c + x, c + x), 0, x, reverse=reverse)[1]

x = onp.array([3, 1, 4, 1, 5, 9])
self.assertAllClose(onp.cumsum(x), cumsum(x, False), check_dtypes=False)
self.assertAllClose(onp.cumsum(x[::-1])[::-1], cumsum(x, True), check_dtypes=False)

with api.disable_jit():
self.assertAllClose(onp.cumsum(x), cumsum(x, False), check_dtypes=False)
with api.disable_jit():
self.assertAllClose(onp.cumsum(x[::-1])[::-1], cumsum(x, True), check_dtypes=False)


if __name__ == '__main__':
absltest.main()

0 comments on commit 3cd409e

Please sign in to comment.