Skip to content

Commit

Permalink
[Mosaic] apply_vector_layout C++ rewrite (10): vector.extract_strided…
Browse files Browse the repository at this point in the history
…_slice

PiperOrigin-RevId: 569081032
  • Loading branch information
tlongeri authored and jax authors committed Sep 28, 2023
1 parent fb90d3e commit a37c292
Showing 1 changed file with 67 additions and 0 deletions.
67 changes: 67 additions & 0 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Expand Up @@ -12,6 +12,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
Expand Down Expand Up @@ -1115,6 +1116,70 @@ LogicalResult arith_constant_rule(RewriteContext &ctx, Operation &op,
<< op.getResult(0).getType();
}

LogicalResult vector_extract_strided_slice_rule(
RewriteContext &ctx, Operation &op, const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK_EQ(layouts_in.size(), 1);
CHECK_EQ(layouts_out.size(), 1);
if (!layouts_in.front().has_value()) {
return op.emitOpError("Expected non-null input layout");
}
if (!layouts_out.front().has_value()) {
return op.emitOpError("Expected non-null output layout");
}
const VectorLayout &layout_in = *layouts_in.front();
const VectorLayout &layout_out = *layouts_out.front();
if (!layout_in.hasNaturalTopology(ctx.target_shape)) {
return op.emitOpError("Not implemented: Unsupported input layout");
}
if (layout_out != layout_in) {
return op.emitOpError("Not implemented: Unsupported output layout");
}
vector::ExtractStridedSliceOp extract_strided_slice_op =
cast<vector::ExtractStridedSliceOp>(op);
const ArrayRef<int64_t> tiled_dims =
extract_strided_slice_op.getVector().getType().getShape().take_back(2);
if (tiled_dims[0] % layout_in.tiling()[0] != 0 ||
tiled_dims[1] % layout_in.tiling()[1] != 0) {
return op.emitOpError(
"Not implemented: Extract strides slices only works with operands with "
"sizes that are multiples of the native tiling");
}

auto I64ArrayToSmallVector = [&](const ArrayAttr array_attr) {
return llvm::map_to_vector(array_attr, [](Attribute attr) {
return cast<IntegerAttr>(attr).getValue().getSExtValue();
});
};

// We currently only support zero-offset, tile-aligned slices. This implies
// the output layout is merely a slice of the input layout, without needing to
// modify physical any of the vregs' layouts.
const SmallVector<int64_t> offsets =
I64ArrayToSmallVector(extract_strided_slice_op.getOffsets());
for (const int64_t offset : ArrayRef<int64_t>(offsets).take_back(2)) {
if (offset != 0) {
return extract_strided_slice_op.emitOpError(
"Not implemented: Only tile-aligned slices supported");
}
}

const SmallVector<int64_t> slice_sizes =
I64ArrayToSmallVector(extract_strided_slice_op.getSizes());
const SmallVector<int64_t> slice_tiled_shape =
layout_in.tileArrayShape(slice_sizes, ctx.target_shape);
FAILUREOR_ASSIGN_OR_RETURN(
const xla::Array<Value> input_tiles,
disassemble(ctx, layout_in, extract_strided_slice_op.getVector()));
const xla::Array<Value> dst_tiles =
input_tiles.Slice(offsets, slice_tiled_shape);
const VectorType dst_ty = extract_strided_slice_op.getResult().getType();
extract_strided_slice_op.replaceAllUsesWith(
assemble(ctx, dst_ty, layout_out, dst_tiles).getOperation());
extract_strided_slice_op.erase();
return success();
}

LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
Expand Down Expand Up @@ -1329,6 +1394,8 @@ const llvm::StringMap<rule_type> &rules() {
{tpu::StoreOp::getOperationName(), tpu_store_rule},
{tpu::TraceOp::getOperationName(), tpu_trace_rule},
{vector::LoadOp::getOperationName(), vector_load_rule},
{vector::ExtractStridedSliceOp::getOperationName(),
vector_extract_strided_slice_rule},
{vector::StoreOp::getOperationName(), vector_store_rule}};
return *rules;
}
Expand Down

0 comments on commit a37c292

Please sign in to comment.