From 54e3b7611af9443fb944433fc9e6e111b0a7a150 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Mon, 4 Dec 2023 10:26:25 -0800 Subject: [PATCH] Add support for unrolling to `lax.fori_loop` PiperOrigin-RevId: 587767613 --- jax/_src/lax/control_flow/loops.py | 17 +++++++++++++---- tests/lax_control_flow_test.py | 26 ++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 49da39f8b221..275593919403 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Module for the loop primitives.""" +from __future__ import annotations from collections.abc import Sequence from functools import partial @@ -1842,7 +1843,8 @@ def scanned_fun(loop_carry, _): return scanned_fun @api_boundary -def fori_loop(lower, upper, body_fun, init_val): +def fori_loop(lower, upper, body_fun, init_val, + *, unroll: int | None = None): """Loop from ``lower`` to ``upper`` by reduction to :func:`jax.lax.while_loop`. The `Haskell-like type signature`_ in brief is @@ -1887,6 +1889,8 @@ def fori_loop(lower, upper, body_fun, init_val): upper: an integer representing the loop index upper bound (exclusive) body_fun: function of type ``(int, a) -> a``. init_val: initial loop carry value of type ``a``. + unroll: An optional integer that determines how much to unroll the loop. + Only applicable if the loop bounds are statically known. Returns: Loop value from the final iteration, of type ``a``. @@ -1934,18 +1938,23 @@ def fori_loop(lower, upper, body_fun, init_val): use_scan = False if use_scan: + if unroll is None: + unroll = 1 if config.disable_jit.value and upper_ == lower_: # non-jit implementation of scan does not support length=0 return init_val (_, result), _ = scan(_fori_scan_body_fun(body_fun), (lower_, init_val), - None, length=upper_ - lower_) + None, length=upper_ - lower_, unroll=unroll) return result + if unroll is not None: + raise ValueError("Can only use `unroll` in `fori_loop` if the loop bounds " + "are statically known.") if lower_dtype != dtype: - lower = lax.convert_element_type(lower, dtype) + lower = lax.convert_element_type(lower, dtype) # type: ignore if upper_dtype != dtype: - upper = lax.convert_element_type(upper, dtype) + upper = lax.convert_element_type(upper, dtype) # type: ignore _, _, result = while_loop(_fori_cond_fun, _fori_body_fun(body_fun), (lower, upper, init_val)) return result diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 6f08e1317882..589047f17a2e 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -499,6 +499,32 @@ def testForiLoopScalarLimits(self): result = lax.fori_loop(0, np.int16(10), body, init) self.assertEqual(result, init + 10) + def test_fori_loop_supports_unrolling(self): + """Test that we can unroll static fori_loops.""" + body = lambda i, c: c + 1 + init = jnp.float32(10) + + result = lax.fori_loop(np.int16(0), 10, body, init, + unroll=3) + self.assertEqual(result, init + 10) + + result = lax.fori_loop(0, np.int16(10), body, init, + unroll=2) + self.assertEqual(result, init + 10) + + def test_fori_loop_with_dynamic_indices_cannot_unroll(self): + """Test that we can't unroll dynamic fori_loops.""" + body = lambda i, c: c + 1 + init = jnp.float32(10) + + @jax.jit + def f(upper): + return lax.fori_loop(np.int16(0), upper, body, init, + unroll=3) + + with self.assertRaisesRegex(ValueError, "Can only use `unroll`"): + f(10) + def testForiLoopBatched(self): def body_fun(i, loop_carry): x, y = loop_carry