Skip to content

Commit

Permalink
Relax SliceOp to being able to be a root op.
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhanW committed Nov 17, 2020
1 parent 9f8cb9d commit 3dcdd2f
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 58 deletions.
84 changes: 51 additions & 33 deletions iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
Expand Up @@ -665,43 +665,55 @@ LogicalResult PadOpConversion::apply(

namespace {
/// Converts mhlo.slice operation to linalg.subview + linalg.copy
struct SliceOpConversion
: public ConvertToLinalgBufferOp<SliceOpConversion, mhlo::SliceOp> {
using ConvertToLinalgBufferOp<SliceOpConversion,
mhlo::SliceOp>::ConvertToLinalgBufferOp;
struct SliceOpConversion : public OpConversionPattern<mhlo::SliceOp> {
SliceOpConversion(MLIRContext *context,
TensorToBufferMap const &resultTensorToBufferMap,
PatternBenefit benefit = 1)
: OpConversionPattern<mhlo::SliceOp>(context, benefit),
resultTensorToBufferMap(resultTensorToBufferMap) {}

LogicalResult apply(mhlo::SliceOp op, ArrayRef<Value> inputBuffers,
ArrayRef<Value> resultBuffers,
ConversionPatternRewriter &rewriter) const;
};
} // namespace
LogicalResult matchAndRewrite(
mhlo::SliceOp op, ArrayRef<Value> inputBuffers,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto argType = inputBuffers[0].getType().template dyn_cast<ShapedType>();
if (!argType || !argType.hasRank()) {
return op.emitError("expected known-rank args");
}

LogicalResult SliceOpConversion::apply(
mhlo::SliceOp op, ArrayRef<Value> inputBuffers,
ArrayRef<Value> resultBuffers, ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto argType = inputBuffers[0].getType().template dyn_cast<ShapedType>();
if (!argType || !argType.hasRank()) {
return op.emitError("expected known-rank args");
}
auto resultType = op.getResult().getType().cast<ShapedType>();
auto memrefType =
MemRefType::get(resultType.getShape(), resultType.getElementType());
Value fakeBuffer = rewriter.create<AllocOp>(loc, memrefType);
SmallVector<Value, 3> offsets, sizes, strides;
for (int i = 0, e = argType.getRank(); i < e; ++i) {
Value startIndex = rewriter.create<ConstantIndexOp>(
loc, op.start_indices().getValue<int64_t>(i));
offsets.push_back(startIndex);
Value size = rewriter.create<DimOp>(loc, fakeBuffer, i);
sizes.push_back(size);
Value stride = rewriter.create<ConstantIndexOp>(
loc, op.strides().getValue<int64_t>(i));
strides.push_back(stride);
}
auto subViewOp = rewriter.create<SubViewOp>(loc, inputBuffers[0], offsets,
sizes, strides);

SmallVector<Value, 3> offsets, sizes, strides;
for (int i = 0, e = argType.getRank(); i < e; ++i) {
Value startIndex = rewriter.create<ConstantIndexOp>(
loc, op.start_indices().getValue<int64_t>(i));
offsets.push_back(startIndex);
Value size = rewriter.create<DimOp>(loc, resultBuffers[0], i);
sizes.push_back(size);
Value stride = rewriter.create<ConstantIndexOp>(
loc, op.strides().getValue<int64_t>(i));
strides.push_back(stride);
Value bufferForResult = resultTensorToBufferMap.lookup(op.getResult());
if (bufferForResult) {
rewriter.create<linalg::CopyOp>(loc, subViewOp, bufferForResult);
rewriter.replaceOp(op, bufferForResult);
} else {
rewriter.replaceOp(op, subViewOp.getResult());
}

return success();
}
auto subViewOp =
rewriter.create<SubViewOp>(loc, inputBuffers[0], offsets, sizes, strides);
rewriter.create<linalg::CopyOp>(loc, subViewOp, resultBuffers[0]);

return success();
}
private:
TensorToBufferMap const &resultTensorToBufferMap;
};
} // namespace

//===----------------------------------------------------------------------===//
// mhlo.reduce_window conversion patterns and utility functions.
Expand Down Expand Up @@ -1236,7 +1248,8 @@ struct HALInterfaceLoadTensorOpEraser final
// in the original computation the loaded tensor value goes through a chain
// of view-like operations and is used as an operand to a store tensor
// operation.
if (Value outputBuffer = resultTensorToBufferMap.lookup(loadOp.result())) {
Value outputBuffer = resultTensorToBufferMap.lookup(loadOp.result());
if (outputBuffer && outputBuffer.getType() == buffer.getType()) {
rewriter.create<linalg::CopyOp>(loadOp.getLoc(), buffer, outputBuffer);
rewriter.replaceOp(loadOp, outputBuffer);
} else {
Expand Down Expand Up @@ -1373,6 +1386,11 @@ static LogicalResult createAndPropagateBufferUsedForResultTensor(
resultTensorToBufferMap.insert(std::make_pair(tensor, buffer));
continue;
}
if (auto sliceOp = tensor.getDefiningOp<mhlo::SliceOp>()) {
tensor = sliceOp.operand();
if (resultTensorToBufferMap.count(tensor)) break;
resultTensorToBufferMap.insert(std::make_pair(tensor, buffer));
}
break;
}
return success();
Expand Down
57 changes: 34 additions & 23 deletions iree/compiler/Conversion/HLOToLinalg/test/slice.mlir
@@ -1,4 +1,4 @@
// RUN: iree-opt -split-input-file -iree-codegen-hlo-to-linalg-on-buffers -cse %s | IreeFileCheck %s
// RUN: iree-opt -split-input-file -iree-codegen-hlo-to-linalg-pipeline -canonicalize %s | IreeFileCheck %s

module {
// CHECK_LABEL: @slice_whole_buffer
Expand All @@ -25,19 +25,10 @@ module {
// -----

module {
// CHECK: #[[MAP:.+]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0 * 4 + d1 + 4)>
// CHECK: @slice_whole_stride
// CHECK-DAG: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x4xi32>
// CHECK-DAG: %[[IN:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<3x4xi32>
// CHECK-DAG: %[[ZERO:.+]] = constant 0 : index
// CHECK-DAG: %[[ONE:.+]] = constant 1 : index
// CHECK-DAG: %[[DIM0:.+]] = dim %[[OUT]], %[[ZERO]] : memref<1x4xi32>
// CHECK-DAG: %[[DIM1:.+]] = dim %[[OUT]], %[[ONE]] : memref<1x4xi32>
// CHECK: subview %[[IN]]
// CHECK-SAME: [%[[ONE]], %[[ZERO]]]
// CHECK-SAME: [%[[DIM0]], %[[DIM1]]]
// CHECK-SAME: [%[[ONE]], %[[ONE]]]
// CHECK-SAME: : memref<3x4xi32> to memref<?x?xi32, #[[MAP]]>
// CHECK: subview %[[IN]][1, 0] [1, 4] [1, 1] : memref<3x4xi32> to memref<1x4xi32, #[[MAP]]>
// CHECK: linalg.copy
func @slice_whole_stride()
attributes {signature = (tensor<3x4xi32>) -> (tensor<1x4xi32>)} {
Expand All @@ -60,19 +51,10 @@ module {
// -----

module {
// CHECK: #[[MAP:.+]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0 * 4 + d1 + 5)>
// CHECK: @slice_stride_part
// CHECK-DAG: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x2xi32>
// CHECK-DAG: %[[IN:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<3x4xi32>
// CHECK-DAG: %[[ZERO:.+]] = constant 0 : index
// CHECK-DAG: %[[ONE:.+]] = constant 1 : index
// CHECK-DAG: %[[DIM0:.+]] = dim %[[OUT]], %[[ZERO]] : memref<1x2xi32>
// CHECK-DAG: %[[DIM1:.+]] = dim %[[OUT]], %[[ONE]] : memref<1x2xi32>
// CHECK: subview %[[IN]]
// CHECK-SAME: [%[[ONE]], %[[ONE]]]
// CHECK-SAME: [%[[DIM0]], %[[DIM1]]]
// CHECK-SAME: [%[[ONE]], %[[ONE]]]
// CHECK-SAME: : memref<3x4xi32> to memref<?x?xi32, #map>
// CHECK: subview %[[IN]][1, 1] [1, 2] [1, 1] : memref<3x4xi32> to memref<1x2xi32, #[[MAP]]>
// CHECK: linalg.copy
func @slice_stride_part()
attributes {signature = (tensor<3x4xi32>) -> (tensor<1x2xi32>)} {
Expand All @@ -91,3 +73,32 @@ module {
hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write"
}
}

// -----

// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0 * 4 + d1 + 5)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: func @slice_stride_part
// CHECK: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0, operand_result_index = 2 : i32} : memref<1x2xi32>
// CHECK: %[[IN0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<3x4xi32>
// CHECK: %[[SUBVIEW:.+]] = subview %[[IN0]][1, 1] [1, 2] [1, 1] : memref<3x4xi32> to memref<1x2xi32, #[[MAP0]]>
// CHECK: %[[IN1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1, operand_result_index = 1 : i32} : memref<1x2xi32>
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[SUBVIEW]], %[[IN1]]
// CHECK-SAME: outs(%[[OUT]]
module {
func @slice_stride_part() {
%c0 = constant 0 : index
%0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 {operand_result_index = 0 : i32} : tensor<3x4xi32>
%1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 {operand_result_index = 1 : i32} : tensor<1x2xi32>
%2 = "mhlo.slice"(%0) {limit_indices = dense<[2, 3]> : tensor<2xi64>, start_indices = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32>
%3 = mhlo.add %2, %1 : tensor<1x2xi32>
hal.interface.store.tensor %3, @legacy_io::@ret0, offset = %c0 {operand_result_index = 2 : i32} : tensor<1x2xi32>
return
}
hal.interface @legacy_io attributes {sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
}
}
3 changes: 2 additions & 1 deletion iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
Expand Up @@ -33,7 +33,8 @@ namespace {
// from this exclusion list eventually.
bool isUnsupportedFusionOp(Operation *op) {
return isa<mhlo::ConcatenateOp, mhlo::ConvOp, mhlo::DotGeneralOp, mhlo::DotOp,
mhlo::PadOp, mhlo::ReduceOp, mhlo::ReduceWindowOp, mhlo::SliceOp,
mhlo::PadOp, mhlo::ReduceOp, mhlo::ReduceWindowOp,
mhlo::SliceOp,
mhlo::TorchIndexSelectOp>(op);
}

Expand Down
Expand Up @@ -203,14 +203,35 @@ bool isDispatchRegionMergable(DispatchRegionOp &regionOp) {
// TODO(b/144530470): replace with tablegen attributes/interfaces.
if (isa<mhlo::ConcatenateOp, mhlo::ConvOp, mhlo::DotGeneralOp,
mhlo::DotOp, mhlo::PadOp, mhlo::ReduceOp, mhlo::ReduceWindowOp,
mhlo::SliceOp, mhlo::TorchIndexSelectOp>(op)) {
mhlo::TorchIndexSelectOp>(op)) {
return false;
}
}
}
return regionOp.body().getBlocks().size() == 1;
}

// Returns true if rhs has ops that can only be root op and will lose the
// characteristic if merge two dispatch regions.
bool rhsHasRootOnlyOp(DispatchRegionOp &lhs, DispatchRegionOp &rhs) {
auto &rhsBlock = rhs.body().front();
auto lhsArgs = llvm::to_vector<8>(lhs.args());
auto rhsArgs = llvm::to_vector<8>(rhs.args());
for (int rhsOpIdx = 0; rhsOpIdx < rhsArgs.size(); ++rhsOpIdx) {
for (int lhsResultIdx = 0; lhsResultIdx < lhs.getNumResults();
++lhsResultIdx) {
if (rhsArgs[rhsOpIdx] == lhs.getResult(lhsResultIdx)) {
for (auto *user : rhsBlock.getArgument(rhsOpIdx).getUsers()) {
if (isa<mhlo::SliceOp>(user)) {
return true;
}
}
}
}
}
return false;
}

// Merges |rhs| into |lhs| and returns the new |lhs| op.
// Precondition: !areDispatchRegionsTransitivelyDependent
DispatchRegionOp mergeDispatchRegions(DispatchRegionOp &lhs,
Expand Down Expand Up @@ -345,6 +366,11 @@ LogicalResult mergeBlockDispatchRegions(FuncOp func, Block *parentBlock) {
LLVM_DEBUG(llvm::dbgs()
<< " -REGION CONTAINS NON-TRIVIAL CONTROL FLOW-\n");
}
if (rhsHasRootOnlyOp(lhs, rhs)) {
LLVM_DEBUG(llvm::dbgs()
<< " -RHS REGION HAS ROOT OP-\n");
continue;
}
mergableRegions[i] = mergeDispatchRegions(lhs, rhs);
if (!mergableRegions[i]) {
return failure();
Expand Down
Expand Up @@ -132,3 +132,33 @@ module {
// CHECK-LABEL: func @dominate
// CHECK: flow.dispatch.region
// CHECK-NOT: flow.dispatch.region

// -----

// Test if the op that only can be a root op fuse with consumer but not
// producer. This test use a dummy workload to test on root only op
// functionality.
module {
func @rootOnlyOp(%arg0: tensor<3x4xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
%c0 = constant 0 : index
%0 = flow.dispatch.region[%c0 : index](%arg2 = %arg0 : tensor<3x4xi32>) -> tensor<3x4xi32> {
%3 = mhlo.add %arg2, %arg2 : tensor<3x4xi32>
flow.return %3 : tensor<3x4xi32>
}
%1 = flow.dispatch.region[%c0 : index](%arg2 = %0 : tensor<3x4xi32>) -> tensor<1x2xi32> {
%3 = "mhlo.slice"(%arg2) {limit_indices = dense<[2, 3]> : tensor<2xi64>, start_indices = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32>
flow.return %3 : tensor<1x2xi32>
}
%2 = flow.dispatch.region[%c0 : index](%arg2 = %1 : tensor<1x2xi32>, %arg3 = %arg1 : tensor<1x2xi32>) -> tensor<1x2xi32> {
%3 = mhlo.multiply %arg2, %arg3 : tensor<1x2xi32>
flow.return %3 : tensor<1x2xi32>
}
return %2 : tensor<1x2xi32>
}
}
// CHECK-LABEL: func @rootOnlyOp
// CHECK: flow.dispatch.region
// CHECK-NEXT: mhlo.add
// CHECK: flow.dispatch.region
// CHECK-NEXT: mhlo.slice
// CHECK-NEXT: mhlo.multiply
15 changes: 15 additions & 0 deletions iree/test/e2e/regression/slice_add.mlir
@@ -0,0 +1,15 @@
// RUN: iree-run-mlir -export-all -iree-hal-target-backends=vmla -function-input="3x4xi32=[[1,2,3,4],[5,6,7,8],[9,10,11,12]]" -function-input="1x2xi32=10" %s | IreeFileCheck %s
// RUN: [[ $IREE_LLVMJIT_DISABLE == 1 ]] || (iree-run-mlir -export-all -iree-hal-target-backends=llvm-ir -function-input="3x4xi32=[[1,2,3,4],[5,6,7,8],[9,10,11,12]]" -function-input="1x2xi32=10" %s | IreeFileCheck %s)
// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir -export-all -iree-hal-target-backends=vulkan-spirv -function-input="3x4xi32=[[1,2,3,4],[5,6,7,8],[9,10,11,12]]" -function-input="1x2xi32=10" %s | IreeFileCheck %s)

// CHECK: EXEC @slice_stride_part
// CHECK: 1x2xi32=[16 17]
func @slice_stride_part(%arg0: tensor<3x4xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
%1 = "mhlo.slice"(%arg0) {
start_indices = dense<[1, 1]> : tensor<2xi64>,
limit_indices = dense<[2, 3]> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>
} : (tensor<3x4xi32>) -> tensor<1x2xi32>
%2 = mhlo.add %1, %arg1 : tensor<1x2xi32>
return %2 : tensor<1x2xi32>
}

0 comments on commit 3dcdd2f

Please sign in to comment.