-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[MLIR][linalg] Erase effectful producer after fusion if it's no longer used #170036
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Artemiy Bulavin (abulavin) ChangesThis PR fixes a subtle bug in the The pass relies on the DCE done by the greedy pattern rewriter in order to erase any trivially dead producer ops that no longer have uses after fusion. The fix is to erase the producer within the pattern if it no longer has uses instead of solely relying on DCE. Full diff: https://github.com/llvm/llvm-project/pull/170036.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 05fc7cbbb90af..b65e5dd35ff3c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -493,6 +493,12 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
});
}
rewriter.eraseOp(genericOp);
+ // If after fusion, the producer no longer has uses, erase it. Usually the
+ // greedy pattern driver takes care of this, however if the producer
+ // contains ops with memory effects it won't be considered trivially dead.
+ if (producer->use_empty())
+ rewriter.eraseOp(producer);
+
return success();
}
return failure();
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index bc55c12c02f29..dedb95db4f4b6 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -1079,4 +1079,38 @@ module {
// CHECK-NOT: linalg.generic
// CHECK: tensor.expand_shape
// CHECK: linalg.generic {{.*}}, iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]}
-// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>)
\ No newline at end of file
+// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>)
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+
+module {
+ func.func @remove_effectful_producer_after_fusion_if_no_uses(%arg0: bf16, %arg1: memref<2xbf16>, %arg2: memref<2xbf16>) {
+ %c0 = arith.constant 0 : index
+ %0 = tensor.empty() : tensor<2xbf16>
+ %1 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%0 : tensor<2xbf16>) {
+ ^bb0(%out: bf16):
+ %2 = memref.atomic_rmw addf %arg0, %arg1[%c0] : (bf16, memref<2xbf16>) -> bf16
+ linalg.yield %2 : bf16
+ } -> tensor<2xbf16>
+
+ linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} ins(%1 : tensor<2xbf16>) {
+ ^bb0(%in: bf16):
+ memref.store %in, %arg2[%c0] : memref<2xbf16>
+ linalg.yield
+ }
+ return
+ }
+}
+
+// CHECK-LABEL: func.func @remove_effectful_producer_after_fusion_if_no_uses
+// CHECK-SAME: %[[ARG0:.*]]: bf16,
+// CHECK-SAME: %[[ARG1:.*]]: memref<2xbf16>,
+// CHECK-SAME: %[[ARG2:.*]]: memref<2xbf16>)
+// CHECK-NOT: memref.atomic_rmw addf %[[ARG0]], %[[ARG1]]
+// CHECK: linalg.generic
+// CHECK: %[[ATOMIC_RMW:.*]] = memref.atomic_rmw addf %[[ARG0]], %[[ARG1]]
+// CHECK: memref.store %[[ATOMIC_RMW]], %[[ARG2]]
+// CHECK: linalg.yield %[[ATOMIC_RMW]]
+// CHECK: return
|
cc3d406 to
af8cd23
Compare
MaheshRavishankar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am sympathetic to your use case, but I think this is the wrong place to make this change. I would rather add an option DCE, or greedy pattern rewriter to erase even with side-effecting operations. I would strongly suggest writing your own pass to remove such ops.
This PR fixes a subtle bug in the
linalg-fuse-elementwise-opspass where some producers are not being erased after fusion.The pass relies on the DCE done by the greedy pattern rewriter in order to erase any trivially dead producer ops that no longer have uses after fusion.
However if the producer's region contains ops with memory effects, the producer itself cannot be erased because
linalg.genericops have theRecursiveMemoryEffectstrait. This makes the producer not trivially dead even when it has no uses. This causes miscompiles as the body of the 'old' producer is essentially "duplicated", as opposed to "moved", in the IR after fusion.The fix is to erase the producer within the pattern if it no longer has uses instead of solely relying on DCE.