diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 05fc7cbbb90af..b65e5dd35ff3c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -493,6 +493,12 @@ class FuseElementwiseOps : public OpRewritePattern { }); } rewriter.eraseOp(genericOp); + // If after fusion, the producer no longer has uses, erase it. Usually the + // greedy pattern driver takes care of this, however if the producer + // contains ops with memory effects it won't be considered trivially dead. + if (producer->use_empty()) + rewriter.eraseOp(producer); + return success(); } return failure(); diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir index 6f1a422324e08..b47aeb2812210 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -1056,3 +1056,37 @@ module { // CHECK: tensor.expand_shape // CHECK: linalg.generic {{.*}}, iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]} // CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>) + +// ----- + +#map = affine_map<(d0) -> (d0)> + +module { + func.func @remove_effectful_producer_after_fusion_if_no_uses(%arg0: bf16, %arg1: memref<2xbf16>, %arg2: memref<2xbf16>) { + %c0 = arith.constant 0 : index + %0 = tensor.empty() : tensor<2xbf16> + %1 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%0 : tensor<2xbf16>) { + ^bb0(%out: bf16): + %2 = memref.atomic_rmw addf %arg0, %arg1[%c0] : (bf16, memref<2xbf16>) -> bf16 + linalg.yield %2 : bf16 + } -> tensor<2xbf16> + + linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} ins(%1 : tensor<2xbf16>) { + ^bb0(%in: bf16): + memref.store %in, %arg2[%c0] : memref<2xbf16> + linalg.yield + } + return + } +} + +// CHECK-LABEL: func.func @remove_effectful_producer_after_fusion_if_no_uses +// CHECK-SAME: %[[ARG0:.*]]: bf16, +// CHECK-SAME: %[[ARG1:.*]]: memref<2xbf16>, +// CHECK-SAME: %[[ARG2:.*]]: memref<2xbf16>) +// CHECK-NOT: memref.atomic_rmw addf %[[ARG0]], %[[ARG1]] +// CHECK: linalg.generic +// CHECK: %[[ATOMIC_RMW:.*]] = memref.atomic_rmw addf %[[ARG0]], %[[ARG1]] +// CHECK: memref.store %[[ATOMIC_RMW]], %[[ARG2]] +// CHECK: linalg.yield %[[ATOMIC_RMW]] +// CHECK: return