Skip to content

Commit

Permalink
[Mosaic] Add sin and clamp lowering rules and support multiple br…
Browse files Browse the repository at this point in the history
…anches in `cond`. Add a pallas_call test using scan/cond. Improve the error message for lowering exceptions and add a `LoweringException` type.

PiperOrigin-RevId: 568945255
  • Loading branch information
emilyfertig authored and jax authors committed Sep 27, 2023
1 parent 87af945 commit c62c6fc
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 23 deletions.
1 change: 1 addition & 0 deletions jax/_src/pallas/mosaic/BUILD
Expand Up @@ -35,6 +35,7 @@ py_library_providing_imports_info(
deps = [
":core",
":kernel_regeneration_util",
":lowering",
":pallas_call_registration",
":primitives",
],
Expand Down
1 change: 1 addition & 0 deletions jax/_src/pallas/mosaic/__init__.py
Expand Up @@ -21,6 +21,7 @@
from jax._src.pallas.mosaic.core import TPUMemorySpace
from jax._src.pallas.mosaic.kernel_regeneration_util import encode_kernel_regeneration_metadata
from jax._src.pallas.mosaic.kernel_regeneration_util import extract_kernel_regeneration_metadata
from jax._src.pallas.mosaic.lowering import LoweringException
from jax._src.pallas.mosaic.primitives import async_copy
from jax._src.pallas.mosaic.primitives import async_remote_copy
from jax._src.pallas.mosaic.primitives import device_id
Expand Down
90 changes: 67 additions & 23 deletions jax/_src/pallas/mosaic/lowering.py
Expand Up @@ -232,6 +232,7 @@ def lower_jaxpr_to_transform_func(
body.func_op.verify()
return body.func_op


def lower_fun(fun: Callable, *, multiple_results: bool) -> Callable:
def f_lowered(ctx: LoweringRuleContext, *args, **params):
f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
Expand Down Expand Up @@ -319,6 +320,10 @@ def body_func(*args):
return body.func_op


class LoweringException(Exception):
pass


def jaxpr_subcomp(
ctx: LoweringContext, jaxpr: jax_core.Jaxpr, *args: ir.Value
) -> Sequence[ir.Value]:
Expand Down Expand Up @@ -362,7 +367,21 @@ def write_env(var: jax_core.Var, val):
[v.aval for v in eqn.outvars],
block_shapes,
)
ans = lowering_rules[eqn.primitive](rule_context, *invals, **eqn.params)
try:
ans = lowering_rules[eqn.primitive](
rule_context, *invals, **eqn.params
)
except LoweringException:
raise # We only add the extra info to the innermost exception.
except Exception as e:
raise LoweringException(
f"Exception while lowering eqn:\n {eqn}\nWith context:\n "
f" {rule_context}\nWith inval"
f" shapes={map(lambda t: getattr(t, 'shape', None), invals)}\nWith"
" inval"
f" types={map(lambda t: getattr(t, 'type', None), invals)}\nIn"
f" jaxpr:\n{jaxpr}"
) from e
else:
raise NotImplementedError(
"Unimplemented primitive in Pallas TPU lowering: "
Expand Down Expand Up @@ -829,6 +848,8 @@ def _reshape_lowering_rule(ctx: LoweringRuleContext, x, new_sizes, dimensions):
raise NotImplementedError
if any(d is None for d in new_sizes):
raise NotImplementedError
if not ctx.avals_in[0].shape:
return vector.BroadcastOp(aval_to_ir_type(ctx.avals_out[0]), x).result
return vector.ShapeCastOp(aval_to_ir_type(ctx.avals_out[0]), x).result


Expand Down Expand Up @@ -875,13 +896,13 @@ def _transpose_lowering_rule(ctx: LoweringRuleContext, x, *, permutation):


def _bcast(x, y, x_aval, y_aval, out_aval):
if isinstance(x, (np.ndarray, np.uint32, int, float)):
if isinstance(x, (np.ndarray, np.number, int, float)):
if hasattr(y, "type") and y.type == ir.IndexType.get():
mlir_type = y.type
else:
mlir_type = mlir.dtype_to_ir_type(x_aval.dtype)
x = ir_constant(x, mlir_type)
if isinstance(y, (np.ndarray, np.uint32, int, float)):
if isinstance(y, (np.ndarray, np.number, int, float)):
if hasattr(x, "type") and x.type == ir.IndexType.get():
mlir_type = x.type
else:
Expand Down Expand Up @@ -1045,7 +1066,8 @@ def _exp_lowering_rule(ctx: LoweringRuleContext, x):
def _pow_lowering_rule(ctx: LoweringRuleContext, x, y):
if not isinstance(x, ir.Value) and x == 2.:
return math.Exp2Op(y).result
raise NotImplementedError("Only 2^x supported")
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
return math.PowFOp(x, y).result


lowering_rules[lax.pow_p] = _pow_lowering_rule
Expand Down Expand Up @@ -1075,17 +1097,25 @@ def _logistic_lowering_rule(ctx: LoweringRuleContext, x):
neg_x = arith.NegFOp(x).result
exp_neg_x = math.ExpOp(neg_x).result
aval_out = ctx.avals_out[0]
out_type = ir.VectorType.get(
aval_out.shape, mlir.dtype_to_ir_type(aval_out.dtype)
)
one = vector.BroadcastOp(out_type, ir_constant(1.0))
out_type = aval_to_ir_type(aval_out)
if aval_out.shape == ():
one = ir_constant(1.0, mlir_type=out_type)
else:
one = vector.BroadcastOp(out_type, ir_constant(1.0))
denom = arith.AddFOp(one, exp_neg_x).result
return arith.DivFOp(one, denom).result


lowering_rules[lax.logistic_p] = _logistic_lowering_rule


def _sin_lowering_rule(ctx: LoweringRuleContext, x):
return math.SinOp(x).result


lowering_rules[lax.sin_p] = _sin_lowering_rule


def _tanh_lowering_rule(ctx: LoweringRuleContext, x):
return math.TanhOp(x).result

Expand Down Expand Up @@ -1179,6 +1209,20 @@ def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, x, *args):

lowering_rules[lax.select_n_p] = _select_n_lowering_rule


def _clamp(min, operand, max):
res = jnp.maximum(operand, min)
return jnp.minimum(res, max)


def _clamp_lowering_rule(ctx: LoweringRuleContext, min, operand, max):
"""Compute minimum_p(maximum_p(min, operand), max)."""
return lower_fun(_clamp, multiple_results=False)(ctx, min, operand, max)


lowering_rules[lax.clamp_p] = _clamp_lowering_rule


def _for_lowering_rule(
ctx: LoweringRuleContext,
*args,
Expand Down Expand Up @@ -1211,7 +1255,6 @@ def _for_lowering_rule(


lowering_rules[for_loop.for_p] = _for_lowering_rule
skip_mlir_conversions.add(for_loop.for_p)


def _lower_jaxpr_to_unrolled_for_loop(ctx: LoweringRuleContext,
Expand Down Expand Up @@ -1277,35 +1320,36 @@ def _scan_lowering_rule(


def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches, linear):
del linear
if len(branches) > 2:
raise NotImplementedError
pred, *args = args
index, *args = args
out_types = map(aval_to_ir_type, ctx.avals_out)
pred = arith.TruncIOp(
aval_to_ir_type(jax_core.ShapedArray((), jnp.bool_)), pred
pred = arith.CmpIOp(
arith.CmpIPredicate.ne, index, ir_constant(0, index.type)
).result
# Specialize to singleton `if`s
singleton = len(out_types) == 1
if singleton:
out_types = out_types[0]
if_op = scf.IfOp(pred, out_types, hasElse=True)
lowering_context = ctx.lowering_context.replace(
block_shapes=ctx.block_shapes[1:],
)
with ir.InsertionPoint(if_op.then_block):
out = jaxpr_subcomp(lowering_context, branches[1].jaxpr, *args)
# TODO(b/300272065): Use `scf.IndexSwitchOp` instead of a cascade of
# if/else.
if len(branches) > 2:
out = _cond_lowering_rule(
ctx,
arith.SubIOp(index, ir_constant(1, index.type)).result,
*args,
branches=branches[1:],
linear=linear,
)
else:
out = jaxpr_subcomp(lowering_context, branches[1].jaxpr, *args)
scf.YieldOp(out)
with ir.InsertionPoint(if_op.else_block):
out = jaxpr_subcomp(lowering_context, branches[0].jaxpr, *args)
scf.YieldOp(out)
if singleton:
return if_op.result
return if_op.results


lowering_rules[lax.cond_p] = _cond_lowering_rule
skip_mlir_conversions.add(lax.cond_p)


def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_):
Expand Down

0 comments on commit c62c6fc

Please sign in to comment.