Skip to content

Commit

Permalink
Fold standalone linalg.fill ops into flow.tensor.splat ops (#5614)
Browse files Browse the repository at this point in the history
This allows us to use DMA instead of kernels for pure data
fills. This is another step towards performance: it
further decreases the number of dispatches for MobileNetv2
from 94 to 76, and reduced the kernel execution latency by
3ms (17ms -> 14ms) on Galaxy S20 (Mali G77).
  • Loading branch information
antiagainst committed Jul 22, 2021
1 parent 268a305 commit 323108e
Show file tree
Hide file tree
Showing 10 changed files with 112 additions and 47 deletions.
50 changes: 42 additions & 8 deletions iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,24 +251,56 @@ struct SubTensorToTensorSlice
}
};

/// Converts linalg.fill ops into flow.tensor.splat ops.
///
/// This is expected to improve performance because we can use DMA
/// functionalities for the fill, instead of dispatching kernels.
struct LinalgFillToFlowTensorSplat final
: public OpRewritePattern<linalg::FillOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(linalg::FillOp fillOp,
PatternRewriter &rewriter) const override {
if (fillOp->getParentOfType<IREE::Flow::DispatchWorkgroupsOp>()) {
// Don't convert linalg.fill ops that were fused together with other ops.
return failure();
}

SmallVector<Value, 4> dynamicDims =
getDynamicDimValues(rewriter, fillOp.getLoc(), fillOp.output());
rewriter.replaceOpWithNewOp<TensorSplatOp>(
fillOp, fillOp.output().getType(), fillOp.value(), dynamicDims);
return success();
}
};

/// Converts operations that can map to flow.tensor.* operations.
struct ConvertToFlowTensorOpsPass
: public ConvertToFlowTensorOpsBase<ConvertToFlowTensorOpsPass> {
ConvertToFlowTensorOpsPass(bool runBefore) {
runBeforeDispatchRegionFormation = runBefore;
}
ConvertToFlowTensorOpsPass(const ConvertToFlowTensorOpsPass &that) {
runBeforeDispatchRegionFormation = that.runBeforeDispatchRegionFormation;
}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::Flow::FlowDialect, memref::MemRefDialect,
mlir::StandardOpsDialect>();
}
ConvertToFlowTensorOpsPass() = default;
ConvertToFlowTensorOpsPass(const ConvertToFlowTensorOpsPass &pass) {}
void runOnOperation() override {
FuncOp funcOp = getOperation();
MLIRContext *context = funcOp->getContext();
context->allowUnregisteredDialects(true);
RewritePatternSet patterns(&getContext());
patterns.insert<
LinalgTensorReshapeToFlowTensorReshape<linalg::TensorCollapseShapeOp>,
LinalgTensorReshapeToFlowTensorReshape<linalg::TensorExpandShapeOp>,
SubTensorInsertToTensorUpdate, SubTensorToTensorSlice>(context);
if (runBeforeDispatchRegionFormation) {
patterns.insert<
LinalgTensorReshapeToFlowTensorReshape<linalg::TensorCollapseShapeOp>,
LinalgTensorReshapeToFlowTensorReshape<linalg::TensorExpandShapeOp>,
SubTensorInsertToTensorUpdate, SubTensorToTensorSlice>(context);
} else {
patterns.insert<LinalgFillToFlowTensorSplat>(context);
}
IREE::Flow::TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
Expand All @@ -277,8 +309,10 @@ struct ConvertToFlowTensorOpsPass
};
} // namespace

std::unique_ptr<OperationPass<FuncOp>> createConvertToFlowTensorOpsPass() {
return std::make_unique<ConvertToFlowTensorOpsPass>();
std::unique_ptr<OperationPass<FuncOp>> createConvertToFlowTensorOpsPass(
bool runBeforeDispatchRegionFormation) {
return std::make_unique<ConvertToFlowTensorOpsPass>(
runBeforeDispatchRegionFormation);
}

} // namespace Flow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,12 @@ static bool isDispatchableOp(Operation *op) {
!isa<tensor::ExtractSliceOp, tensor::InsertSliceOp>(op)) {
return false;
}

// Mark linalg.fill as non-dispatchable so that for those linalg.fill ops that
// cannot be fused together with some root op, they are left out of dispatch
// region formation, and to be picked up by DMA op conversion.
if (isa<linalg::FillOp>(op)) return false;

return !isAlwaysClonedIntoDispatchOp(op);
}

Expand Down
8 changes: 6 additions & 2 deletions iree/compiler/Dialect/Flow/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,17 @@ void buildFlowTransformPassPipeline(OpPassManager &passManager) {
passManager.addNestedPass<FuncOp>(createInterchangeGenericOpsPass());
passManager.addNestedPass<FuncOp>(mlir::createCanonicalizerPass());
passManager.addNestedPass<FuncOp>(createFusionOfTensorOpsPass());
passManager.addNestedPass<FuncOp>(
IREE::Flow::createConvertToFlowTensorOpsPass());
passManager.addNestedPass<FuncOp>(mlir::createCSEPass());
passManager.addPass(memref::createResolveShapedTypeResultDimsPass());
passManager.addNestedPass<FuncOp>(
IREE::Flow::createConvertToFlowTensorOpsPass(
/*runBeforeDispatchRegionFormation=*/true));
passManager.addNestedPass<FuncOp>(
IREE::Flow::createDispatchLinalgOnTensorsPass());
passManager.addPass(memref::createResolveShapedTypeResultDimsPass());
passManager.addNestedPass<FuncOp>(
IREE::Flow::createConvertToFlowTensorOpsPass(
/*runBeforeDispatchRegionFormation=*/false));
// NOTE: required because the current dispatch-linalg-on-tensors pass
// creates a lot of dead IR that needs to be cleaned up.
passManager.addNestedPass<FuncOp>(mlir::createCanonicalizerPass());
Expand Down
10 changes: 6 additions & 4 deletions iree/compiler/Dialect/Flow/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,12 @@ std::unique_ptr<Pass> createFusionOfTensorOpsPass();
/// the most inner loops.
std::unique_ptr<OperationPass<FuncOp>> createInterchangeGenericOpsPass();

// Convert operations to equivalent flow.tensor.* ops. This is run after
// dispatch region creation to catch operations that were left outside of
// dispatch regions and could be represented as flow.tensor.* ops.
std::unique_ptr<OperationPass<FuncOp>> createConvertToFlowTensorOpsPass();
// Convert operations to equivalent flow.tensor.* ops.
// `runBeforeDispatchRegionFormation` controls whether to run before dispatch
// region creation. If run after, it will catch operations that were left
// outside of dispatch regions and could be represented as flow.tensor.* ops.
std::unique_ptr<OperationPass<FuncOp>> createConvertToFlowTensorOpsPass(
bool runBeforeDispatchRegionFormation = true);

// Promote I1 tensor constants to I8 tensors to match later operations.
std::unique_ptr<OperationPass<FuncOp>> createPromoteI1ToI8Pass();
Expand Down
4 changes: 4 additions & 0 deletions iree/compiler/Dialect/Flow/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def ConvertToFlowTensorOps :
Pass<"iree-flow-convert-to-flow-tensor-ops-pass", "FuncOp"> {
let summary = "Convert operations to equivalent flow.tensor.* operations";
let constructor = "mlir::iree_compiler::IREE::Flow::createConvertToFlowTensorOpsPass()";
let options = [
Option<"runBeforeDispatchRegionFormation", "run-before-dispatch-region-formation",
"bool", /*default=*/"true", "Run the pass before dispatch region formation">
];
}

def DeduplicateExecutables :
Expand Down
3 changes: 2 additions & 1 deletion iree/compiler/Dialect/Flow/Transforms/test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ iree_lit_test_suite(
[
"conv1x1_to_matmul.mlir",
"conv2d_to_img2col.mlir",
"convert_to_flow_tensor_ops.mlir",
"convert_to_flow_tensor_ops_after.mlir",
"convert_to_flow_tensor_ops_before.mlir",
"deduplicate_executables.mlir",
"dispatch_linalg_on_tensors.mlir",
"dispatch_linalg_on_tensors_elementwise.mlir",
Expand Down
3 changes: 2 additions & 1 deletion iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ iree_lit_test_suite(
SRCS
"conv1x1_to_matmul.mlir"
"conv2d_to_img2col.mlir"
"convert_to_flow_tensor_ops.mlir"
"convert_to_flow_tensor_ops_after.mlir"
"convert_to_flow_tensor_ops_before.mlir"
"deduplicate_executables.mlir"
"dispatch_linalg_on_tensors.mlir"
"dispatch_linalg_on_tensors_elementwise.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: iree-opt -iree-flow-convert-to-flow-tensor-ops-pass='run-before-dispatch-region-formation=false' -canonicalize -cse -split-input-file %s | IreeFileCheck %s

func @turn_fill_into_splat(%arg0: tensor<?x?xf32>, %arg1: tensor<f32>, %arg2: index, %arg3: index, %arg4: index, %arg5: index) -> tensor<?x?xf32> {
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = tensor.extract %arg1[] : tensor<f32>
%1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%2 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%3 = affine.apply affine_map<(d0)[s0, s1] -> (d0 + s0 + s1)>(%1)[%arg2, %arg4]
%4 = affine.apply affine_map<(d0)[s0, s1] -> (d0 + s0 + s1)>(%2)[%arg3, %arg5]
%5 = linalg.init_tensor [%3, %4] : tensor<?x?xf32>
%6 = linalg.fill(%0, %5) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%7 = flow.tensor.update %arg0, %6[%arg2, %arg3] : tensor<?x?xf32>{%1, %2} -> tensor<?x?xf32>{%3, %4}
return %7 : tensor<?x?xf32>
}

// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 + s2)>
// CHECK: func @turn_fill_into_splat
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<f32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK: %[[VAL:.+]] = tensor.extract %[[ARG1]][]
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK-DAG: %[[RD0:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]], %[[D0]]]
// CHECK-DAG: %[[RD1:.+]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG5]], %[[D1]]]
// CHECK: %[[SPLAT:.+]] = flow.tensor.splat %[[VAL]] : tensor<?x?xf32>{%[[RD0]], %[[RD1]]}
// CHECK: flow.tensor.update %[[ARG0]], %[[SPLAT]]
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,11 @@ func @always_fuse_reshape

// -----

func @fuse_tensor_update_with_fill(%arg0: tensor<?x?xf32>, %arg1: tensor<f32>, %arg2: index,
%arg3: index, %arg4: index, %arg5: index) -> tensor<?x?xf32> {
// A subsequent pass is expected to convert linalg.fill and flow.tensor.update into DMA ops.
func @dont_fuse_tensor_update_with_fill(
%arg0: tensor<?x?xf32>, %arg1: tensor<f32>,
%arg2: index, %arg3: index, %arg4: index, %arg5: index)
-> tensor<?x?xf32> {
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = tensor.extract %arg1[] : tensor<f32>
Expand All @@ -336,29 +339,9 @@ func @fuse_tensor_update_with_fill(%arg0: tensor<?x?xf32>, %arg1: tensor<f32>, %
return %7 : tensor<?x?xf32>
}

// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 + s2)>
// CHECK: func @fuse_tensor_update_with_fill
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<f32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK-DAG: %[[RD0:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]], %[[D0]]]
// CHECK-DAG: %[[RD1:.+]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG5]], %[[D1]]]
// CHECK: %[[RESULT:.+]] = flow.dispatch.workgroups
// CHECK-SAME: [%[[RD1]], %[[RD0]], %[[C1]]]
// CHECK-SAME: (%[[ARG1]], %[[RD0]], %[[RD1]])
// CHECK-DAG: %[[VAL:.+]] = tensor.extract
// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor
// CHECK: %[[RETURN:.+]] = linalg.fill(%[[VAL]], %[[INIT]])
// CHECK: flow.dispatch.tensor.store %[[RETURN]], {{.*}}
// CHECK-NEXT: flow.return
// CHECK: flow.tensor.update %[[ARG0]], %[[RESULT]]
// CHECK: func @dont_fuse_tensor_update_with_fill
// CHECK: linalg.fill
// CHECK: flow.tensor.update

// -----

Expand Down Expand Up @@ -479,6 +462,7 @@ func @depthwise_conv2d(%input: tensor<1x113x113x96xf32>, %filter: tensor<3x3x96x

// -----

// A subsequent pass is expected to convert linalg.fill into DMA ops.
func @subtensor_insert(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x225x225x3xf32> {
%cst = constant 0.000000e+00 : f32
%0 = linalg.init_tensor [1, 225, 225, 3] : tensor<1x225x225x3xf32>
Expand All @@ -490,12 +474,8 @@ func @subtensor_insert(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x225x225x3xf32
// CHECK: func @subtensor_insert
// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x224x224x3xf32>)
//
// CHECK: %[[FILL:.+]] = flow.dispatch.workgroups[{{.+}}]() : () -> tensor<1x225x225x3xf32> =
// CHECK-NEXT: (%[[OUTPUT:.+]]: !flow.dispatch.tensor<writeonly:1x225x225x3xf32>) {
// CHECK: linalg.init_tensor
// CHECK-NEXT: %[[TENSOR:.+]] = linalg.fill
// CHECK-NEXT: flow.dispatch.tensor.store %[[TENSOR]], %[[OUTPUT]], {{.*}}
// CHECK-NEXT: flow.return
// CHECK-NOT: flow.dispatch.workgroups
// CHECK: %[[FILL:.+]] = linalg.fill
//
// CHECK: %[[PAD:.+]] = flow.dispatch.workgroups[{{.+}}](%[[INPUT]], %[[FILL]]) : (tensor<1x224x224x3xf32>, tensor<1x225x225x3xf32>) -> %[[FILL]] =
// CHECK-NEXT: (%[[SRC:.+]]: !flow.dispatch.tensor<readonly:1x224x224x3xf32>, %[[DST:.+]]: !flow.dispatch.tensor<readwrite:1x225x225x3xf32>) {
Expand Down

0 comments on commit 323108e

Please sign in to comment.