Skip to content

Conversation

@abulavin
Copy link
Contributor

@abulavin abulavin commented Nov 30, 2025

This PR fixes a subtle bug in the linalg-fuse-elementwise-ops pass where some producers are not being erased after fusion.

The pass relies on the DCE done by the greedy pattern rewriter in order to erase any trivially dead producer ops that no longer have uses after fusion.
However if the producer's region contains ops with memory effects, the producer itself cannot be erased because linalg.generic ops have the RecursiveMemoryEffects trait. This makes the producer not trivially dead even when it has no uses. This causes miscompiles as the body of the 'old' producer is essentially "duplicated", as opposed to "moved", in the IR after fusion.

The fix is to erase the producer within the pattern if it no longer has uses instead of solely relying on DCE.

@llvmbot
Copy link
Member

llvmbot commented Nov 30, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Artemiy Bulavin (abulavin)

Changes

This PR fixes a subtle bug in the linalg-fuse-elementwise-ops pass where some producers are not being erased after fusion.

The pass relies on the DCE done by the greedy pattern rewriter in order to erase any trivially dead producer ops that no longer have uses after fusion.
However if the producer's region contains ops with memory effects, the producer itself cannot be erased because linalg.generic ops have the RecursiveMemoryEffects trait. This makes the producer not trivially dead even when it has no uses. This causes miscompiles as the body of the 'old' producer is essentially repeated in the IR.

The fix is to erase the producer within the pattern if it no longer has uses instead of solely relying on DCE.


Full diff: https://github.com/llvm/llvm-project/pull/170036.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+6)
  • (modified) mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir (+35-1)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 05fc7cbbb90af..b65e5dd35ff3c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -493,6 +493,12 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
         });
       }
       rewriter.eraseOp(genericOp);
+      // If after fusion, the producer no longer has uses, erase it. Usually the
+      // greedy pattern driver takes care of this, however if the producer
+      // contains ops with memory effects it won't be considered trivially dead.
+      if (producer->use_empty())
+        rewriter.eraseOp(producer);
+
       return success();
     }
     return failure();
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index bc55c12c02f29..dedb95db4f4b6 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -1079,4 +1079,38 @@ module {
 // CHECK-NOT:     linalg.generic
 // CHECK:         tensor.expand_shape
 // CHECK:         linalg.generic {{.*}}, iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]}
-// CHECK-SAME:     ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>)
\ No newline at end of file
+// CHECK-SAME:     ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>)
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+
+module {
+  func.func @remove_effectful_producer_after_fusion_if_no_uses(%arg0: bf16, %arg1: memref<2xbf16>, %arg2: memref<2xbf16>) {
+    %c0 = arith.constant 0 : index
+    %0 = tensor.empty() : tensor<2xbf16>
+    %1 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%0 : tensor<2xbf16>) {
+    ^bb0(%out: bf16):
+      %2 = memref.atomic_rmw addf %arg0, %arg1[%c0] : (bf16, memref<2xbf16>) -> bf16
+      linalg.yield %2 : bf16
+    } -> tensor<2xbf16>
+  
+    linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} ins(%1 : tensor<2xbf16>) {
+    ^bb0(%in: bf16):
+      memref.store %in, %arg2[%c0] : memref<2xbf16>
+      linalg.yield
+    }
+    return
+  }
+}
+
+// CHECK-LABEL:   func.func @remove_effectful_producer_after_fusion_if_no_uses
+// CHECK-SAME:      %[[ARG0:.*]]: bf16,
+// CHECK-SAME:      %[[ARG1:.*]]: memref<2xbf16>,
+// CHECK-SAME:      %[[ARG2:.*]]: memref<2xbf16>)
+// CHECK-NOT:       memref.atomic_rmw addf %[[ARG0]], %[[ARG1]]
+// CHECK:           linalg.generic
+// CHECK:             %[[ATOMIC_RMW:.*]] = memref.atomic_rmw addf %[[ARG0]], %[[ARG1]]
+// CHECK:             memref.store %[[ATOMIC_RMW]], %[[ARG2]]
+// CHECK:             linalg.yield %[[ATOMIC_RMW]]
+// CHECK:           return

@abulavin abulavin force-pushed the fix-linalg-elementwise-fusion branch from cc3d406 to af8cd23 Compare November 30, 2025 15:01
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am sympathetic to your use case, but I think this is the wrong place to make this change. I would rather add an option DCE, or greedy pattern rewriter to erase even with side-effecting operations. I would strongly suggest writing your own pass to remove such ops.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants