Skip to content

Commit

Permalink
Add support for unrolling to lax.fori_loop
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 587767613
  • Loading branch information
sharadmv authored and jax authors committed Dec 4, 2023
1 parent 5942e15 commit 54e3b76
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 4 deletions.
17 changes: 13 additions & 4 deletions jax/_src/lax/control_flow/loops.py
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for the loop primitives."""
from __future__ import annotations

from collections.abc import Sequence
from functools import partial
Expand Down Expand Up @@ -1842,7 +1843,8 @@ def scanned_fun(loop_carry, _):
return scanned_fun

@api_boundary
def fori_loop(lower, upper, body_fun, init_val):
def fori_loop(lower, upper, body_fun, init_val,
*, unroll: int | None = None):
"""Loop from ``lower`` to ``upper`` by reduction to :func:`jax.lax.while_loop`.
The `Haskell-like type signature`_ in brief is
Expand Down Expand Up @@ -1887,6 +1889,8 @@ def fori_loop(lower, upper, body_fun, init_val):
upper: an integer representing the loop index upper bound (exclusive)
body_fun: function of type ``(int, a) -> a``.
init_val: initial loop carry value of type ``a``.
unroll: An optional integer that determines how much to unroll the loop.
Only applicable if the loop bounds are statically known.
Returns:
Loop value from the final iteration, of type ``a``.
Expand Down Expand Up @@ -1934,18 +1938,23 @@ def fori_loop(lower, upper, body_fun, init_val):
use_scan = False

if use_scan:
if unroll is None:
unroll = 1
if config.disable_jit.value and upper_ == lower_:
# non-jit implementation of scan does not support length=0
return init_val

(_, result), _ = scan(_fori_scan_body_fun(body_fun), (lower_, init_val),
None, length=upper_ - lower_)
None, length=upper_ - lower_, unroll=unroll)
return result
if unroll is not None:
raise ValueError("Can only use `unroll` in `fori_loop` if the loop bounds "
"are statically known.")

if lower_dtype != dtype:
lower = lax.convert_element_type(lower, dtype)
lower = lax.convert_element_type(lower, dtype) # type: ignore
if upper_dtype != dtype:
upper = lax.convert_element_type(upper, dtype)
upper = lax.convert_element_type(upper, dtype) # type: ignore
_, _, result = while_loop(_fori_cond_fun, _fori_body_fun(body_fun),
(lower, upper, init_val))
return result
Expand Down
26 changes: 26 additions & 0 deletions tests/lax_control_flow_test.py
Expand Up @@ -499,6 +499,32 @@ def testForiLoopScalarLimits(self):
result = lax.fori_loop(0, np.int16(10), body, init)
self.assertEqual(result, init + 10)

def test_fori_loop_supports_unrolling(self):
"""Test that we can unroll static fori_loops."""
body = lambda i, c: c + 1
init = jnp.float32(10)

result = lax.fori_loop(np.int16(0), 10, body, init,
unroll=3)
self.assertEqual(result, init + 10)

result = lax.fori_loop(0, np.int16(10), body, init,
unroll=2)
self.assertEqual(result, init + 10)

def test_fori_loop_with_dynamic_indices_cannot_unroll(self):
"""Test that we can't unroll dynamic fori_loops."""
body = lambda i, c: c + 1
init = jnp.float32(10)

@jax.jit
def f(upper):
return lax.fori_loop(np.int16(0), upper, body, init,
unroll=3)

with self.assertRaisesRegex(ValueError, "Can only use `unroll`"):
f(10)

def testForiLoopBatched(self):
def body_fun(i, loop_carry):
x, y = loop_carry
Expand Down

0 comments on commit 54e3b76

Please sign in to comment.