Skip to content

Commit

Permalink
Merge pull request #403 from google/improve-loop-construct-docstrings
Browse files Browse the repository at this point in the history
improve loop construct docs, remove foreach_loop
  • Loading branch information
mattjj committed Feb 18, 2019
2 parents b322833 + 13834ee commit 70b13ce
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 41 deletions.
69 changes: 46 additions & 23 deletions jax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,35 @@ def sort_key_val(keys, values, dimension=-1):
return sorted_keys, sorted_values

def while_loop(cond_fun, body_fun, init_val):
"""Call `body_fun` repeatedly in a loop while `cond_fun` is True.
Arguments:
cond_fun: pure function of type `T -> Bool`.
body_fun: pure function of type `T -> T`.
init_val: value of type `T`, a type that can be a scalar, array, or any
(nested) Python tuple/list/dict thereof.
Returns:
The output from the final iteration of body_fun, of type `T`.
The semantics of `while_loop` are given by this Python implementation:
def while_loop(cond_fun, body_fun, init_val):
val = init_val
while cond_fun(val):
val = body_fun(val)
return val
Unlike that pure Python version, `while_loop` is a JAX primitive and is
lowered to a single XLA While HLO. That makes it useful for reducing
compilation times for jit-compiled functions, since native Python loop
constructs in an `@jit` function are unrolled, leading to large XLA
computations.
Another difference from using Python-native loop constructs is that
`while_loop` is not (yet) reverse-mode differentiable because XLA computations
require static bounds on memory requirements.
"""
init_val_flat, in_tree = pytree_to_jaxtupletree(init_val)
flat_body_fun, out_tree = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(body_fun), (in_tree,))
flat_cond_fun, _ = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(cond_fun), (in_tree,))
Expand Down Expand Up @@ -618,33 +647,27 @@ def fori_loop(lower, upper, body_fun, init_val):
Returns:
Loop value from the final iteration, of type T.
"""
# state: (upper limit, index, loop value)
# The `lt` and `add` functions are added to the namespace programmatically.
_, _, result = while_loop(
lambda upper_i_x: lt(upper_i_x[1], upper_i_x[0]),
lambda upper_i_x: (upper_i_x[0],
add(upper_i_x[1], onp.array(1, _dtype(upper_i_x[1]))),
body_fun(upper_i_x[1], upper_i_x[2])),
(upper, lower, init_val))
return result

def foreach_loop(sequence, body_fun, init_val):
"""Loop over `sequence` by reduction to `while_loop`.
The semantics of `fori_loop` are given by this Python implementation:
Arguments:
sequence: tuple of loop items, each of type U
body_fun: function of type (U, T) -> T, where T is the type of `init_val`
init_val: initial loop value, of type T
def fori_loop(lower, upper, body_fun, init_val):
val = init_val
for i in range(lower, upper):
val = body_fun(i, val)
return val
Returns:
Loop value from the final iteration, of type T.
Unlike that pure Python version, `fori_loop` is implemented in terms of a call
to `while_loop`. See the docstring for `while_loop` for more information.
"""
_, result = fori_loop(
0, len(sequence),
lambda i, seq_val: (seq_val[0], body_fun(seq_val[0][i], seq_val[1])),
(sequence, init_val))
def while_cond_fun(loop_carry):
i, _ = loop_carry
return lt(i, upper)

def while_body_fun(loop_carry):
i, x = loop_carry
return add(i, _const(i, 1)), body_fun(i, x)

_, result = while_loop(while_cond_fun, while_body_fun, (lower, init_val))
return result


Expand Down
18 changes: 0 additions & 18 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1335,24 +1335,6 @@ def body_fun(i, state):
self.assertAllClose(cfun(x, num), onp.sum(x[:num]), check_dtypes=False)
self.assertAllClose(cfun(x, num), onp.sum(x[:num]), check_dtypes=False)

def testForeachLoopBasic(self):
def sum_squares(xs):
def body_fun(x, y):
return y + x * x
return lax.foreach_loop(xs, body_fun, 0)

sum_squares_jit = api.jit(sum_squares)

xs = onp.array([1, 2, 3, 4])
self.assertEqual(sum_squares(xs[:1]), 1)
self.assertEqual(sum_squares(xs[:1]), sum_squares_jit(xs[:1]))
self.assertEqual(sum_squares(xs[:2]), 5)
self.assertEqual(sum_squares(xs[:2]), sum_squares_jit(xs[:2]))
self.assertEqual(sum_squares(xs[:3]), 14)
self.assertEqual(sum_squares(xs[:3]), sum_squares_jit(xs[:3]))
self.assertEqual(sum_squares(xs[:4]), 30)
self.assertEqual(sum_squares(xs[:4]), sum_squares_jit(xs[:4]))

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_lhs_shape={}_rhs_shape={}"
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
Expand Down

0 comments on commit 70b13ce

Please sign in to comment.