Skip to content

Commit 834133c

Browse files
Anand Kodnanidcaballe
authored andcommitted
[MLIR] Vector store to load forwarding
The MemRefDataFlow pass does store to load forwarding only for affine store/loads. This patch updates the pass to use affine read/write interface which enables vector forwarding. Reviewed By: dcaballe, bondhugula, ftynse Differential Revision: https://reviews.llvm.org/D84302
1 parent 394db22 commit 834133c

File tree

4 files changed

+51
-9
lines changed

4 files changed

+51
-9
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,16 @@ def AffineReadOpInterface : OpInterface<"AffineReadOpInterface"> {
8181
op.getAffineMapAttr()};
8282
}]
8383
>,
84+
InterfaceMethod<
85+
/*desc=*/"Returns the value read by this operation.",
86+
/*retTy=*/"Value",
87+
/*methodName=*/"getValue",
88+
/*args=*/(ins),
89+
/*methodBody=*/[{}],
90+
/*defaultImplementation=*/[{
91+
return cast<ConcreteOp>(this->getOperation());
92+
}]
93+
>,
8494
];
8595
}
8696

@@ -150,6 +160,17 @@ def AffineWriteOpInterface : OpInterface<"AffineWriteOpInterface"> {
150160
op.getAffineMapAttr()};
151161
}]
152162
>,
163+
InterfaceMethod<
164+
/*desc=*/"Returns the value to store.",
165+
/*retTy=*/"Value",
166+
/*methodName=*/"getValueToStore",
167+
/*args=*/(ins),
168+
/*methodBody=*/[{}],
169+
/*defaultImplementation=*/[{
170+
ConcreteOp op = cast<ConcreteOp>(this->getOperation());
171+
return op.getOperand(op.getStoredValOperandIndex());
172+
}]
173+
>,
153174
];
154175
}
155176

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -725,8 +725,8 @@ class AffineStoreOpBase<string mnemonic, list<OpTrait> traits = []> :
725725
Affine_Op<mnemonic, !listconcat(traits,
726726
[DeclareOpInterfaceMethods<AffineWriteOpInterface>])> {
727727
code extraClassDeclarationBase = [{
728-
/// Get value to be stored by store operation.
729-
Value getValueToStore() { return getOperand(0); }
728+
/// Returns the operand index of the value to be stored.
729+
unsigned getStoredValOperandIndex() { return 0; }
730730

731731
/// Returns the operand index of the memref.
732732
unsigned getMemRefOperandIndex() { return 1; }

mlir/lib/Transforms/MemRefDataFlowOpt.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ namespace {
6363
struct MemRefDataFlowOpt : public MemRefDataFlowOptBase<MemRefDataFlowOpt> {
6464
void runOnFunction() override;
6565

66-
void forwardStoreToLoad(AffineLoadOp loadOp);
66+
void forwardStoreToLoad(AffineReadOpInterface loadOp);
6767

6868
// A list of memref's that are potentially dead / could be eliminated.
6969
SmallPtrSet<Value, 4> memrefsToErase;
@@ -84,14 +84,14 @@ std::unique_ptr<OperationPass<FuncOp>> mlir::createMemRefDataFlowOptPass() {
8484

8585
// This is a straightforward implementation not optimized for speed. Optimize
8686
// if needed.
87-
void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) {
87+
void MemRefDataFlowOpt::forwardStoreToLoad(AffineReadOpInterface loadOp) {
8888
// First pass over the use list to get the minimum number of surrounding
8989
// loops common between the load op and the store op, with min taken across
9090
// all store ops.
9191
SmallVector<Operation *, 8> storeOps;
9292
unsigned minSurroundingLoops = getNestingDepth(loadOp);
9393
for (auto *user : loadOp.getMemRef().getUsers()) {
94-
auto storeOp = dyn_cast<AffineStoreOp>(user);
94+
auto storeOp = dyn_cast<AffineWriteOpInterface>(user);
9595
if (!storeOp)
9696
continue;
9797
unsigned nsLoops = getNumCommonSurroundingLoops(*loadOp, *storeOp);
@@ -167,8 +167,9 @@ void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) {
167167
return;
168168

169169
// Perform the actual store to load forwarding.
170-
Value storeVal = cast<AffineStoreOp>(lastWriteStoreOp).getValueToStore();
171-
loadOp.replaceAllUsesWith(storeVal);
170+
Value storeVal =
171+
cast<AffineWriteOpInterface>(lastWriteStoreOp).getValueToStore();
172+
loadOp.getValue().replaceAllUsesWith(storeVal);
172173
// Record the memref for a later sweep to optimize away.
173174
memrefsToErase.insert(loadOp.getMemRef());
174175
// Record this to erase later.
@@ -190,7 +191,7 @@ void MemRefDataFlowOpt::runOnFunction() {
190191
memrefsToErase.clear();
191192

192193
// Walk all load's and perform store to load forwarding.
193-
f.walk([&](AffineLoadOp loadOp) { forwardStoreToLoad(loadOp); });
194+
f.walk([&](AffineReadOpInterface loadOp) { forwardStoreToLoad(loadOp); });
194195

195196
// Erase all load op's whose results were replaced with store fwd'ed ones.
196197
for (auto *loadOp : loadOpsToErase)
@@ -207,7 +208,7 @@ void MemRefDataFlowOpt::runOnFunction() {
207208
// could still erase it if the call had no side-effects.
208209
continue;
209210
if (llvm::any_of(memref.getUsers(), [&](Operation *ownerOp) {
210-
return !isa<AffineStoreOp, DeallocOp>(ownerOp);
211+
return !isa<AffineWriteOpInterface, DeallocOp>(ownerOp);
211212
}))
212213
continue;
213214

mlir/test/Transforms/memref-dataflow-opt.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,3 +280,23 @@ func @refs_not_known_to_be_equal(%A : memref<100 x 100 x f32>, %M : index) {
280280
}
281281
return
282282
}
283+
284+
// The test checks for value forwarding from vector stores to vector loads.
285+
// The value loaded from %in can directly be stored to %out by eliminating
286+
// store and load from %tmp.
287+
func @vector_forwarding(%in : memref<512xf32>, %out : memref<512xf32>) {
288+
%tmp = alloc() : memref<512xf32>
289+
affine.for %i = 0 to 16 {
290+
%ld0 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32>
291+
affine.vector_store %ld0, %tmp[32*%i] : memref<512xf32>, vector<32xf32>
292+
%ld1 = affine.vector_load %tmp[32*%i] : memref<512xf32>, vector<32xf32>
293+
affine.vector_store %ld1, %out[32*%i] : memref<512xf32>, vector<32xf32>
294+
}
295+
return
296+
}
297+
298+
// CHECK-LABEL: func @vector_forwarding
299+
// CHECK: affine.for %{{.*}} = 0 to 16 {
300+
// CHECK-NEXT: %[[LDVAL:.*]] = affine.vector_load
301+
// CHECK-NEXT: affine.vector_store %[[LDVAL]],{{.*}}
302+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)