Skip to content

Commit

Permalink
[Pallas TPU] Generalize while_loop lowering in Pallas -> Mosaic.
Browse files Browse the repository at this point in the history
The existing lowering path supports only while_loops which can be converted to fori_loop.
That path makes it significantly easier to optimize and unroll, but cannot support a large class of interesting loop formulations.

This patch draws from the Pallas -> Triton while_loop lowering rule to support such loops in Pallas.
Matching is still performed against fori_loop, to lower via that mechanism if possible -- as it is likely more straightforwardly optimizable compared to general "while".

PiperOrigin-RevId: 626089180
  • Loading branch information
jax authors committed Apr 18, 2024
1 parent 6ca69f3 commit 9c9e805
Show file tree
Hide file tree
Showing 2 changed files with 274 additions and 8 deletions.
89 changes: 81 additions & 8 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1833,28 +1833,23 @@ def _scan_lowering_rule(
skip_mlir_conversions.add(lax.scan_p)


def _while_lowering_rule(
def _lower_while_via_fori(
ctx: LoweringRuleContext,
*args,
fori_jaxpr,
cond_nconsts,
cond_jaxpr,
body_nconsts,
body_jaxpr,
):
jaxpr, err = pallas_utils.pattern_match_while_to_fori_loop(
cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts
)
if jaxpr is None:
raise NotImplementedError(err)

_, body_consts, carry = split_list(args, [cond_nconsts, body_nconsts])
(lb, ub), args = carry[:2], carry[2:]
for_out = _lower_jaxpr_to_for_loop(
ctx.replace(
block_shapes=ctx.block_shapes[: body_nconsts + 1]
+ ctx.block_shapes[body_nconsts + 2 :],
),
jaxpr,
fori_jaxpr,
lb,
arith.subi(ub, lb),
body_consts,
Expand All @@ -1865,6 +1860,84 @@ def _while_lowering_rule(
return [ub, ub, *for_out]


def _while_lowering_rule(
ctx: LoweringRuleContext,
*args,
cond_nconsts,
cond_jaxpr,
body_nconsts,
body_jaxpr,
):
# First try to lower via a simpler fori loop, which may optimize better.
fori_jaxpr, err = pallas_utils.pattern_match_while_to_fori_loop(
cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts
)
if fori_jaxpr is not None:
return _lower_while_via_fori(
ctx,
*args,
fori_jaxpr=fori_jaxpr,
cond_nconsts=cond_nconsts,
cond_jaxpr=cond_jaxpr,
body_nconsts=body_nconsts,
body_jaxpr=body_jaxpr,
)

# If we fail conversion to fori, fallback to an ordinary while loop.
cond_consts, body_consts, carry = split_list(
args, [cond_nconsts, body_nconsts]
)
cond_const_block_shapes, body_const_block_shapes, carry_block_shapes = (
split_list(ctx.block_shapes, [cond_nconsts, body_nconsts])
)
cond_const_types = [a.type for a in cond_consts]
body_const_types = [a.type for a in body_consts]
carry_types = [a.type for a in carry]
all_types = [*cond_const_types, *body_const_types, *carry_types]
while_op = scf.WhileOp(all_types, args)

before_block = while_op.before.blocks.append(*all_types)
cond_consts_, _, carry_ = split_list(
before_block.arguments,
[cond_nconsts, body_nconsts],
)
cond_args = [*cond_consts_, *carry_]
with ir.InsertionPoint.at_block_begin(before_block):
[cond] = jaxpr_subcomp(
ctx.lowering_context.replace(
block_shapes=[*cond_const_block_shapes, *carry_block_shapes]
),
cond_jaxpr.jaxpr,
*cond_args,
)
scf.condition(cond, before_block.arguments)

after_block = while_op.after.blocks.append(*all_types)
cond_consts_, body_consts_, carry_ = split_list(
after_block.arguments,
[cond_nconsts, body_nconsts],
)
all_args = [*cond_consts_, *body_consts_, *carry_]
cond_const_args, body_const_args, carry_args = split_list(
all_args, [cond_nconsts, body_nconsts]
)
with ir.InsertionPoint.at_block_begin(after_block):
loop_out = jaxpr_subcomp(
ctx.lowering_context.replace(
block_shapes=[*body_const_block_shapes, *carry_block_shapes],
),
body_jaxpr.jaxpr,
*body_const_args,
*carry_args,
)
all_handles = [*cond_const_args, *body_const_args, *loop_out]
if all_handles:
scf.yield_(all_handles)

all_out = list(while_op.results_)
return all_out[cond_nconsts + body_nconsts :]


lowering_rules[lax.while_p] = _while_lowering_rule

def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches, linear):
Expand Down
193 changes: 193 additions & 0 deletions tests/pallas/pallas_call_tpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1751,6 +1751,199 @@ def body(i, _):
)(*(jnp.array([[x]]) for x in (2, 6)))
np.testing.assert_array_equal(r, 4)

def test_non_range_while_loop(self):
"""Tests lowering of a while_loop which cannot reduce to a fori_loop."""

def kernel(x_ref, r_ref):
@pl.when(pl.program_id(0) == 0)
def _():
pl.store(r_ref, (0, 0), 0)

def cond(state):
i, s = state
return jnp.logical_and(i < 1024, s < 1024)

def body(state):
i, s = state
sl = sl = jax.lax.div(i, 128)
l = jax.lax.rem(i, 128)
v = pl.load(x_ref, (0, sl, l))
return i + 1, s + v

i = jnp.int32(0)
s = pl.load(r_ref, (0, 0))

i, s = jax.lax.while_loop(cond, body, (i, s))
pl.store(r_ref, (0, 0), s)

x = jnp.arange(4096)
x = jnp.reshape(x, [4, 8, 128])

r = pl.pallas_call(
kernel,
grid=(4,),
out_specs=pl.BlockSpec(block_shape=(1, 1), memory_space=pltpu.SMEM),
out_shape=jax.ShapeDtypeStruct([1, 1], jnp.int32),
in_specs=[
pl.BlockSpec(
lambda i: (i, 0, 0),
block_shape=(1, 8, 128),
memory_space=pltpu.SMEM,
)
],
)(x)
np.testing.assert_array_equal(r, [[1035]])

def test_vector_carry_while_loop(self):
"""Tests lowering of a while_loop which carries a vector quantity."""

def kernel(x_ref, r_ref):

def cond(v):
return v[0, 0] < 16

def body(v):
return v * 2

r_ref[:] = jax.lax.while_loop(cond, body, x_ref[:])

x = jnp.full((8, 128), 3, dtype=jnp.int32)
fn = pl.pallas_call(
kernel,
grid=(1,),
in_specs=[pl.BlockSpec(lambda i: (0, 0), (8, 128))],
out_specs=pl.BlockSpec(lambda i: (0, 0), (8, 128)),
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.int32),
)
r = fn(x)
reduced = jnp.sum(r)
# 3 -> 6 -> 12 -> 24
np.testing.assert_array_equal(reduced, 1024 * 24)

@parameterized.named_parameters(
('1x128', (1, 128)),
('2x128', (2, 128)),
('4x128', (4, 128)),
('8x128', (8, 128)),
('8x256', (8, 256)),
)
def test_while_loop_carry_memref(self, shape):
"""Tests a while loop carrying a memref."""

# TODO(hmckenzie): Investigate further why this occurs.
if shape == (1, 128):
self.skipTest('memref<1x128> inexplicably doubles to 2x128.')

def kernel(out_ref, bound):
def cond(i):
return i < bound

def body(i):
out_ref[0, i] = 2
return i + 1

jax.lax.while_loop(cond, body, 0)

x = jnp.asarray([1, 1, 1, 1])
x = jnp.asarray(x)
x = jnp.pad(x, (0, np.prod(shape) - 4), constant_values=0)
x = jnp.reshape(x, shape)
kernel = partial(kernel, bound=x.shape[1])

fn = pl.pallas_call(
kernel,
grid=(1,),
out_specs=[
pl.BlockSpec(
lambda i: (0, 0), block_shape=shape, memory_space=pltpu.SMEM
),
],
out_shape=[
jax.ShapeDtypeStruct(shape, jnp.int32),
],
)
y = fn()[0]
np.testing.assert_array_equal(y[0, 0], 2)
np.testing.assert_array_equal(y[0, 1], 2)
np.testing.assert_array_equal(y[0, 2], 2)
np.testing.assert_array_equal(y[0, 3], 2)

def test_nested_while_loop(self):
"""Tests lowering a nested while_loop."""

def kernel(in_key_ref, out_segment_count, out_size_ref, key_count):
# Compute the length of contiguous segments of keys.

def inner_cond(carry):
i, prev_key = carry
sl = sl = jax.lax.div(i, 128)
l = jax.lax.rem(i, 128)
key = jax.lax.cond(
i < key_count, lambda i: in_key_ref[sl, l], lambda i: -1, i
)
return jnp.logical_and(i < key_count, key == prev_key)

def inner_body(carry):
i, key = carry
return i + 1, key

def outer_cond(carry):
i, _ = carry
return i < key_count

def outer_body(carry):
i, next_out_idx = carry
sl = sl = jax.lax.div(i, 128)
l = jax.lax.rem(i, 128)
key = in_key_ref[sl, l]
end, _ = jax.lax.while_loop(inner_cond, inner_body, (i + 1, key))

sl = sl = jax.lax.div(next_out_idx, 128)
l = jax.lax.rem(next_out_idx, 128)
out_size_ref[sl, l] = end - i
return end, next_out_idx + 1

_, count = jax.lax.while_loop(outer_cond, outer_body, (0, 0))
out_segment_count[0, 0] = count

keys = [4, 4, 4, 3, 2, 2, 7, 7, 7, 7]
keys = jnp.asarray(keys)
real_keys = keys.shape[0]
key_count = 1024
keys = jnp.pad(keys, (0, key_count - real_keys), constant_values=32768)
keys = jnp.reshape(keys, (8, 128))
kernel_fn = partial(kernel, key_count=key_count)

fn = pl.pallas_call(
kernel_fn,
grid=(1,),
in_specs=[
# keys.
pl.BlockSpec(
lambda i: (0, 0),
block_shape=(8, 128),
memory_space=pltpu.SMEM,
),
],
out_specs=[
# Segments found.
pl.BlockSpec(block_shape=(1, 1), memory_space=pltpu.SMEM),
# Segment sizes.
pl.BlockSpec(block_shape=(8, 128), memory_space=pltpu.SMEM),
],
out_shape=[
jax.ShapeDtypeStruct((1, 1), jnp.int32),
jax.ShapeDtypeStruct((8, 128), jnp.int32),
],
)
count, sizes = fn(keys)
np.testing.assert_equal(count[0, 0], jnp.asarray(5))
np.testing.assert_equal(sizes[0, 0], jnp.asarray(3))
np.testing.assert_equal(sizes[0, 1], jnp.asarray(1))
np.testing.assert_equal(sizes[0, 2], jnp.asarray(2))
np.testing.assert_equal(sizes[0, 3], jnp.asarray(4))
np.testing.assert_equal(sizes[0, 4], jnp.asarray(key_count - real_keys))


class PallasCallPipelineTest(parameterized.TestCase):

Expand Down

0 comments on commit 9c9e805

Please sign in to comment.