Skip to content

Commit

Permalink
[MLIR][Affine] Improve load elimination
Browse files Browse the repository at this point in the history
Fixes #62639.

Differential Revision: https://reviews.llvm.org/D154769
  • Loading branch information
rikhuijzer committed Jul 9, 2023
1 parent 758c464 commit 71513a7
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 16 deletions.
36 changes: 20 additions & 16 deletions mlir/lib/Dialect/Affine/Utils/Utils.cpp
Expand Up @@ -862,9 +862,10 @@ bool mlir::affine::hasNoInterveningEffect(Operation *start, T memOp) {
/// other operations will overwrite the memory loaded between the given load
/// and store. If such a value exists, the replaced `loadOp` will be added to
/// `loadOpsToErase` and its memref will be added to `memrefsToErase`.
static LogicalResult forwardStoreToLoad(
AffineReadOpInterface loadOp, SmallVectorImpl<Operation *> &loadOpsToErase,
SmallPtrSetImpl<Value> &memrefsToErase, DominanceInfo &domInfo) {
static void forwardStoreToLoad(AffineReadOpInterface loadOp,
SmallVectorImpl<Operation *> &loadOpsToErase,
SmallPtrSetImpl<Value> &memrefsToErase,
DominanceInfo &domInfo) {

// The store op candidate for forwarding that satisfies all conditions
// to replace the load, if any.
Expand Down Expand Up @@ -911,21 +912,20 @@ static LogicalResult forwardStoreToLoad(
}

if (!lastWriteStoreOp)
return failure();
return;

// Perform the actual store to load forwarding.
Value storeVal =
cast<AffineWriteOpInterface>(lastWriteStoreOp).getValueToStore();
// Check if 2 values have the same shape. This is needed for affine vector
// loads and stores.
if (storeVal.getType() != loadOp.getValue().getType())
return failure();
return;
loadOp.getValue().replaceAllUsesWith(storeVal);
// Record the memref for a later sweep to optimize away.
memrefsToErase.insert(loadOp.getMemRef());
// Record this to erase later.
loadOpsToErase.push_back(loadOp);
return success();
}

template bool
Expand Down Expand Up @@ -995,16 +995,16 @@ static void loadCSE(AffineReadOpInterface loadA,
MemRefAccess srcAccess(loadB);
MemRefAccess destAccess(loadA);

// 1. The accesses have to be to the same location.
// 1. The accesses should be to be to the same location.
if (srcAccess != destAccess) {
continue;
}

// 2. The store has to dominate the load op to be candidate.
// 2. loadB should dominate loadA.
if (!domInfo.dominates(loadB, loadA))
continue;

// 3. There is no write between loadA and loadB.
// 3. There should not be a write between loadA and loadB.
if (!affine::hasNoInterveningEffect<MemoryEffects::Write>(
loadB.getOperation(), loadA))
continue;
Expand Down Expand Up @@ -1073,13 +1073,8 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,

// Walk all load's and perform store to load forwarding.
f.walk([&](AffineReadOpInterface loadOp) {
if (failed(
forwardStoreToLoad(loadOp, opsToErase, memrefsToErase, domInfo))) {
loadCSE(loadOp, opsToErase, domInfo);
}
forwardStoreToLoad(loadOp, opsToErase, memrefsToErase, domInfo);
});

// Erase all load op's whose results were replaced with store fwd'ed ones.
for (auto *op : opsToErase)
op->erase();
opsToErase.clear();
Expand All @@ -1088,9 +1083,9 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
f.walk([&](AffineWriteOpInterface storeOp) {
findUnusedStore(storeOp, opsToErase, postDomInfo);
});
// Erase all store op's which don't impact the program
for (auto *op : opsToErase)
op->erase();
opsToErase.clear();

// Check if the store fwd'ed memrefs are now left with only stores and
// deallocs and can thus be completely deleted. Note: the canonicalize pass
Expand All @@ -1114,6 +1109,15 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
user->erase();
defOp->erase();
}

// To eliminate as many loads as possible, run load CSE after eliminating
// stores. Otherwise, some stores are wrongly seen as having an intervening
// effect.
f.walk([&](AffineReadOpInterface loadOp) {
loadCSE(loadOp, opsToErase, domInfo);
});
for (auto *op : opsToErase)
op->erase();
}

// Perform the replacement in `op`.
Expand Down
25 changes: 25 additions & 0 deletions mlir/test/Dialect/Affine/scalrep.mlir
Expand Up @@ -280,6 +280,31 @@ func.func @refs_not_known_to_be_equal(%A : memref<100 x 100 x f32>, %M : index)
return
}

// CHECK-LABEL: func @elim_load_after_store
func.func @elim_load_after_store(%arg0: memref<100xf32>, %arg1: memref<100xf32>) {
%alloc = memref.alloc() : memref<1xf32>
%alloc_0 = memref.alloc() : memref<1xf32>
// CHECK: affine.for
affine.for %arg2 = 0 to 100 {
// CHECK: affine.load
%0 = affine.load %arg0[%arg2] : memref<100xf32>
%1 = affine.load %arg0[%arg2] : memref<100xf32>
// CHECK: arith.addf
%2 = arith.addf %0, %1 : f32
affine.store %2, %alloc_0[0] : memref<1xf32>
%3 = affine.load %arg0[%arg2] : memref<100xf32>
%4 = affine.load %alloc_0[0] : memref<1xf32>
// CHECK-NEXT: arith.addf
%5 = arith.addf %3, %4 : f32
affine.store %5, %alloc[0] : memref<1xf32>
%6 = affine.load %arg0[%arg2] : memref<100xf32>
%7 = affine.load %alloc[0] : memref<1xf32>
%8 = arith.addf %6, %7 : f32
affine.store %8, %arg1[%arg2] : memref<100xf32>
}
return
}

// The test checks for value forwarding from vector stores to vector loads.
// The value loaded from %in can directly be stored to %out by eliminating
// store and load from %tmp.
Expand Down

0 comments on commit 71513a7

Please sign in to comment.