From c62c6fc1ab147cac7342644679132317ddc16a73 Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Wed, 27 Sep 2023 13:33:04 -0700 Subject: [PATCH] [Mosaic] Add `sin` and `clamp` lowering rules and support multiple branches in `cond`. Add a pallas_call test using scan/cond. Improve the error message for lowering exceptions and add a `LoweringException` type. PiperOrigin-RevId: 568945255 --- jax/_src/pallas/mosaic/BUILD | 1 + jax/_src/pallas/mosaic/__init__.py | 1 + jax/_src/pallas/mosaic/lowering.py | 90 ++++++++++++++++++++++-------- 3 files changed, 69 insertions(+), 23 deletions(-) diff --git a/jax/_src/pallas/mosaic/BUILD b/jax/_src/pallas/mosaic/BUILD index ff6afd5e3dd7..7f408823d402 100644 --- a/jax/_src/pallas/mosaic/BUILD +++ b/jax/_src/pallas/mosaic/BUILD @@ -35,6 +35,7 @@ py_library_providing_imports_info( deps = [ ":core", ":kernel_regeneration_util", + ":lowering", ":pallas_call_registration", ":primitives", ], diff --git a/jax/_src/pallas/mosaic/__init__.py b/jax/_src/pallas/mosaic/__init__.py index 908df30b7fcc..23db94c7299d 100644 --- a/jax/_src/pallas/mosaic/__init__.py +++ b/jax/_src/pallas/mosaic/__init__.py @@ -21,6 +21,7 @@ from jax._src.pallas.mosaic.core import TPUMemorySpace from jax._src.pallas.mosaic.kernel_regeneration_util import encode_kernel_regeneration_metadata from jax._src.pallas.mosaic.kernel_regeneration_util import extract_kernel_regeneration_metadata +from jax._src.pallas.mosaic.lowering import LoweringException from jax._src.pallas.mosaic.primitives import async_copy from jax._src.pallas.mosaic.primitives import async_remote_copy from jax._src.pallas.mosaic.primitives import device_id diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 165591c9bfc4..b030a9f2ed7e 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -232,6 +232,7 @@ def lower_jaxpr_to_transform_func( body.func_op.verify() return body.func_op + def lower_fun(fun: Callable, *, multiple_results: bool) -> Callable: def f_lowered(ctx: LoweringRuleContext, *args, **params): f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),) @@ -319,6 +320,10 @@ def body_func(*args): return body.func_op +class LoweringException(Exception): + pass + + def jaxpr_subcomp( ctx: LoweringContext, jaxpr: jax_core.Jaxpr, *args: ir.Value ) -> Sequence[ir.Value]: @@ -362,7 +367,21 @@ def write_env(var: jax_core.Var, val): [v.aval for v in eqn.outvars], block_shapes, ) - ans = lowering_rules[eqn.primitive](rule_context, *invals, **eqn.params) + try: + ans = lowering_rules[eqn.primitive]( + rule_context, *invals, **eqn.params + ) + except LoweringException: + raise # We only add the extra info to the innermost exception. + except Exception as e: + raise LoweringException( + f"Exception while lowering eqn:\n {eqn}\nWith context:\n " + f" {rule_context}\nWith inval" + f" shapes={map(lambda t: getattr(t, 'shape', None), invals)}\nWith" + " inval" + f" types={map(lambda t: getattr(t, 'type', None), invals)}\nIn" + f" jaxpr:\n{jaxpr}" + ) from e else: raise NotImplementedError( "Unimplemented primitive in Pallas TPU lowering: " @@ -829,6 +848,8 @@ def _reshape_lowering_rule(ctx: LoweringRuleContext, x, new_sizes, dimensions): raise NotImplementedError if any(d is None for d in new_sizes): raise NotImplementedError + if not ctx.avals_in[0].shape: + return vector.BroadcastOp(aval_to_ir_type(ctx.avals_out[0]), x).result return vector.ShapeCastOp(aval_to_ir_type(ctx.avals_out[0]), x).result @@ -875,13 +896,13 @@ def _transpose_lowering_rule(ctx: LoweringRuleContext, x, *, permutation): def _bcast(x, y, x_aval, y_aval, out_aval): - if isinstance(x, (np.ndarray, np.uint32, int, float)): + if isinstance(x, (np.ndarray, np.number, int, float)): if hasattr(y, "type") and y.type == ir.IndexType.get(): mlir_type = y.type else: mlir_type = mlir.dtype_to_ir_type(x_aval.dtype) x = ir_constant(x, mlir_type) - if isinstance(y, (np.ndarray, np.uint32, int, float)): + if isinstance(y, (np.ndarray, np.number, int, float)): if hasattr(x, "type") and x.type == ir.IndexType.get(): mlir_type = x.type else: @@ -1045,7 +1066,8 @@ def _exp_lowering_rule(ctx: LoweringRuleContext, x): def _pow_lowering_rule(ctx: LoweringRuleContext, x, y): if not isinstance(x, ir.Value) and x == 2.: return math.Exp2Op(y).result - raise NotImplementedError("Only 2^x supported") + x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) + return math.PowFOp(x, y).result lowering_rules[lax.pow_p] = _pow_lowering_rule @@ -1075,10 +1097,11 @@ def _logistic_lowering_rule(ctx: LoweringRuleContext, x): neg_x = arith.NegFOp(x).result exp_neg_x = math.ExpOp(neg_x).result aval_out = ctx.avals_out[0] - out_type = ir.VectorType.get( - aval_out.shape, mlir.dtype_to_ir_type(aval_out.dtype) - ) - one = vector.BroadcastOp(out_type, ir_constant(1.0)) + out_type = aval_to_ir_type(aval_out) + if aval_out.shape == (): + one = ir_constant(1.0, mlir_type=out_type) + else: + one = vector.BroadcastOp(out_type, ir_constant(1.0)) denom = arith.AddFOp(one, exp_neg_x).result return arith.DivFOp(one, denom).result @@ -1086,6 +1109,13 @@ def _logistic_lowering_rule(ctx: LoweringRuleContext, x): lowering_rules[lax.logistic_p] = _logistic_lowering_rule +def _sin_lowering_rule(ctx: LoweringRuleContext, x): + return math.SinOp(x).result + + +lowering_rules[lax.sin_p] = _sin_lowering_rule + + def _tanh_lowering_rule(ctx: LoweringRuleContext, x): return math.TanhOp(x).result @@ -1179,6 +1209,20 @@ def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, x, *args): lowering_rules[lax.select_n_p] = _select_n_lowering_rule + +def _clamp(min, operand, max): + res = jnp.maximum(operand, min) + return jnp.minimum(res, max) + + +def _clamp_lowering_rule(ctx: LoweringRuleContext, min, operand, max): + """Compute minimum_p(maximum_p(min, operand), max).""" + return lower_fun(_clamp, multiple_results=False)(ctx, min, operand, max) + + +lowering_rules[lax.clamp_p] = _clamp_lowering_rule + + def _for_lowering_rule( ctx: LoweringRuleContext, *args, @@ -1211,7 +1255,6 @@ def _for_lowering_rule( lowering_rules[for_loop.for_p] = _for_lowering_rule -skip_mlir_conversions.add(for_loop.for_p) def _lower_jaxpr_to_unrolled_for_loop(ctx: LoweringRuleContext, @@ -1277,35 +1320,36 @@ def _scan_lowering_rule( def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches, linear): - del linear - if len(branches) > 2: - raise NotImplementedError - pred, *args = args + index, *args = args out_types = map(aval_to_ir_type, ctx.avals_out) - pred = arith.TruncIOp( - aval_to_ir_type(jax_core.ShapedArray((), jnp.bool_)), pred + pred = arith.CmpIOp( + arith.CmpIPredicate.ne, index, ir_constant(0, index.type) ).result - # Specialize to singleton `if`s - singleton = len(out_types) == 1 - if singleton: - out_types = out_types[0] if_op = scf.IfOp(pred, out_types, hasElse=True) lowering_context = ctx.lowering_context.replace( block_shapes=ctx.block_shapes[1:], ) with ir.InsertionPoint(if_op.then_block): - out = jaxpr_subcomp(lowering_context, branches[1].jaxpr, *args) + # TODO(b/300272065): Use `scf.IndexSwitchOp` instead of a cascade of + # if/else. + if len(branches) > 2: + out = _cond_lowering_rule( + ctx, + arith.SubIOp(index, ir_constant(1, index.type)).result, + *args, + branches=branches[1:], + linear=linear, + ) + else: + out = jaxpr_subcomp(lowering_context, branches[1].jaxpr, *args) scf.YieldOp(out) with ir.InsertionPoint(if_op.else_block): out = jaxpr_subcomp(lowering_context, branches[0].jaxpr, *args) scf.YieldOp(out) - if singleton: - return if_op.result return if_op.results lowering_rules[lax.cond_p] = _cond_lowering_rule -skip_mlir_conversions.add(lax.cond_p) def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_):