From 7af1c149f56f1c27722490816a44c480db1c0963 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Thu, 7 Dec 2023 22:54:32 -0800 Subject: [PATCH] [Pallas/Mosaic] Lower `lax.fori_loop`s to *rolled* loops. Note that this is a breaking change! Current uses of `lax.fori_loop` inside of kernels should instead pass `unroll=True` (loops were being unrolled by default and we are switching that with this change). PiperOrigin-RevId: 589017485 --- jax/_src/pallas/mosaic/lowering.py | 54 +++++++++++++++++-- .../pallas/ops/tpu/flash_attention.py | 2 +- 2 files changed, 51 insertions(+), 5 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 13865d848aca..286fdef8b874 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1433,6 +1433,51 @@ def _for_lowering_rule( lowering_rules[for_loop.for_p] = _for_lowering_rule +def _lower_jaxpr_to_for_loop(ctx: LoweringRuleContext, + jaxpr: jax_core.Jaxpr, start: int, + num_steps: int, consts, *args, + has_loop_index: bool, + unroll: int): + def _run_body(i, args): + if has_loop_index: + lowering_context = ctx.lowering_context.replace( + block_shapes=ctx.block_shapes) + args = jaxpr_subcomp(lowering_context, jaxpr, *consts, i, *args) + else: + del i + lowering_context = ctx.lowering_context.replace( + block_shapes=ctx.block_shapes[:len(consts)] + + ctx.block_shapes[len(consts) + 1:], + ) + args = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args) + return args + if num_steps == unroll: + # No need for an scf.For. We can just unroll completely + for i in range(start, start + num_steps): + args = _run_body( + ir_constant(i, mlir_type=mlir.dtype_to_ir_type(jnp.dtype("int32"))), + args, + ) + return args + if unroll != 1: + raise NotImplementedError( + f"Only unroll={num_steps=} and unroll=1 supported. Got {unroll=}.") + if len(args) > 0: + raise NotImplementedError("Rolled loops don't support arguments") + lbd = ir_constant(0, mlir_type=mlir.dtype_to_ir_type(jnp.dtype("int32"))) + ubd = ir_constant( + num_steps, mlir_type=mlir.dtype_to_ir_type(jnp.dtype("int32")) + ) + step = ir_constant(1, mlir_type=mlir.dtype_to_ir_type(jnp.dtype("int32"))) + for_op = scf.ForOp(lbd, ubd, step, args) + with ir.InsertionPoint(for_op.body): + iv = for_op.induction_variable + inner_args = for_op.inner_iter_args + inner_out = _run_body(iv, inner_args) + scf.YieldOp(inner_out) + return for_op.results + + def _lower_jaxpr_to_unrolled_for_loop(ctx: LoweringRuleContext, jaxpr: jax_core.Jaxpr, start: int, num_steps: int, consts, *args, @@ -1469,7 +1514,7 @@ def _scan_lowering_rule( num_extensive = len(args) - num_consts - num_carry if num_extensive: raise NotImplementedError if reverse: raise NotImplementedError - del linear, num_extensive, unroll, reverse + del linear, num_extensive, reverse jaxpr, jaxpr_consts = jaxpr.jaxpr, jaxpr.consts if jaxpr_consts: raise NotImplementedError @@ -1483,9 +1528,10 @@ def _scan_lowering_rule( loop_index_start, *args = args else: loop_index_start = 0 - out = _lower_jaxpr_to_unrolled_for_loop(ctx, jaxpr, loop_index_start, length, - consts, *args, - has_loop_index=has_loop_index) + out = _lower_jaxpr_to_for_loop( + ctx, jaxpr, loop_index_start, length, + consts, *args, has_loop_index=has_loop_index, + unroll=unroll) if has_loop_index: out = [ir_constant(length, mlir_type=mlir.dtype_to_ir_type(jnp.dtype('int32'))), diff --git a/jax/experimental/pallas/ops/tpu/flash_attention.py b/jax/experimental/pallas/ops/tpu/flash_attention.py index 7fbf006626d6..de55a20d788f 100644 --- a/jax/experimental/pallas/ops/tpu/flash_attention.py +++ b/jax/experimental/pallas/ops/tpu/flash_attention.py @@ -366,7 +366,7 @@ def start_new_sequence(): @pl.when(should_run) def run(): @functools.partial( - lax.fori_loop, 0, block_k_major // block_k, init_val=None + lax.fori_loop, 0, block_k_major // block_k, init_val=None, unroll=True ) def body(i, _): m_prev = m_scratch_ref[batch_idx]