From a37c292d02cca22b155b4674d1c35336f6879cc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Wed, 27 Sep 2023 23:47:51 -0700 Subject: [PATCH] [Mosaic] apply_vector_layout C++ rewrite (10): vector.extract_strided_slice PiperOrigin-RevId: 569081032 --- .../tpu/transforms/apply_vector_layout.cc | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 6817316a7d7e..d9de3d8f2675 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -12,6 +12,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -1115,6 +1116,70 @@ LogicalResult arith_constant_rule(RewriteContext &ctx, Operation &op, << op.getResult(0).getType(); } +LogicalResult vector_extract_strided_slice_rule( + RewriteContext &ctx, Operation &op, const ArrayRef layouts_in, + const ArrayRef layouts_out) { + CHECK_EQ(layouts_in.size(), 1); + CHECK_EQ(layouts_out.size(), 1); + if (!layouts_in.front().has_value()) { + return op.emitOpError("Expected non-null input layout"); + } + if (!layouts_out.front().has_value()) { + return op.emitOpError("Expected non-null output layout"); + } + const VectorLayout &layout_in = *layouts_in.front(); + const VectorLayout &layout_out = *layouts_out.front(); + if (!layout_in.hasNaturalTopology(ctx.target_shape)) { + return op.emitOpError("Not implemented: Unsupported input layout"); + } + if (layout_out != layout_in) { + return op.emitOpError("Not implemented: Unsupported output layout"); + } + vector::ExtractStridedSliceOp extract_strided_slice_op = + cast(op); + const ArrayRef tiled_dims = + extract_strided_slice_op.getVector().getType().getShape().take_back(2); + if (tiled_dims[0] % layout_in.tiling()[0] != 0 || + tiled_dims[1] % layout_in.tiling()[1] != 0) { + return op.emitOpError( + "Not implemented: Extract strides slices only works with operands with " + "sizes that are multiples of the native tiling"); + } + + auto I64ArrayToSmallVector = [&](const ArrayAttr array_attr) { + return llvm::map_to_vector(array_attr, [](Attribute attr) { + return cast(attr).getValue().getSExtValue(); + }); + }; + + // We currently only support zero-offset, tile-aligned slices. This implies + // the output layout is merely a slice of the input layout, without needing to + // modify physical any of the vregs' layouts. + const SmallVector offsets = + I64ArrayToSmallVector(extract_strided_slice_op.getOffsets()); + for (const int64_t offset : ArrayRef(offsets).take_back(2)) { + if (offset != 0) { + return extract_strided_slice_op.emitOpError( + "Not implemented: Only tile-aligned slices supported"); + } + } + + const SmallVector slice_sizes = + I64ArrayToSmallVector(extract_strided_slice_op.getSizes()); + const SmallVector slice_tiled_shape = + layout_in.tileArrayShape(slice_sizes, ctx.target_shape); + FAILUREOR_ASSIGN_OR_RETURN( + const xla::Array input_tiles, + disassemble(ctx, layout_in, extract_strided_slice_op.getVector())); + const xla::Array dst_tiles = + input_tiles.Slice(offsets, slice_tiled_shape); + const VectorType dst_ty = extract_strided_slice_op.getResult().getType(); + extract_strided_slice_op.replaceAllUsesWith( + assemble(ctx, dst_ty, layout_out, dst_tiles).getOperation()); + extract_strided_slice_op.erase(); + return success(); +} + LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op, const ArrayRef layouts_in, const ArrayRef layouts_out) { @@ -1329,6 +1394,8 @@ const llvm::StringMap &rules() { {tpu::StoreOp::getOperationName(), tpu_store_rule}, {tpu::TraceOp::getOperationName(), tpu_trace_rule}, {vector::LoadOp::getOperationName(), vector_load_rule}, + {vector::ExtractStridedSliceOp::getOperationName(), + vector_extract_strided_slice_rule}, {vector::StoreOp::getOperationName(), vector_store_rule}}; return *rules; }