Skip to content

Commit

Permalink
[MLIR][Affine] Fix/improve affine sibling fusion
Browse files Browse the repository at this point in the history
The sibling fusion profitability checks shouldn't rely on the presence
of a store op in the sibling. The reuse is between the loads.

Fixes issues raised at https://discourse.llvm.org/t/understanding-the-affine-loop-fusion-pass/69452

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D146763
  • Loading branch information
bondhugula committed Mar 25, 2023
1 parent 22ebb49 commit 721defb
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 35 deletions.
14 changes: 7 additions & 7 deletions mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
Expand Up @@ -1785,9 +1785,6 @@ struct GreedyFusion {
// Currently findSiblingNodeToFuse searches for siblings with one load.
assert(sibLoadOpInsts.size() == 1);
Operation *sibLoadOpInst = sibLoadOpInsts[0];
assert(!sibNode->stores.empty());
// TODO: Choose the store which postdominates all other stores.
auto *sibStoreOpInst = sibNode->stores.back();

// Gather 'dstNode' load ops to 'memref'.
SmallVector<Operation *, 2> dstLoadOpInsts;
Expand Down Expand Up @@ -1818,8 +1815,11 @@ struct GreedyFusion {

unsigned bestDstLoopDepth = maxLegalFusionDepth;
if (!maximalFusion) {
// Check if fusion would be profitable.
if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstAffineForOp,
// Check if fusion would be profitable. For sibling fusion, the sibling
// load op is treated as the src "store" op for fusion profitability
// purposes. The footprint of the load in the slice relative to the
// unfused source's determines reuse.
if (!isFusionProfitable(sibLoadOpInst, sibLoadOpInst, dstAffineForOp,
depthSliceUnions, maxLegalFusionDepth,
&bestDstLoopDepth, computeToleranceThreshold))
continue;
Expand Down Expand Up @@ -1875,13 +1875,13 @@ struct GreedyFusion {
}))
return false;

// Check that all stores are to the same memref.
// Check that all stores are to the same memref if any.
DenseSet<Value> storeMemrefs;
for (auto *storeOpInst : sibNode->stores) {
storeMemrefs.insert(
cast<AffineWriteOpInterface>(storeOpInst).getMemRef());
}
if (storeMemrefs.size() != 1)
if (storeMemrefs.size() > 1)
return false;

// Skip if a memref value in one node is used by a non-affine memref
Expand Down
40 changes: 20 additions & 20 deletions mlir/test/Transforms/loop-fusion-2.mlir
Expand Up @@ -587,32 +587,32 @@ func.func @fuse_across_varying_dims_complex(%arg0: f32) {
// MAXIMAL-NEXT: memref.alloc() : memref<2x2x3x3x16x1xf32>
// MAXIMAL-NEXT: memref.alloc() : memref<144x4xf32>
// MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 9 {
// MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 9 {
// MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 4 {
// MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 16 {
// MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 64 {
// MAXIMAL-NEXT: affine.apply [[$MAP0]](%{{.*}}, %{{.*}})
// MAXIMAL-NEXT: affine.apply [[$MAP1]](%{{.*}}, %{{.*}})
// MAXIMAL-NEXT: affine.apply [[$MAP2]](%{{.*}}, %{{.*}})
// MAXIMAL-NEXT: affine.apply [[$MAP3]](%{{.*}}, %{{.*}})
// MAXIMAL-NEXT: affine.apply [[$MAP4]](%{{.*}}, %{{.*}})
// MAXIMAL-NEXT: affine.load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref<2x2x3x3x16x1xf32>
// MAXIMAL-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, 0] : memref<64x1xf32>
// MAXIMAL-NEXT: }
// MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 4 {
// MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 16 {
// MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 64 {
// MAXIMAL-NEXT: affine.apply [[$MAP0]](%{{.*}}, %{{.*}})
// MAXIMAL-NEXT: affine.apply [[$MAP1]](%{{.*}}, %{{.*}})
// MAXIMAL-NEXT: affine.apply [[$MAP2]](%{{.*}}, %{{.*}})
// MAXIMAL-NEXT: affine.apply [[$MAP3]](%{{.*}}, %{{.*}})
// MAXIMAL-NEXT: affine.apply [[$MAP4]](%{{.*}}, %{{.*}})
// MAXIMAL-NEXT: affine.load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref<2x2x3x3x16x1xf32>
// MAXIMAL-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, 0] : memref<64x1xf32>
// MAXIMAL-NEXT: }
// MAXIMAL-NEXT: affine.apply [[$MAP7]](%{{.*}}, %{{.*}})
// MAXIMAL-NEXT: affine.load %{{.*}}[%{{.*}} * 16 + %{{.*}}, 0] : memref<64x1xf32>
// MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 9 {
// MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 4 {
// MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 16 {
// MAXIMAL-NEXT: affine.apply [[$MAP7]](%{{.*}}, %{{.*}})
// MAXIMAL-NEXT: affine.load %{{.*}}[%{{.*}} * 16 + %{{.*}}, 0] : memref<64x1xf32>
// MAXIMAL-NEXT: }
// MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 16 {
// MAXIMAL-NEXT: affine.apply [[$MAP7]](%{{.*}}, %{{.*}})
// MAXIMAL-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<144x4xf32>
// MAXIMAL-NEXT: affine.apply [[$MAP8]](%{{.*}}, %{{.*}})
// MAXIMAL-NEXT: affine.load %{{.*}}[%{{.*}} * 16 - %{{.*}} + 15, 0] : memref<64x1xf32>
// MAXIMAL-NEXT: }
// MAXIMAL-NEXT: }
// MAXIMAL-NEXT: affine.apply [[$MAP8]](%{{.*}}, %{{.*}})
// MAXIMAL-NEXT: affine.load %{{.*}}[%{{.*}} * 16 - %{{.*}} + 15, 0] : memref<64x1xf32>
// MAXIMAL-NEXT: }
// MAXIMAL-NEXT: }
// MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 16 {
// MAXIMAL-NEXT: affine.apply [[$MAP7]](%{{.*}}, %{{.*}})
// MAXIMAL-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<144x4xf32>
// MAXIMAL-NEXT: }
// MAXIMAL-NEXT: }
// MAXIMAL-NEXT: }

Expand Down
16 changes: 16 additions & 0 deletions mlir/test/Transforms/loop-fusion-4.mlir
Expand Up @@ -144,6 +144,22 @@ func.func @reduce_add_non_maximal_f32_f32(%arg0: memref<64x64xf32, 1>, %arg1 : m

// -----

// SIBLING-MAXIMAL-LABEL: func @sibling_load_only
func.func @sibling_load_only(%arg0: memref<10xf32>) {
affine.for %arg1 = 0 to 10 {
%0 = affine.load %arg0[%arg1] : memref<10xf32>
}
affine.for %arg1 = 0 to 10 {
%0 = affine.load %arg0[%arg1] : memref<10xf32>
}
// SIBLING-MAXIMAL-NEXT: affine.for
// SIBLING-MAXIMAL-NEXT: affine.load
// SIBLING-MAXIMAL-NEXT: affine.load
return
}

// -----

// PRODUCER-CONSUMER-LABEL: func @fusion_for_multiple_blocks() {
func.func @fusion_for_multiple_blocks() {
^bb0:
Expand Down
11 changes: 3 additions & 8 deletions mlir/test/Transforms/loop-fusion.mlir
Expand Up @@ -1189,8 +1189,8 @@ func.func @should_fuse_at_depth1_with_trip_count_19() {

// -----

// CHECK-LABEL: func @should_fuse_with_private_memrefs_with_diff_shapes() {
func.func @should_fuse_with_private_memrefs_with_diff_shapes() {
// CHECK-LABEL: func @should_fuse_with_private_memref() {
func.func @should_fuse_with_private_memref() {
%m = memref.alloc() : memref<100xf32>
%cf7 = arith.constant 7.0 : f32

Expand All @@ -1203,16 +1203,11 @@ func.func @should_fuse_with_private_memrefs_with_diff_shapes() {
affine.for %i2 = 0 to 82 {
%v1 = affine.load %m[%i2] : memref<100xf32>
}
// Should create two new private memrefs customized to the shapes accessed
// by loops %{{.*}} and %{{.*}}.
// CHECK-DAG: memref.alloc() : memref<1xf32>
// Should create a new private memref.
// CHECK-DAG: memref.alloc() : memref<1xf32>
// CHECK: affine.for %{{.*}} = 0 to 17 {
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
// CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
// CHECK-NEXT: }
// CHECK-NEXT: affine.for %{{.*}} = 0 to 82 {
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
// CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
// CHECK-NEXT: }
// CHECK-NEXT: return
Expand Down

0 comments on commit 721defb

Please sign in to comment.