-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[mlir][transform] Use LoopInvariant passes in the tileConsumerAndFuseProducersUsingSCF #163222
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][transform] Use LoopInvariant passes in the tileConsumerAndFuseProducersUsingSCF #163222
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-scf Author: lonely eagle (linuxlonelyeagle) ChangesIn tileConsumerAndFuseProducersUsingSCF, the loop-invariant-code-motion and loop-invariant-subset-hoisting passes are used to reorder the extract ops within the loop, ensuring that the fuseOp is created at the correct position. Otherwise, the fuseOp would only be created in the innermost loop.Besides that, hoisting the extractOp is also an optimization. Full diff: https://github.com/llvm/llvm-project/pull/163222.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 29b770fb4b279..4684ad5dd84ae 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -25,6 +25,7 @@
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
@@ -1316,7 +1317,15 @@ getUntiledProducerFromSliceSource(OpOperand *source,
ArrayRef<LoopLikeOpInterface> loops) {
std::optional<OpOperand *> destinationIterArg;
assert(!loops.empty() && "expected non empty loops container");
+
+ // The `extractOp` may not reside within the innermost loop, calculate the
+ // distance between it and the last LoopLikeInterfaceOp. Adding this
+ // `distance` to `loopIt` yields the start of the loop.
auto loopIt = loops.rbegin();
+ auto parentLoop = source->getOwner()->getParentOfType<LoopLikeOpInterface>();
+ const LoopLikeOpInterface *it = llvm::find(loops, parentLoop);
+ int64_t distance = std::distance(loops.begin(), it);
+ loopIt += (loops.size() - distance - 1);
while (loopIt != loops.rend() && isa<BlockArgument>(source->get())) {
auto iterArg = cast<BlockArgument>(source->get());
auto loop = *loopIt;
@@ -1347,7 +1356,6 @@ mlir::scf::tileAndFuseProducerOfSlice(
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(candidateSliceOp);
-
// 2. Clone the fused producer
// 2a. Compute the destination operands to use for the cloned operation.
SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors;
@@ -1750,6 +1758,15 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
replacements};
}
+ // The extract_slice op is created in the innermost loop by default. Using
+ // `moveLoopInvariantCode` and `hoistLoopInvariantSubsets` improves the
+ // position of the extract_slice op within the loops, allowing the fuse Op to
+ // be created in the correct loop.
+ for (LoopLikeOpInterface loop : loops)
+ (void)moveLoopInvariantCode(loop);
+ for (LoopLikeOpInterface loop : loops)
+ (void)hoistLoopInvariantSubsets(rewriter, loop);
+
// Since the loop gets potentially replaced during fusion, we need to track
// the mutation of replacement values. To do this, we attach a listener to
// update the replacements as they happen.
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
index 8a0390a4379cf..805ccc614c00f 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
@@ -28,9 +28,9 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[INIT:.+]] = tensor.empty
// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] =
// CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT]])
+// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] =
// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
-// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]]
// CHECK: %[[FILL_TILE:.+]] = linalg.fill
@@ -141,6 +141,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[INIT0:.+]] = tensor.empty(%[[D0]], %[[D1]])
// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[RHS1]], %[[C1]]
// CHECK: %[[INIT1:.+]] = tensor.empty(%[[D0]], %[[D2]])
+// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
// CHECK: scf.for %[[IV:[a-zA-Z0-9]+]] =
// CHECK-SAME: iter_args(%[[ITERARG:.+]] = %[[INIT1]])
// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
@@ -151,7 +152,6 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[GEMM0_TILE:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
// CHECK-SAME: outs(%[[FILL0_TILE]] :
-// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG]][%[[IV]], 0]
// CHECK: %[[FILL1_TILE:.+]] = linalg.fill
// CHECK-SAME: outs(%[[INIT1_TILE]] :
@@ -444,6 +444,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ORIG_GEMM2]], %[[C0]]
// CHECK-DAG: %[[N2:.+]] = tensor.dim %[[ORIG_GEMM2]], %[[C1]]
// CHECK-DAG: %[[N3:.+]] = tensor.dim %[[ARG5]], %[[C1]]
+// CHECK-DAG: %[[SLICE_ARG5:.+]] = tensor.extract_slice %[[ARG5]][0, 0] [%[[N2]], %[[N3]]]
// CHECK: %[[R0:.+]] = scf.for %[[IV:[a-zA-Z0-9_]+]] =
// CHECK-SAME: iter_args(%[[ARG8:.+]] = %[[ARG6]]) -> (tensor<?x?xf32>) {
// CHECK-DAG: %[[N1:.+]] = tensor.dim %[[ORIG_GEMM1]], %[[C1]]
@@ -458,7 +459,6 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[SLICE_ARG4:.+]] = tensor.extract_slice %[[ARG4]][%[[IV]], 0] [%[[TILE_M]], %[[N2]]]
// CHECK-DAG: %[[TILE_GEMM2:.+]] = linalg.matmul ins(%[[TILE_GEMM1]], %[[SLICE_ARG3]] :
// CHECK-SAME: outs(%[[SLICE_ARG4]] :
-// CHECK-DAG: %[[SLICE_ARG5:.+]] = tensor.extract_slice %[[ARG5]][0, 0] [%[[N2]], %[[N3]]]
// CHECK-DAG: %[[SLICE_ARG6:.+]] = tensor.extract_slice %[[ARG8]][%[[IV]], 0] [%[[TILE_M]], %[[N3]]]
// CHECK-DAG: %[[TILE_GEMM3:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[TILE_GEMM2]], %[[SLICE_ARG5]] :
@@ -688,3 +688,44 @@ module attributes {transform.with_named_sequence} {
// CHECK: }
// CHECK: }
+// -----
+
+func.func @pooling_ncw_max_fill_fuse(%input: tensor<?x?x?xf32>, %fake: tensor<?xf32>, %init: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ %res = linalg.pooling_ncw_max {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+ ins(%input, %fake: tensor<?x?x?xf32>, tensor<?xf32>)
+ outs(%fill: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ return %res : tensor<?x?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(
+ %arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.pooling_ncw_max"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %tiled_pool, %loops0:4 = transform.structured.fuse %0 {tile_sizes = [1, 16, 1, 1], apply_cleanup = true}
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func.func @pooling_ncw_max_fill_fuse(
+// CHECK-SAME: %[[INPUT:.*]]: tensor<?x?x?xf32>,
+// CHECK-SAME: %[[FAKE:.*]]: tensor<?xf32>,
+// CHECK-SAME: %[[INIT:.*]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] =
+// CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT]])
+// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] =
+// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
+// CHECK: scf.for %[[IV2:[a-zA-Z0-9]+]] =
+// CHECK-SAME: iter_args(%[[ITERARG2:.+]] = %[[ITERARG1]])
+// CHECK: %[[FILL_EXTRACT:.*]] = tensor.extract_slice %[[ITERARG2]]{{\[}}%[[IV0]], %[[IV1]], %[[IV2]]]
+// CHECK: %[[TILED_FILL:.*]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[FILL_EXTRACT]] : tensor<1x?x1xf32>) -> tensor<1x?x1xf32>
+// CHECK: scf.for %[[IV3:[a-zA-Z0-9]+]] =
+// CHECK-SAME: iter_args(%[[ITERARG3:.*]] = %[[ITERARG2]], %[[ITERARG4:.*]] = %[[TILED_FILL]])
+// CHECK: %[[TILED_INPUT:.*]] = tensor.extract_slice %[[INPUT]]{{\[}}%[[IV0]], %[[IV1]]
+// CHECK: %[[TILED_FAKE:.*]] = tensor.extract_slice %[[FAKE]]{{\[}}%[[IV3]]]
+// CHECK: linalg.pooling_ncw_max
+// CHECK-SAME: ins(%[[TILED_INPUT]], %[[TILED_FAKE]] :
+// CHECK-SAME: outs(%[[ITERARG4]] :
diff --git a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
index 3c0ada9d2cabc..1df1e1dcf7d58 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
@@ -37,6 +37,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
// CHECK: %[[RESULT:.+]]:2 = scf.for %[[IV:[a-zA-Z0-9]+]] =
// CHECK-SAME: iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT1]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]])
// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
@@ -47,7 +48,6 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[GEMM0_TILE:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
// CHECK-SAME: outs(%[[FILL0_TILE]] :
-// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
// CHECK: %[[FILL1_TILE:.+]] = linalg.fill
// CHECK-SAME: outs(%[[INIT1_TILE]] :
diff --git a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-scfforall.mlir b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-scfforall.mlir
index 8fc8f3245be15..f5370cd86dd9f 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-scfforall.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-scfforall.mlir
@@ -37,6 +37,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
// CHECK: %[[RESULT:.+]]:2 = scf.forall (%[[IV:[a-zA-Z0-9]+]]) =
// CHECK-SAME: shared_outs(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT1]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]])
// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
@@ -47,7 +48,6 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[GEMM0_TILE:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
// CHECK-SAME: outs(%[[FILL0_TILE]] :
-// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
// CHECK: %[[FILL1_TILE:.+]] = linalg.fill
// CHECK-SAME: outs(%[[INIT1_TILE]] :
|
|
This PR is intended to resolve this issue. linalg.fill has been created in the wrong location. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need to understand this change a bit better. Please give me some time to understand.
a0dfa45 to
f072894
Compare
f072894 to
414a123
Compare
|
Sorry, I forgot to rebase the code, which caused it to fail the check. |
|
Just to check: this change is mostly about hoisting of extract_slice, for If so, couldn't we just make use of the composability of transform ops and do: Doesn't that already give you the desired result? |
The example you provided does not resolve the issue. What you've done here is run loop-invariant passes against my IR above, but it doesn't work. This PR runs loop-invariant passes immediately after tiling, thereby correctly identifying the position for fuseOp insertion. |
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I looked at this again. I do have concerns with this change. The expectation of tile and fuse is that it creates
(a) a perfectly nested loop that iterates over the loops
(b) it fused into the inner most loop of the given tiled loops.
This transformation is already fairly complex. Adding hoisting in between here is further adding to the complexity. It makes it much more complex to reason about what the tile and fuse methods do if the code suddenly fuses somewhere unexpected.
Apart from complexity, the hoisting is a "heuristic". Depending on the use case, hoisting might be beneficial, and might not be. That really comes down to a cost model. If we add hoisting, what the rationale for a later change to not do "sinking". The best approach as far as I can see is to keep to the steps (a) and (b) I highlighted above. Any susbsequent hoisting/sinking etc should be done as a post processing after tile and fuse.
|
I think I understand what you mean.Thank you. @MaheshRavishankar #164108 |
In tileConsumerAndFuseProducersUsingSCF, the loop-invariant-code-motion and loop-invariant-subset-hoisting passes are used to reorder the extract ops within the loop, ensuring that the fuseOp is created at the correct position. Otherwise, the fuseOp would only be created in the innermost loop.Besides that, hoisting the extractOp is also an optimization.