From 236a445b49ec7ec3dc6937a1958c60ed831af77f Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Wed, 22 Jun 2022 12:36:13 -0700 Subject: [PATCH] Add `for_loop` primitive and impl rule Co-authored-by: Matthew Johnson --- jax/_src/lax/control_flow/for_loop.py | 138 ++++++++++++++++++++++++-- tests/lax_control_flow_test.py | 91 +++++++++++++++-- 2 files changed, 217 insertions(+), 12 deletions(-) diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py index d520e0109115..7718c11faaa0 100644 --- a/jax/_src/lax/control_flow/for_loop.py +++ b/jax/_src/lax/control_flow/for_loop.py @@ -14,30 +14,38 @@ """Module for the `for_loop` primitive.""" from functools import partial -from typing import Any, Dict, List, Sequence, Tuple +from typing import Any, Callable, Dict, Generic, List, Sequence, Tuple, TypeVar from jax import core from jax import lax from jax import linear_util as lu +from jax.api_util import flatten_fun_nokwargs from jax.interpreters import ad +from jax.interpreters import mlir from jax.interpreters import partial_eval as pe +from jax.interpreters import xla +from jax.tree_util import (tree_flatten, tree_structure, tree_unflatten, + treedef_tuple, PyTreeDef) from jax._src import ad_util from jax._src import pretty_printer as pp -from jax._src import util +from jax._src.util import safe_map, safe_zip, split_list +import jax.numpy as jnp ## JAX utilities -map, unsafe_map = util.safe_map, map -zip, unsafe_zip = util.safe_zip, zip +map, unsafe_map = safe_map, map +zip, unsafe_zip = safe_zip, zip ## Helpful type aliases -Ref = Any +S = TypeVar('S') +T = TypeVar('T') +class Ref(Generic[T]): pass Array = Any ## State effect class StateEffect: pass -State = StateEffect +State = StateEffect() ## get/swap/addupdate implementations @@ -57,6 +65,7 @@ def _get_impl(ref: Ref, *idx: int): def ref_get(ref: Ref, idx: Tuple[int]) -> Array: """Reads a value from a `Ref`, a.k.a. value <- ref[idx].""" + idx = map(jnp.int32, idx) return get_p.bind(ref, *idx) # `swap` mutates a `Ref`, setting its value and returns its previous value. @@ -84,6 +93,7 @@ def _swap_impl(ref: Ref, value: Array, *idx: int): def ref_swap(ref: Ref, idx: Tuple[int], value: Array) -> Array: """Sets a `Ref`'s value and returns the original value.""" + idx = map(jnp.int32, idx) return swap_p.bind(ref, value, *idx) def ref_set(ref: Ref, idx: Tuple[int], value: Array) -> None: @@ -168,6 +178,10 @@ def _swap_abstract_eval(ref_aval: ShapedArrayRef, val_aval: core.AbstractValue, f"Ref shape: {ref_aval.shape}. " f"Value shape: {val_aval.shape}. " f"Indices: {idx}. ") + if ref_aval.dtype != val_aval.dtype: + raise ValueError("Invalid dtype for `swap`. " + f"Ref dtype: {ref_aval.dtype}. " + f"Value shape: {val_aval.dtype}. ") return core.ShapedArray(ref_aval.shape[len(idx):], ref_aval.dtype), {State} swap_p.def_effectful_abstract_eval(_swap_abstract_eval) @@ -372,3 +386,115 @@ def write(v: core.Var, val: Any) -> None: ref_vals = map( read, [v for v in jaxpr.invars if type(v.aval) is ShapedArrayRef]) return out_vals + ref_vals + +## `for_loop` implementation + +for_p = core.Primitive('for') +for_p.multiple_results = True + +### Tracing utilities + +def _hoist_consts_to_refs(jaxpr: core.Jaxpr) -> core.Jaxpr: + num_consts = len(jaxpr.constvars) + + # Note that this function is meant for use w/ `for_loop` since it assumes + # that the index is the first argument and preserves this after hoisting + # consts. + def _hoist(i, *consts_args): + const_refs, args = split_list(consts_args, [num_consts]) + # We immediately read the const values out of the `Ref`s. + consts = [r[()] for r in const_refs] + return core.eval_jaxpr(jaxpr, consts, i, *args) + assert all(isinstance(var.aval, core.ShapedArray) for var in jaxpr.constvars) + const_avals = [ShapedArrayRef(var.aval.shape, var.aval.dtype) for var in # pytype: disable=attribute-error + jaxpr.constvars] + i_aval, *arg_avals = [var.aval for var in jaxpr.invars] + in_avals = [i_aval, *const_avals, *arg_avals] + hoisted_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(_hoist), in_avals) + assert not consts, "All consts should have been converted to refs" + return hoisted_jaxpr + +def _trace_to_jaxpr_with_refs(f, state_tree: PyTreeDef, + state_avals: Sequence[core.AbstractValue] + ) -> Tuple[core.Jaxpr, List[Any], PyTreeDef]: + f, out_tree_thunk = flatten_fun_nokwargs( + lu.wrap_init(f), treedef_tuple((tree_structure(0), state_tree))) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( + f, state_avals) + return jaxpr, consts, out_tree_thunk() + +def val_to_ref_aval(x) -> ShapedArrayRef: + aval = core.raise_to_shaped(core.get_aval(x)) + if type(aval) is not core.ShapedArray: + raise Exception(f"can't make ref from {x}") + return ShapedArrayRef(aval.shape, aval.dtype) + +def for_loop(nsteps: int, body: Callable[[Array, Ref[S]], None], init_state: S) -> S: + """A for-loop combinator that allows read/write semantics in the loop body. + + `for_loop` is a higher-order function that enables writing loops that can be + staged out in JIT-ted JAX computations. Unlike `jax.lax.fori_loop`, it allows + mutation in its body using `Ref`s. + + `for_loop` will initialize `Ref`s with the values in `init_state`. Each + iteration, `body` will be called with the current `Ref`s, which can be read + from and written to using `ref_get` and `ref_set`. + + `for_loop` is semantically equivalent to the following Python code: + + ```python + def for_loop(nsteps, body, init_state): + refs = tree_map(make_ref, init_state) + for i in range(nsteps): + body(i, refs) + return tree_map(ref_get, refs) + ``` + + Args: + nsteps: Number of iterations + body: A callable that takes in the iteration number as its first argument + and `Ref`s corresponding to `init_state` as its second argument. + `body` is free to read from and write to its `Ref`s. `body` should + not return anything. + init_state: A Pytree of JAX-compatible values used to initialize the `Ref`s + that will be passed into the for loop body. + Returns: + A Pytree of values representing the output of the for loop. + """ + flat_state, state_tree = tree_flatten(init_state) + state_avals = map(val_to_ref_aval, flat_state) + idx_aval = core.ShapedArray((), jnp.dtype("int32")) + jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs( + body, state_tree, [idx_aval, *state_avals]) + if out_tree != tree_structure(None): + raise Exception("`body` should not return anything.") + # Remove constvars from jaxpr and turn them into `Ref`s + jaxpr = _hoist_consts_to_refs(jaxpr) + which_linear = (False,) * (len(consts) + len(flat_state)) + out_flat = for_p.bind(*consts, *flat_state, jaxpr=jaxpr, nsteps=int(nsteps), + reverse=False, which_linear=which_linear) + # Consts are `Ref`s so they are both inputs and outputs. We remove them from + # the outputs. + out_flat = out_flat[len(consts):] + return tree_unflatten(state_tree, out_flat) + +@for_p.def_abstract_eval +def _for_abstract_eval(*avals, jaxpr, **__): + return list(avals) + +def _for_impl(*args, jaxpr, nsteps, reverse, which_linear): + del which_linear + discharged_jaxpr, consts = discharge_state(jaxpr, ()) + def cond(carry): + i, _ = carry + return i < nsteps + def body(carry): + i, state = carry + i_ = nsteps - i - 1 if reverse else i + next_state = core.eval_jaxpr(discharged_jaxpr, consts, i_, *state) + return i + 1, next_state + _, state = lax.while_loop(cond, body, (jnp.int32(0), list(args))) + return state +mlir.register_lowering(for_p, mlir.lower_fun(_for_impl, multiple_results=True)) +for_p.def_impl(partial(xla.apply_primitive, for_p)) diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 82b830191987..d7c7b6a17eb2 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -2567,8 +2567,8 @@ def test_addupdate_checks_for_correct_shapes(self): def test_can_represent_get_and_swap_in_jaxprs(self): def body(x): - x[()] = 1 - x[()] = 2 + x[()] = jnp.int32(1) + x[()] = jnp.int32(2) return (x[()],) jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic( lu.wrap_init(body), [for_loop.ShapedArrayRef((), jnp.int32)]) @@ -2581,7 +2581,7 @@ def body(x): def test_can_represent_addupdate_in_jaxprs(self): def body(x): - for_loop.ref_addupdate(x, (), 1) + for_loop.ref_addupdate(x, (), jnp.int32(1)) return (x[()],) jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic( lu.wrap_init(body), [for_loop.ShapedArrayRef((), jnp.int32)]) @@ -2599,7 +2599,7 @@ def body(x_ref): def test_set_custom_pretty_printing_rule(self): def body(x_ref): - x_ref[()] = 2 + x_ref[()] = jnp.int32(2) return [] jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( lu.wrap_init(body), [for_loop.ShapedArrayRef((), jnp.int32)]) @@ -2607,7 +2607,7 @@ def body(x_ref): def test_swap_custom_pretty_printing_rule(self): def body(x_ref): - x = for_loop.ref_swap(x_ref, (), 2) + x = for_loop.ref_swap(x_ref, (), jnp.int32(2)) return [x] jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( lu.wrap_init(body), [for_loop.ShapedArrayRef((), jnp.int32)]) @@ -2615,7 +2615,7 @@ def body(x_ref): def test_addupdate_custom_pretty_printing_rule(self): def body(x_ref): - for_loop.ref_addupdate(x_ref, (), 2) + for_loop.ref_addupdate(x_ref, (), jnp.int32(2)) return [] jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( lu.wrap_init(body), [for_loop.ShapedArrayRef((), jnp.int32)]) @@ -2803,5 +2803,84 @@ def f(a_ref): self.assertTrue((b == inval + 1).all()) self.assertTrue((refval == inval).all()) + def test_for_loop_impl_trivial(self): + out = for_loop.for_loop(5, lambda i, _: None, None) + self.assertEqual(out, None) + + def test_for_loop_can_write_to_ref(self): + def body(_, x_ref): + x_ref[()] = jnp.float32(1.) + out = for_loop.for_loop(1, body, jnp.float32(0.)) + self.assertEqual(out, 1.) + + def body2(i, x_ref): + x_ref[()] = jnp.float32(i) + out = for_loop.for_loop(2, body2, jnp.float32(0.)) + self.assertEqual(out, 1.) + + def body3(i, x_ref): + x_ref[()] = jnp.float32(i) * 2. + out = for_loop.for_loop(2, body3, jnp.float32(0.)) + self.assertEqual(out, 2.) + + def test_for_loop_can_write_to_multiple_refs(self): + def body(_, refs): + x_ref, y_ref = refs + x_ref[()] = jnp.float32(1.) + y_ref[()] = jnp.float32(2.) + x, y = for_loop.for_loop(1, body, (jnp.float32(0.), jnp.float32(0.))) + self.assertEqual(x, 1.) + self.assertEqual(y, 2.) + + def test_for_loop_can_read_from_ref(self): + def body(_, x_ref): + x_ref[()] + x = for_loop.for_loop(1, body, jnp.float32(0.)) + self.assertEqual(x, 0.) + + def test_for_loop_can_read_from_and_write_to_ref(self): + def body(_, x_ref): + x = x_ref[()] + x_ref[()] = x + jnp.float32(1.) + x = for_loop.for_loop(5, body, jnp.float32(0.)) + self.assertEqual(x, 5.) + + def test_for_loop_can_read_from_and_write_to_refs(self): + def body2(_, refs): + x_ref, y_ref = refs + x = x_ref[()] + y_ref[()] = x + 1. + x_ref[()] = x + 1. + x, y = for_loop.for_loop(5, body2, (0., 0.)) + self.assertEqual(x, 5.) + self.assertEqual(y, 5.) + + def test_for_loop_can_read_from_and_write_to_ref_slice(self): + def body(i, x_ref): + x = x_ref[i] + x_ref[i] = x + jnp.float32(1.) + x = for_loop.for_loop(4, body, jnp.ones(4, jnp.float32)) + np.testing.assert_allclose(x, 2 * jnp.ones(4, jnp.float32)) + + def body2(i, x_ref): + x = x_ref[i, 0] + x_ref[i, 1] = x + x_ref[i, 1] + x = for_loop.for_loop(4, body2, jnp.arange(8.).reshape((4, 2))) + np.testing.assert_allclose( + x, jnp.array([[0., 1.], [2., 5.], [4., 9.], [6., 13.]])) + + def test_for_loop_can_implement_cumsum(self): + def cumsum(x): + def body(i, refs): + x_ref, accum_ref = refs + accum_ref[i + 1] = accum_ref[i] + x_ref[i] + accum = jnp.zeros(x.shape[0] + 1, x.dtype) + _, accum_out = for_loop.for_loop(x.shape[0], body, (x, accum)) + return accum_out[1:] + + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (8,)) + np.testing.assert_allclose(cumsum(x), jnp.cumsum(x)) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())