Skip to content

Commit

Permalink
[mlir] Add hoisting of transfer ops in affine loops
Browse files Browse the repository at this point in the history
The only way to do this with the current hoisting strategy is by
lowering Affine to Scf first, but that prevents further passes on
Affine.

Differential Revision: https://reviews.llvm.org/D137600
  • Loading branch information
jsetoain committed Dec 7, 2022
1 parent 19cde2d commit da291ba
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 11 deletions.
50 changes: 39 additions & 11 deletions mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
Expand Up @@ -14,7 +14,9 @@
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
Expand All @@ -29,6 +31,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"

using llvm::dbgs;
Expand Down Expand Up @@ -425,10 +428,10 @@ void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) {

LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
<< *transferRead.getOperation() << "\n");
auto loop = dyn_cast<scf::ForOp>(transferRead->getParentOp());
auto loop = dyn_cast<LoopLikeOpInterface>(transferRead->getParentOp());
LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp()
<< "\n");
if (!loop)
if (!isa_and_nonnull<scf::ForOp, AffineForOp>(loop))
return WalkResult::advance();

LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
Expand Down Expand Up @@ -513,18 +516,43 @@ void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) {
ArrayRef<BlockArgument> newBBArgs) {
return SmallVector<Value>{transferWrite.getVector()};
};
auto newForOp =
replaceLoopWithNewYields(b, loop, transferRead.getVector(), yieldFn);

// Transfer write has been hoisted, need to update the written vector by
// the value yielded by the newForOp.
transferWrite.getVectorMutable().assign(newForOp.getResults().back());

changed = true;
loop.erase();
// Need to interrupt and restart because erasing the loop messes up the
// walk.
return WalkResult::interrupt();
return TypeSwitch<Operation *, WalkResult>(loop)
.Case<scf::ForOp>([&](scf::ForOp scfForOp) {
auto newForOp = replaceLoopWithNewYields(
b, scfForOp, transferRead.getVector(), yieldFn);
transferWrite.getVectorMutable().assign(
newForOp.getResults().back());
changed = true;
loop.erase();
// Need to interrupt and restart because erasing the loop messes up
// the walk.
return WalkResult::interrupt();
})
.Case<AffineForOp>([&](AffineForOp affineForOp) {
auto newForOp = replaceForOpWithNewYields(
b, affineForOp, transferRead.getVector(),
SmallVector<Value>{transferWrite.getVector()},
transferWrite.getVector());
// Replace all uses of the `transferRead` with the corresponding
// basic block argument.
transferRead.getVector().replaceUsesWithIf(
newForOp.getLoopBody().getArguments().back(),
[&](OpOperand &use) {
Operation *user = use.getOwner();
return newForOp->isProperAncestor(user);
});
transferWrite.getVectorMutable().assign(
newForOp.getResults().back());
changed = true;
loop.erase();
// Need to interrupt and restart because erasing the loop messes up
// the walk.
return WalkResult::interrupt();
})
.Default([](Operation *) { return WalkResult::interrupt(); });
});
}
}
36 changes: 36 additions & 0 deletions mlir/test/Dialect/Linalg/hoisting.mlir
Expand Up @@ -469,3 +469,39 @@ func.func @hoist_vector_transfer_write_pairs_disjoint_tensor(
return %1 : tensor<?x?xf32>
}

// -----

// CHECK-LABEL: func @hoist_vector_transfer_pairs_in_affine_loops(
// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]+]]: memref<64x64xi32>,
// CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]+]]: memref<64x64xi32>,
// CHECK-SAME: %[[MEMREF2:[a-zA-Z0-9]+]]: memref<64x64xi32>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
// CHECK: affine.for %[[I:.*]] = 0 to 64 {
// CHECK: affine.for %[[J:.*]] = 0 to 64 step 16 {
// CHECK: %[[R0:.*]] = vector.transfer_read %[[MEMREF2]][%[[I]], %[[J]]], %[[C0]] : memref<64x64xi32>, vector<16xi32>
// CHECK: %[[R:.*]] = affine.for %[[K:.*]] = 0 to 64 iter_args(%[[ACC:.*]] = %[[R0]]) -> (vector<16xi32>) {
// CHECK: %[[AV:.*]] = vector.transfer_read %[[MEMREF0]][%[[I]], %[[K]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32>
// CHECK: %[[BV:.*]] = vector.transfer_read %[[MEMREF1]][%[[K]], %[[J]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32>
// CHECK: %[[T0:.*]] = arith.muli %[[AV]], %[[BV]] : vector<16xi32>
// CHECK: %[[T1:.*]] = arith.addi %[[ACC]], %[[T0]] : vector<16xi32>
// CHECK: affine.yield %[[T1]] : vector<16xi32>
// CHECK: }
// CHECK: vector.transfer_write %[[R]], %[[MEMREF2]][%[[I]], %[[J]]] : vector<16xi32>, memref<64x64xi32>
// CHECK: }
// CHECK: }
func.func @hoist_vector_transfer_pairs_in_affine_loops(%memref0: memref<64x64xi32>, %memref1: memref<64x64xi32>, %memref2: memref<64x64xi32>) {
%c0_i32 = arith.constant 0 : i32
affine.for %arg3 = 0 to 64 {
affine.for %arg4 = 0 to 64 step 16 {
affine.for %arg5 = 0 to 64 {
%0 = vector.transfer_read %memref0[%arg3, %arg5], %c0_i32 {permutation_map = affine_map<(d0, d1) -> (0)>} : memref<64x64xi32>, vector<16xi32>
%1 = vector.transfer_read %memref1[%arg5, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32>
%2 = vector.transfer_read %memref2[%arg3, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32>
%3 = arith.muli %0, %1 : vector<16xi32>
%4 = arith.addi %2, %3 : vector<16xi32>
vector.transfer_write %4, %memref2[%arg3, %arg4] : vector<16xi32>, memref<64x64xi32>
}
}
}
return
}

0 comments on commit da291ba

Please sign in to comment.