From e66f4e94c4841cca762b09a6147f12963f14cec1 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 7 Nov 2023 07:15:23 -0800 Subject: [PATCH] [Mosaic] Add support for extracting the first element of a vector as a scalar PiperOrigin-RevId: 580169469 --- jax/_src/pallas/mosaic/lowering.py | 6 ++- .../tpu/transforms/apply_vector_layout.cc | 43 +++++++++++++++++++ .../tpu/transforms/infer_vector_layout.cc | 18 ++++++++ jaxlib/mosaic/python/apply_vector_layout.py | 18 ++++++++ 4 files changed, 83 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 42832f8c6626..e019c83311c3 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -921,6 +921,10 @@ def _reshape_lowering_rule(ctx: LoweringRuleContext, x, new_sizes, dimensions): def _squeeze_lowering_rule(ctx: LoweringRuleContext, x, dimensions): del dimensions # Unused. + (aval_in,) = ctx.avals_in + (aval_out,) = ctx.avals_out + if not aval_out.shape: + return vector.ExtractOp(x, [], [0] * len(aval_in.shape)).result return vector.ShapeCastOp(aval_to_ir_type(ctx.avals_out[0]), x).result @@ -1489,9 +1493,7 @@ def _slice_lowering_rule( (aval_out,) = ctx.avals_out if strides is None: strides = [1] * len(start_indices) - sizes = np.array(limit_indices) - np.array(start_indices) - op = vector.ExtractStridedSliceOp( aval_to_ir_type(aval_out), x, start_indices, sizes, strides ) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index a4ec1f5db40a..6390dd5e38a1 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -1986,6 +1986,48 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, } } +LogicalResult vector_extract_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + vector::ExtractOp extract_op = cast(op); + if (extract_op.hasDynamicPosition()) { + return op.emitOpError("Not implemented: dynamic indices"); + } + CHECK_EQ(layouts_in.size(), 1); + CHECK_EQ(layouts_out.size(), 1); + if (!layouts_in.front().has_value()) { + return op.emitOpError("Expected non-null output layout"); + } + const VectorLayout &layout_in = *layouts_in.front(); + if (layouts_out.front().has_value()) { + return op.emitOpError("Not implemented: Only scalar results supported"); + } + if (layout_in.bitwidth() != 32) { + return op.emitOpError( + "Not implemented: Only 32-bit vector.extract supported"); + } + if (layout_in.offsets() != LayoutOffsets{0, 0}) { + return op.emitOpError("Not implemented: Unsupported layout"); + } + ImplicitLocOpBuilder builder(op.getLoc(), &op); + for (int64_t i : extract_op.getStaticPosition()) { + if (i != 0) { + return op.emitOpError("Not implemented: Only 0 indices supported"); + } + } + FAILUREOR_ASSIGN_OR_RETURN( + const xla::Array vregs, + disassemble(ctx, builder, layout_in, extract_op.getVector())); + CHECK_GT(vregs.num_elements(), 0); + extract_op.replaceAllUsesWith( + builder + .create(op.getLoc(), *vregs.data(), + ArrayRef{0, 0}) + .getResult()); + extract_op.erase(); + return success(); +} + LogicalResult vector_contract_rule(RewriteContext &ctx, Operation &op, const ArrayRef layouts_in, const ArrayRef layouts_out) { @@ -2867,6 +2909,7 @@ const llvm::StringMap &rules() { {tpu::TraceOp::getOperationName(), tpu_trace_rule}, {vector::BroadcastOp::getOperationName(), vector_broadcast_rule}, {vector::ContractionOp::getOperationName(), vector_contract_rule}, + {vector::ExtractOp::getOperationName(), vector_extract_rule}, {vector::LoadOp::getOperationName(), vector_load_rule}, {vector::MultiDimReductionOp::getOperationName(), vector_multi_reduction_rule}, diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 73b76b68cdc8..f73bf3662495 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -261,6 +261,10 @@ class VectorLayoutInferer { if (infer(op).failed()) { return failure(); } + } else if (auto op = dyn_cast(any_op)) { + if (infer(op).failed()) { + return failure(); + } } else if (auto op = dyn_cast(any_op)) { if (infer(op).failed()) { return failure(); @@ -833,6 +837,20 @@ class VectorLayoutInferer { return inferMatmul(op); } + LogicalResult infer(vector::ExtractOp op) { + TPU_CHECK_OP(!op.hasDynamicPosition(), "dynamic indices not supported"); + TPU_CHECK_OP( + op.getSourceVectorType().getElementTypeBitWidth() == kNativeBitwidth, + "Only 32-bit types supported"); + auto layout = getLayout(op.getVector()); + TPU_CHECK_OP(layout.has_value(), "missing vector layout"); + setLayout(op, + VectorLayout(kNativeBitwidth, {0, 0}, layout->tiling(), + layout->implicit_dim()), + kNoLayout); + return success(); + } + LogicalResult infer(vector::LoadOp op) { auto src_ty = op.getMemRefType(); auto res_ty = op.getVectorType(); diff --git a/jaxlib/mosaic/python/apply_vector_layout.py b/jaxlib/mosaic/python/apply_vector_layout.py index d7fb8c9abf6f..667ce49f1758 100644 --- a/jaxlib/mosaic/python/apply_vector_layout.py +++ b/jaxlib/mosaic/python/apply_vector_layout.py @@ -2501,6 +2501,24 @@ def _vector_broadcast_rule(ctx: RewriteContext, op: vector.BroadcastOp, # pylin return ctx.replace(op, assemble(dst_ty, layout_out, dst_tiles)) +@_register_rule("vector.extract") +def _vector_extract_rule(ctx: RewriteContext, op: vector.ExtractOp, # pylint: disable=missing-function-docstring + layout_in: Layout, layout_out: VectorLayout): + if layout_out is not None: + raise NotImplementedError("Vector results of extract unsupported") + if layout_in.bitwidth != 32: + raise NotImplementedError("Only 32-bit vector.extract supported") + if layout_in.offsets != (0, 0): + raise NotImplementedError("Unsupported layout") + if len(op.operands) > 1: + raise NotImplementedError("Dynamic indices not supported") + idx = ir.DenseI64ArrayAttr(op.attributes["static_position"]) + if any(i != 0 for i in idx): + raise NotImplementedError("Only 0 indices supported") + vregs = disassemble(layout_in, op.vector) + ctx.replace(op, vector.ExtractOp(vregs.flat[0], [], [0, 0])) + + @_register_rule("vector.load") def _vector_load_rule( # pylint: disable=missing-function-docstring ctx: RewriteContext, op: vector.LoadOp,