diff --git a/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp b/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp index 4118639b9887..07118d94625f 100644 --- a/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp @@ -251,24 +251,56 @@ struct SubTensorToTensorSlice } }; +/// Converts linalg.fill ops into flow.tensor.splat ops. +/// +/// This is expected to improve performance because we can use DMA +/// functionalities for the fill, instead of dispatching kernels. +struct LinalgFillToFlowTensorSplat final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::FillOp fillOp, + PatternRewriter &rewriter) const override { + if (fillOp->getParentOfType()) { + // Don't convert linalg.fill ops that were fused together with other ops. + return failure(); + } + + SmallVector dynamicDims = + getDynamicDimValues(rewriter, fillOp.getLoc(), fillOp.output()); + rewriter.replaceOpWithNewOp( + fillOp, fillOp.output().getType(), fillOp.value(), dynamicDims); + return success(); + } +}; + /// Converts operations that can map to flow.tensor.* operations. struct ConvertToFlowTensorOpsPass : public ConvertToFlowTensorOpsBase { + ConvertToFlowTensorOpsPass(bool runBefore) { + runBeforeDispatchRegionFormation = runBefore; + } + ConvertToFlowTensorOpsPass(const ConvertToFlowTensorOpsPass &that) { + runBeforeDispatchRegionFormation = that.runBeforeDispatchRegionFormation; + } + void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } - ConvertToFlowTensorOpsPass() = default; - ConvertToFlowTensorOpsPass(const ConvertToFlowTensorOpsPass &pass) {} void runOnOperation() override { FuncOp funcOp = getOperation(); MLIRContext *context = funcOp->getContext(); context->allowUnregisteredDialects(true); RewritePatternSet patterns(&getContext()); - patterns.insert< - LinalgTensorReshapeToFlowTensorReshape, - LinalgTensorReshapeToFlowTensorReshape, - SubTensorInsertToTensorUpdate, SubTensorToTensorSlice>(context); + if (runBeforeDispatchRegionFormation) { + patterns.insert< + LinalgTensorReshapeToFlowTensorReshape, + LinalgTensorReshapeToFlowTensorReshape, + SubTensorInsertToTensorUpdate, SubTensorToTensorSlice>(context); + } else { + patterns.insert(context); + } IREE::Flow::TensorReshapeOp::getCanonicalizationPatterns(patterns, context); if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { return signalPassFailure(); @@ -277,8 +309,10 @@ struct ConvertToFlowTensorOpsPass }; } // namespace -std::unique_ptr> createConvertToFlowTensorOpsPass() { - return std::make_unique(); +std::unique_ptr> createConvertToFlowTensorOpsPass( + bool runBeforeDispatchRegionFormation) { + return std::make_unique( + runBeforeDispatchRegionFormation); } } // namespace Flow diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp index ed1a3446eb64..ab3862200ed8 100644 --- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp @@ -255,6 +255,12 @@ static bool isDispatchableOp(Operation *op) { !isa(op)) { return false; } + + // Mark linalg.fill as non-dispatchable so that for those linalg.fill ops that + // cannot be fused together with some root op, they are left out of dispatch + // region formation, and to be picked up by DMA op conversion. + if (isa(op)) return false; + return !isAlwaysClonedIntoDispatchOp(op); } diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp index f8e9c7838baf..6185eac7f6ae 100644 --- a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp @@ -108,13 +108,17 @@ void buildFlowTransformPassPipeline(OpPassManager &passManager) { passManager.addNestedPass(createInterchangeGenericOpsPass()); passManager.addNestedPass(mlir::createCanonicalizerPass()); passManager.addNestedPass(createFusionOfTensorOpsPass()); - passManager.addNestedPass( - IREE::Flow::createConvertToFlowTensorOpsPass()); passManager.addNestedPass(mlir::createCSEPass()); passManager.addPass(memref::createResolveShapedTypeResultDimsPass()); + passManager.addNestedPass( + IREE::Flow::createConvertToFlowTensorOpsPass( + /*runBeforeDispatchRegionFormation=*/true)); passManager.addNestedPass( IREE::Flow::createDispatchLinalgOnTensorsPass()); passManager.addPass(memref::createResolveShapedTypeResultDimsPass()); + passManager.addNestedPass( + IREE::Flow::createConvertToFlowTensorOpsPass( + /*runBeforeDispatchRegionFormation=*/false)); // NOTE: required because the current dispatch-linalg-on-tensors pass // creates a lot of dead IR that needs to be cleaned up. passManager.addNestedPass(mlir::createCanonicalizerPass()); diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.h b/iree/compiler/Dialect/Flow/Transforms/Passes.h index b5cb6054e1b3..5f4403afaad0 100644 --- a/iree/compiler/Dialect/Flow/Transforms/Passes.h +++ b/iree/compiler/Dialect/Flow/Transforms/Passes.h @@ -61,10 +61,12 @@ std::unique_ptr createFusionOfTensorOpsPass(); /// the most inner loops. std::unique_ptr> createInterchangeGenericOpsPass(); -// Convert operations to equivalent flow.tensor.* ops. This is run after -// dispatch region creation to catch operations that were left outside of -// dispatch regions and could be represented as flow.tensor.* ops. -std::unique_ptr> createConvertToFlowTensorOpsPass(); +// Convert operations to equivalent flow.tensor.* ops. +// `runBeforeDispatchRegionFormation` controls whether to run before dispatch +// region creation. If run after, it will catch operations that were left +// outside of dispatch regions and could be represented as flow.tensor.* ops. +std::unique_ptr> createConvertToFlowTensorOpsPass( + bool runBeforeDispatchRegionFormation = true); // Promote I1 tensor constants to I8 tensors to match later operations. std::unique_ptr> createPromoteI1ToI8Pass(); diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.td b/iree/compiler/Dialect/Flow/Transforms/Passes.td index 189d4860bea9..c8fa51c1e842 100644 --- a/iree/compiler/Dialect/Flow/Transforms/Passes.td +++ b/iree/compiler/Dialect/Flow/Transforms/Passes.td @@ -25,6 +25,10 @@ def ConvertToFlowTensorOps : Pass<"iree-flow-convert-to-flow-tensor-ops-pass", "FuncOp"> { let summary = "Convert operations to equivalent flow.tensor.* operations"; let constructor = "mlir::iree_compiler::IREE::Flow::createConvertToFlowTensorOpsPass()"; + let options = [ + Option<"runBeforeDispatchRegionFormation", "run-before-dispatch-region-formation", + "bool", /*default=*/"true", "Run the pass before dispatch region formation"> + ]; } def DeduplicateExecutables : diff --git a/iree/compiler/Dialect/Flow/Transforms/test/BUILD b/iree/compiler/Dialect/Flow/Transforms/test/BUILD index de5f9551ff00..b8e2e4fd6ccd 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/BUILD +++ b/iree/compiler/Dialect/Flow/Transforms/test/BUILD @@ -19,7 +19,8 @@ iree_lit_test_suite( [ "conv1x1_to_matmul.mlir", "conv2d_to_img2col.mlir", - "convert_to_flow_tensor_ops.mlir", + "convert_to_flow_tensor_ops_after.mlir", + "convert_to_flow_tensor_ops_before.mlir", "deduplicate_executables.mlir", "dispatch_linalg_on_tensors.mlir", "dispatch_linalg_on_tensors_elementwise.mlir", diff --git a/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt index a53d36d6a6d4..1d3a4d1ee4b2 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt @@ -16,7 +16,8 @@ iree_lit_test_suite( SRCS "conv1x1_to_matmul.mlir" "conv2d_to_img2col.mlir" - "convert_to_flow_tensor_ops.mlir" + "convert_to_flow_tensor_ops_after.mlir" + "convert_to_flow_tensor_ops_before.mlir" "deduplicate_executables.mlir" "dispatch_linalg_on_tensors.mlir" "dispatch_linalg_on_tensors_elementwise.mlir" diff --git a/iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops_after.mlir b/iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops_after.mlir new file mode 100644 index 000000000000..f126e546b543 --- /dev/null +++ b/iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops_after.mlir @@ -0,0 +1,33 @@ +// RUN: iree-opt -iree-flow-convert-to-flow-tensor-ops-pass='run-before-dispatch-region-formation=false' -canonicalize -cse -split-input-file %s | IreeFileCheck %s + +func @turn_fill_into_splat(%arg0: tensor, %arg1: tensor, %arg2: index, %arg3: index, %arg4: index, %arg5: index) -> tensor { + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = tensor.extract %arg1[] : tensor + %1 = tensor.dim %arg0, %c0 : tensor + %2 = tensor.dim %arg0, %c1 : tensor + %3 = affine.apply affine_map<(d0)[s0, s1] -> (d0 + s0 + s1)>(%1)[%arg2, %arg4] + %4 = affine.apply affine_map<(d0)[s0, s1] -> (d0 + s0 + s1)>(%2)[%arg3, %arg5] + %5 = linalg.init_tensor [%3, %4] : tensor + %6 = linalg.fill(%0, %5) : f32, tensor -> tensor + %7 = flow.tensor.update %arg0, %6[%arg2, %arg3] : tensor{%1, %2} -> tensor{%3, %4} + return %7 : tensor +} + +// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 + s2)> +// CHECK: func @turn_fill_into_splat +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK: %[[VAL:.+]] = tensor.extract %[[ARG1]][] +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[RD0:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]], %[[D0]]] +// CHECK-DAG: %[[RD1:.+]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG5]], %[[D1]]] +// CHECK: %[[SPLAT:.+]] = flow.tensor.splat %[[VAL]] : tensor{%[[RD0]], %[[RD1]]} +// CHECK: flow.tensor.update %[[ARG0]], %[[SPLAT]] diff --git a/iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops.mlir b/iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops_before.mlir similarity index 100% rename from iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops.mlir rename to iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops_before.mlir diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir index 837dc6a47cd6..8382262e3f13 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir @@ -321,8 +321,11 @@ func @always_fuse_reshape // ----- -func @fuse_tensor_update_with_fill(%arg0: tensor, %arg1: tensor, %arg2: index, - %arg3: index, %arg4: index, %arg5: index) -> tensor { +// A subsequent pass is expected to convert linalg.fill and flow.tensor.update into DMA ops. +func @dont_fuse_tensor_update_with_fill( + %arg0: tensor, %arg1: tensor, + %arg2: index, %arg3: index, %arg4: index, %arg5: index) +-> tensor { %c0 = constant 0 : index %c1 = constant 1 : index %0 = tensor.extract %arg1[] : tensor @@ -336,29 +339,9 @@ func @fuse_tensor_update_with_fill(%arg0: tensor, %arg1: tensor, % return %7 : tensor } -// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 + s2)> -// CHECK: func @fuse_tensor_update_with_fill -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index -// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index -// CHECK-DAG: %[[C0:.+]] = constant 0 : index -// CHECK-DAG: %[[C1:.+]] = constant 1 : index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] -// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] -// CHECK-DAG: %[[RD0:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]], %[[D0]]] -// CHECK-DAG: %[[RD1:.+]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG5]], %[[D1]]] -// CHECK: %[[RESULT:.+]] = flow.dispatch.workgroups -// CHECK-SAME: [%[[RD1]], %[[RD0]], %[[C1]]] -// CHECK-SAME: (%[[ARG1]], %[[RD0]], %[[RD1]]) -// CHECK-DAG: %[[VAL:.+]] = tensor.extract -// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor -// CHECK: %[[RETURN:.+]] = linalg.fill(%[[VAL]], %[[INIT]]) -// CHECK: flow.dispatch.tensor.store %[[RETURN]], {{.*}} -// CHECK-NEXT: flow.return -// CHECK: flow.tensor.update %[[ARG0]], %[[RESULT]] +// CHECK: func @dont_fuse_tensor_update_with_fill +// CHECK: linalg.fill +// CHECK: flow.tensor.update // ----- @@ -479,6 +462,7 @@ func @depthwise_conv2d(%input: tensor<1x113x113x96xf32>, %filter: tensor<3x3x96x // ----- +// A subsequent pass is expected to convert linalg.fill into DMA ops. func @subtensor_insert(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x225x225x3xf32> { %cst = constant 0.000000e+00 : f32 %0 = linalg.init_tensor [1, 225, 225, 3] : tensor<1x225x225x3xf32> @@ -490,12 +474,8 @@ func @subtensor_insert(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x225x225x3xf32 // CHECK: func @subtensor_insert // CHECK-SAME: (%[[INPUT:.+]]: tensor<1x224x224x3xf32>) // -// CHECK: %[[FILL:.+]] = flow.dispatch.workgroups[{{.+}}]() : () -> tensor<1x225x225x3xf32> = -// CHECK-NEXT: (%[[OUTPUT:.+]]: !flow.dispatch.tensor) { -// CHECK: linalg.init_tensor -// CHECK-NEXT: %[[TENSOR:.+]] = linalg.fill -// CHECK-NEXT: flow.dispatch.tensor.store %[[TENSOR]], %[[OUTPUT]], {{.*}} -// CHECK-NEXT: flow.return +// CHECK-NOT: flow.dispatch.workgroups +// CHECK: %[[FILL:.+]] = linalg.fill // // CHECK: %[[PAD:.+]] = flow.dispatch.workgroups[{{.+}}](%[[INPUT]], %[[FILL]]) : (tensor<1x224x224x3xf32>, tensor<1x225x225x3xf32>) -> %[[FILL]] = // CHECK-NEXT: (%[[SRC:.+]]: !flow.dispatch.tensor, %[[DST:.+]]: !flow.dispatch.tensor) {