Skip to content

Commit

Permalink
[Pallas/Mosaic] Lower lax.fori_loops to *rolled* loops.
Browse files Browse the repository at this point in the history
Note that this is a breaking change! Current uses of `lax.fori_loop` inside of kernels
should instead pass `unroll=True` (loops were being unrolled by default and
we are switching that with this change).

PiperOrigin-RevId: 589017485
  • Loading branch information
sharadmv authored and jax authors committed Dec 8, 2023
1 parent 1701b73 commit 7af1c14
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 5 deletions.
54 changes: 50 additions & 4 deletions jax/_src/pallas/mosaic/lowering.py
Expand Up @@ -1433,6 +1433,51 @@ def _for_lowering_rule(
lowering_rules[for_loop.for_p] = _for_lowering_rule


def _lower_jaxpr_to_for_loop(ctx: LoweringRuleContext,
jaxpr: jax_core.Jaxpr, start: int,
num_steps: int, consts, *args,
has_loop_index: bool,
unroll: int):
def _run_body(i, args):
if has_loop_index:
lowering_context = ctx.lowering_context.replace(
block_shapes=ctx.block_shapes)
args = jaxpr_subcomp(lowering_context, jaxpr, *consts, i, *args)
else:
del i
lowering_context = ctx.lowering_context.replace(
block_shapes=ctx.block_shapes[:len(consts)]
+ ctx.block_shapes[len(consts) + 1:],
)
args = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
return args
if num_steps == unroll:
# No need for an scf.For. We can just unroll completely
for i in range(start, start + num_steps):
args = _run_body(
ir_constant(i, mlir_type=mlir.dtype_to_ir_type(jnp.dtype("int32"))),
args,
)
return args
if unroll != 1:
raise NotImplementedError(
f"Only unroll={num_steps=} and unroll=1 supported. Got {unroll=}.")
if len(args) > 0:
raise NotImplementedError("Rolled loops don't support arguments")
lbd = ir_constant(0, mlir_type=mlir.dtype_to_ir_type(jnp.dtype("int32")))
ubd = ir_constant(
num_steps, mlir_type=mlir.dtype_to_ir_type(jnp.dtype("int32"))
)
step = ir_constant(1, mlir_type=mlir.dtype_to_ir_type(jnp.dtype("int32")))
for_op = scf.ForOp(lbd, ubd, step, args)
with ir.InsertionPoint(for_op.body):
iv = for_op.induction_variable
inner_args = for_op.inner_iter_args
inner_out = _run_body(iv, inner_args)
scf.YieldOp(inner_out)
return for_op.results


def _lower_jaxpr_to_unrolled_for_loop(ctx: LoweringRuleContext,
jaxpr: jax_core.Jaxpr, start: int,
num_steps: int, consts, *args,
Expand Down Expand Up @@ -1469,7 +1514,7 @@ def _scan_lowering_rule(
num_extensive = len(args) - num_consts - num_carry
if num_extensive: raise NotImplementedError
if reverse: raise NotImplementedError
del linear, num_extensive, unroll, reverse
del linear, num_extensive, reverse

jaxpr, jaxpr_consts = jaxpr.jaxpr, jaxpr.consts
if jaxpr_consts: raise NotImplementedError
Expand All @@ -1483,9 +1528,10 @@ def _scan_lowering_rule(
loop_index_start, *args = args
else:
loop_index_start = 0
out = _lower_jaxpr_to_unrolled_for_loop(ctx, jaxpr, loop_index_start, length,
consts, *args,
has_loop_index=has_loop_index)
out = _lower_jaxpr_to_for_loop(
ctx, jaxpr, loop_index_start, length,
consts, *args, has_loop_index=has_loop_index,
unroll=unroll)
if has_loop_index:
out = [ir_constant(length,
mlir_type=mlir.dtype_to_ir_type(jnp.dtype('int32'))),
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/pallas/ops/tpu/flash_attention.py
Expand Up @@ -366,7 +366,7 @@ def start_new_sequence():
@pl.when(should_run)
def run():
@functools.partial(
lax.fori_loop, 0, block_k_major // block_k, init_val=None
lax.fori_loop, 0, block_k_major // block_k, init_val=None, unroll=True
)
def body(i, _):
m_prev = m_scratch_ref[batch_idx]
Expand Down

0 comments on commit 7af1c14

Please sign in to comment.