diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp index 1ec5fbfef50c3..c7d88e64c7d0a 100644 --- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp @@ -1015,7 +1015,8 @@ struct GreedyFusion { // redundant execution of the source happens (1:1 pointwise dep on the // producer-consumer memref access for example). Check this and allow // fusion accordingly. - if (hasCyclicDependence(srcAffineForOp)) { + bool srcHasCyclicDep = hasCyclicDependence(srcAffineForOp); + if (srcHasCyclicDep) { LDBG() << "Source nest has a cyclic dependence."; // Maximal fusion does not check for compute tolerance threshold; so // perform the maximal fusion only when the redundanation computation @@ -1075,6 +1076,17 @@ struct GreedyFusion { srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs, *mdg); + // If the source loop has a cyclic dependence (e.g., a reduction that + // reads and writes the same memref) and cannot be removed after fusion, + // skip this fusion. Fusing a cyclic source without removing it would + // result in its cyclic computation executing twice: once in the + // original source and once in the fused copy. + if (srcHasCyclicDep && !removeSrcNode) { + LDBG() << "Can't fuse: source has cyclic dependence and " + << "can't be removed after fusion"; + continue; + } + DenseSet privateMemrefs; for (Value memref : producerConsumerMemrefs) { if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId, diff --git a/mlir/test/Dialect/Affine/loop-fusion-3.mlir b/mlir/test/Dialect/Affine/loop-fusion-3.mlir index 70d6c82105543..4d329d54fcaa8 100644 --- a/mlir/test/Dialect/Affine/loop-fusion-3.mlir +++ b/mlir/test/Dialect/Affine/loop-fusion-3.mlir @@ -1294,5 +1294,91 @@ func.func @unknown_memref_def_op() { } func.func private @bar() -> memref<10xf32> +// ----- + +// CHECK-LABEL: func @no_double_reduction_cyclic_src_non_removable +// Test that a cyclic source loop (reduction) is not fused as a separate copy +// into a consumer when the fusion cannot remove the source. Without this check, +// the reduction would run twice, producing incorrect results. +// The reinterpret_cast makes %acc an "escaping" memref, which forces an +// isMaximal check. The slice is non-maximal (consumer only covers a subset of +// the producer's iteration space), making the source non-removable. +// The fix ensures fusion is skipped in that case, allowing the loops to be +// correctly combined later when the source can be fully removed. +// +// CHECK: affine.for +// CHECK-NOT: affine.for +// CHECK: affine.for +// CHECK-NOT: affine.for +// CHECK: affine.for +// CHECK-NOT: affine.for +// CHECK: affine.for +// CHECK-NOT: affine.for +// CHECK: } +// CHECK-NOT: affine.for +// CHECK: } +// CHECK-NOT: affine.for +// CHECK: } +// CHECK-NOT: affine.for +// CHECK: } +// CHECK-NOT: affine.for +// CHECK: } +// CHECK-NOT: affine.for +func.func private @printMemrefF32(memref<*xf32>) +func.func @no_double_reduction_cyclic_src_non_removable() { + %cst = arith.constant 5.000000e-01 : f32 + %cst_0 = arith.constant 1.000000e+00 : f32 + %cst_1 = arith.constant 0.000000e+00 : f32 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<3x1x7x5xf32> + // Init loop (producer) + affine.for %arg0 = 0 to 3 { + affine.for %arg1 = 0 to 1 { + affine.for %arg2 = 0 to 7 { + affine.for %arg3 = 0 to 5 { + affine.store %cst, %alloc[%arg0, %arg1, %arg2, %arg3] : memref<3x1x7x5xf32> + } + } + } + } + // Accumulator via reinterpret_cast (makes it "escaping", triggers isMaximal check) + %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<7x5xf32> + %reinterpret_cast = memref.reinterpret_cast %alloc_2 to offset: [0], sizes: [1, 7, 5], strides: [35, 5, 1] : memref<7x5xf32> to memref<1x7x5xf32> + // Zero-init accumulator + affine.for %arg0 = 0 to 1 { + affine.for %arg1 = 0 to 7 { + affine.for %arg2 = 0 to 5 { + affine.store %cst_1, %reinterpret_cast[%arg0, %arg1, %arg2] : memref<1x7x5xf32> + } + } + } + // Cyclic reduction loop: reads and writes %reinterpret_cast + affine.for %arg0 = 0 to 3 { + affine.for %arg1 = 0 to 1 { + affine.for %arg2 = 0 to 7 { + affine.for %arg3 = 0 to 5 { + %0 = affine.load %alloc[%arg0, %arg1, %arg2, %arg3] : memref<3x1x7x5xf32> + %1 = affine.load %reinterpret_cast[%arg1, %arg2, %arg3] : memref<1x7x5xf32> + %2 = arith.addf %0, %1 : f32 + affine.store %2, %reinterpret_cast[%arg1, %arg2, %arg3] : memref<1x7x5xf32> + } + } + } + } + // Consumer loop (sigmoid) + %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<1x7x5xf32> + affine.for %arg0 = 0 to 1 { + affine.for %arg1 = 0 to 7 { + affine.for %arg2 = 0 to 5 { + %0 = affine.load %reinterpret_cast[%arg0, %arg1, %arg2] : memref<1x7x5xf32> + %1 = arith.negf %0 : f32 + %2 = math.exp %1 : f32 + %3 = arith.addf %2, %cst_0 : f32 + %4 = arith.divf %cst_0, %3 : f32 + affine.store %4, %alloc_3[%arg0, %arg1, %arg2] : memref<1x7x5xf32> + } + } + } + return +} // Add further tests in mlir/test/Transforms/loop-fusion-4.mlir