Skip to content

Commit

Permalink
Reverts 8e96c49
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 606423581
  • Loading branch information
hawkinsp authored and jax authors committed Feb 13, 2024
1 parent 7aadabd commit 8b1da12
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 162 deletions.
43 changes: 4 additions & 39 deletions jax/_src/pallas/mosaic/lowering.py
Expand Up @@ -776,8 +776,7 @@ def _indexer_to_start_size(
else _index_to_start_size(next(indices_iter), cast_to_index)
for s in ref_block_shape
)
next_index = next(indices_iter, None)
assert next_index is None, (indexer.indices, ref_block_shape)
assert next(indices_iter, None) is None
new_ref_block_shape = tuple(s for s, squeeze in zip(sizes, squeeze_dims)
if not squeeze)
return tuple(starts), tuple(sizes), tuple(squeeze_dims), new_ref_block_shape
Expand Down Expand Up @@ -1661,12 +1660,9 @@ def _run_body(i, args):
raise NotImplementedError(
f"Only unroll={num_steps=} and unroll=1 supported. Got {unroll=}.")
lbd = ir_constant(0, mlir_type=mlir.dtype_to_ir_type(jnp.dtype("int32")))
if isinstance(num_steps, int):
ubd = ir_constant(
num_steps, mlir_type=_dtype_to_ir_type(jnp.dtype("int32"))
)
else:
ubd = num_steps
ubd = ir_constant(
num_steps, mlir_type=_dtype_to_ir_type(jnp.dtype("int32"))
)
step = ir_constant(1, mlir_type=_dtype_to_ir_type(jnp.dtype("int32")))
for_op = scf.ForOp(lbd, ubd, step, args)
with ir.InsertionPoint(for_op.body):
Expand Down Expand Up @@ -1744,37 +1740,6 @@ def _scan_lowering_rule(
skip_mlir_conversions.add(lax.scan_p)


def _while_lowering_rule(
ctx: LoweringRuleContext,
*args,
cond_nconsts,
cond_jaxpr,
body_nconsts,
body_jaxpr,
):
jaxpr = pallas_utils.pattern_match_while_to_fori_loop(
cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts
)
_, body_consts, carry = split_list(args, [cond_nconsts, body_nconsts])
(lb, ub), args = carry[:2], carry[2:]
for_out = _lower_jaxpr_to_for_loop(
ctx.replace(
block_shapes=ctx.block_shapes[: body_nconsts + 1]
+ ctx.block_shapes[body_nconsts + 2 :],
),
jaxpr,
lb,
ub,
body_consts,
*args,
has_loop_index=True,
unroll=1,
)
return [ub, ub, *for_out]


lowering_rules[lax.while_p] = _while_lowering_rule

def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches, linear):
index, *args = args
out_types = map(aval_to_ir_type, ctx.avals_out)
Expand Down
28 changes: 3 additions & 25 deletions jax/_src/pallas/mosaic/pipeline.py
Expand Up @@ -143,13 +143,6 @@ def _get_next_indices(grid: core.StaticGrid, indices: GridIndices) -> GridIndice
return tuple(reversed(next_indices))


def _replace_nones_in_block_spec(block_spec: core.BlockSpec) -> core.BlockSpec:
"""Replaces Nones in a block spec shape with 1s."""
block_shape = cast(tuple[int, ...], block_spec.block_shape)
block_shape = tuple([1 if dim is None else dim for dim in block_shape])
return dataclasses.replace(block_spec, block_shape=block_shape)


def _run_block_spec(
block_spec: core.BlockSpec, indices: GridIndices
) -> tuple[Union[slice, indexing.Slice], ...]:
Expand Down Expand Up @@ -466,10 +459,6 @@ def emit_pipeline_with_allocations(
in_specs, out_specs, in_out_specs
)
del in_specs, out_specs, should_accumulate_out, in_out_specs
pipeline_specs_with_nones = pipeline_specs
pipeline_specs = jax.tree_util.tree_map(
_replace_nones_in_block_spec, pipeline_specs_with_nones
)

def make_pipeline_refs(
*ref_args: PipelineRefs,
Expand Down Expand Up @@ -572,7 +561,7 @@ def make_in_out_existing_allocations(spec, ref):

def pipeline(
*ref_args: PipelineRefs,
scratchs: Union[PipelineRefs, None] = None,
scratchs: PipelineRefs = None,
allocations: Union[
None,
tuple[PipelineArg[PipelineBuffers], PipelineArg[PipelineAllocations]],
Expand Down Expand Up @@ -615,7 +604,7 @@ def init_buffer_ref(_, buffer_ref):

zero_indices = (jnp.array(0, dtype=jnp.int32),) * len(grid)
last_indices = tuple(
[jnp.asarray(dim_size - 1, dtype=jnp.int32) for dim_size in grid]
[jnp.array(dim_size - 1, dtype=jnp.int32) for dim_size in grid]
)
indices = zero_indices
pipeline_buffers: PipelineArg[PipelineBuffers] = tree_util.tree_map(
Expand Down Expand Up @@ -816,39 +805,28 @@ def run_epilogue():
with tpu_primitives.trace("ep_kernel"):

def grab_body_ref(
spec_with_nones,
spec,
allocation,
buffers,
existing_allocation,
in_out_existing_allocation=None,
):
if existing_allocation is None:
buffer_slice = tuple([
0 if dim is None else slice(None)
for dim in spec_with_nones.block_shape
])
return allocation.vmem_ref.at[buffers.current, *buffer_slice]
return allocation.vmem_ref.at[buffers.current]
dma_slice = _run_block_spec(spec, indices)
dma_slice = tuple([
0 if dim is None else _slice
for dim, _slice in zip(spec_with_nones.block_shape, dma_slice)
])
if in_out_existing_allocation is None:
return existing_allocation.at[dma_slice]
return in_out_existing_allocation.at[dma_slice]

in_args = tree_util.tree_map(
grab_body_ref,
pipeline_specs_with_nones.input,
pipeline_specs.input,
pipeline_allocations.input,
pipeline_buffers.input,
pipeline_existing_allocations.input,
)
out_args = tree_util.tree_map(
grab_body_ref,
pipeline_specs_with_nones.out,
pipeline_specs.out,
pipeline_allocations.out,
pipeline_buffers.out,
Expand Down
73 changes: 2 additions & 71 deletions jax/_src/pallas/utils.py
Expand Up @@ -14,11 +14,11 @@

"""Pallas utility functions."""
import math
import numpy as np

from jax import lax
from jax._src import core as jax_core
from jax._src.util import split_list
import jax.numpy as jnp
import numpy as np


def when(condition):
Expand Down Expand Up @@ -90,72 +90,3 @@ def pattern_match_scan_to_fori_loop(
# expect a loop index as an argument.
has_loop_index = False
return jaxpr, has_loop_index


def pattern_match_while_to_fori_loop(
cond_jaxpr: jax_core.Jaxpr,
cond_nconsts: int,
body_jaxpr: jax_core.Jaxpr,
body_nconsts: int,
) -> tuple[jax_core.Jaxpr, bool]:
# Try to pattern match to fori loop.
if cond_nconsts:
raise NotImplementedError("Conditional jaxpr can't contain consts.")
_, cond_invars = split_list(cond_jaxpr.jaxpr.invars, [cond_nconsts])
cond_in_avals = [v.aval for v in cond_invars]
if len(cond_in_avals) < 2:
raise NotImplementedError("Conditional jaxpr have only two carry args.")
# Check that the first two carry values are scalar ints
a1, a2 = cond_in_avals[:2]
if a1.shape or a1.dtype not in (jnp.int32, jnp.int64):
raise NotImplementedError(
"First conditional jaxpr carry arg is not a scalar int."
)
if a2.shape or a2.dtype not in (jnp.int32, jnp.int64):
raise NotImplementedError(
"Second conditional jaxpr carry arg is not a scalar int."
)
# Check that the only eqn in the cond checks the loop index condition
v1, v2 = cond_invars[:2]
outvar = cond_jaxpr.jaxpr.outvars[0]
assert outvar.aval.dtype == jnp.bool_
if len(cond_jaxpr.jaxpr.eqns) != 1:
raise NotImplementedError("Non-trivial conditional jaxprs not supported.")
eqn = cond_jaxpr.jaxpr.eqns[0]
if eqn.primitive != lax.lt_p:
raise NotImplementedError("Non-trivial conditional jaxprs not supported.")
if eqn.outvars != [outvar]:
raise NotImplementedError("Non-trivial conditional jaxprs not supported.")
if eqn.invars != [v1, v2]:
raise NotImplementedError("Non-trivial conditional jaxprs not supported.")
# Check that the carry is updated in the body appropriately
_, body_invars = split_list(body_jaxpr.jaxpr.invars, [body_nconsts])
v1, v2 = body_invars[:2]
vo1, vo2 = body_jaxpr.jaxpr.outvars[:2]
# Upper bound should be constant
if v2 is not vo2:
raise NotImplementedError("Loop upper bound is not constant.")
# Check that we increment the loop index in the body
for i, eqn in enumerate(body_jaxpr.jaxpr.eqns):
if eqn.primitive is lax.add_p:
if eqn.invars[0] is v1:
if isinstance(eqn.invars[1], jax_core.Literal):
if eqn.invars[1].val == 1:
if eqn.outvars[0] == vo1:
eqn_index = i
break
else:
raise NotImplementedError("Loop index not incremented in body.")
jaxpr = body_jaxpr.jaxpr
new_invars = (
*jaxpr.invars[:body_nconsts],
jaxpr.invars[body_nconsts],
*jaxpr.invars[body_nconsts + 2 :],
)
new_outvars = tuple(jaxpr.outvars[2:])
jaxpr = jaxpr.replace(
eqns=jaxpr.eqns[:eqn_index] + jaxpr.eqns[eqn_index + 1 :],
invars=new_invars,
outvars=new_outvars,
)
return jaxpr
27 changes: 0 additions & 27 deletions tests/pallas/pallas_call_tpu_test.py
Expand Up @@ -220,33 +220,6 @@ def single_inst(i, _):

np.testing.assert_allclose(out, expected)

def test_scalar_interpreter_dynamic_loop(self):
loop_end = jnp.array([5], jnp.int32)

def body(loop_end_ref, out_ref):
out_ref[...] = jnp.zeros_like(out_ref)

def loop_body(i, carry):
del i, carry
out_ref[...] += 1

lax.fori_loop(0, loop_end_ref[0], loop_body, None)

out = pl.pallas_call(
body,
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1,
out_specs=pl.BlockSpec(lambda *_: (0, 0), (8, 128)),
grid=1,
),
interpret=self.interpret,
debug=False,
)(loop_end)

expected_out = jnp.ones((8, 128), jnp.float32) * 5
np.testing.assert_allclose(out, expected_out)

def test_vmap_scalar_prefetch_1sized(self):
def body(_, x_ref, o_ref):
o_ref[...] = x_ref[...]
Expand Down

0 comments on commit 8b1da12

Please sign in to comment.