Skip to content

Commit

Permalink
[mlir][scf] Simplify the logic for replaceLoopWithNewYields for per…
Browse files Browse the repository at this point in the history
…fectly nested loops.

Based on discussion in https://reviews.llvm.org/D134411, instead of
first modifying the inner most loop first followed by modifying the
outer loops from inside out, this patch restructures the logic to
start the modification from the outer most loop.

Differential Revision: https://reviews.llvm.org/D134832
  • Loading branch information
Mahesh Ravishankar committed Sep 29, 2022
1 parent f079ba7 commit 0d5cb90
Showing 1 changed file with 59 additions and 49 deletions.
108 changes: 59 additions & 49 deletions mlir/lib/Dialect/SCF/Utils/Utils.cpp
Expand Up @@ -111,56 +111,66 @@ SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields(
bool replaceIterOperandsUsesInLoop) {
if (loopNest.empty())
return {};
SmallVector<scf::ForOp> newLoopNest(loopNest.size());

newLoopNest.back() = replaceLoopWithNewYields(
builder, loopNest.back(), newIterOperands, newYieldValueFn);

for (unsigned loopDepth :
llvm::reverse(llvm::seq<unsigned>(0, loopNest.size() - 1))) {
NewYieldValueFn fn = [&](OpBuilder &innerBuilder, Location loc,
ArrayRef<BlockArgument> innerNewBBArgs) {
SmallVector<Value> newYields(
newLoopNest[loopDepth + 1]->getResults().take_back(
newIterOperands.size()));
return newYields;
};
newLoopNest[loopDepth] =
replaceLoopWithNewYields(builder, loopNest[loopDepth], newIterOperands,
fn, replaceIterOperandsUsesInLoop);
if (!replaceIterOperandsUsesInLoop) {
/// The yield is expected to producer the following structure
/// ```
/// %0 = scf.for ... iter_args(%arg0 = %init) {
/// %1 = scf.for ... iter_args(%arg1 = %arg0) {
/// scf.yield %yield
/// }
/// }
/// ```
///
/// since the yield is propagated from inside out, after the inner
/// loop is processed the IR is in this form
///
/// ```
/// scf.for ... iter_args {
/// %1 = scf.for ... iter_args(%arg1 = %init) {
/// scf.yield %yield
/// }
/// ```
///
/// If `replaceIterOperandUsesInLoops` is true, there is nothing to do.
/// `%init` will be replaced with `%arg0` when it is created for the
/// outer loop. But without that this has to be done explicitly.
unsigned subLen = newIterOperands.size();
unsigned subStart =
newLoopNest[loopDepth + 1].getNumIterOperands() - subLen;
auto resetOperands =
newLoopNest[loopDepth + 1].getInitArgsMutable().slice(subStart,
subLen);
resetOperands.assign(
newLoopNest[loopDepth].getRegionIterArgs().take_back(subLen));
}
// This method is recursive (to make it more readable). Adding an
// assertion here to limit the recursion. (See
// https://discourse.llvm.org/t/rfc-update-to-mlir-developer-policy-on-recursion/62235)
assert(loopNest.size() <= 6 &&
"exceeded recursion limit when yielding value from loop nest");

// To yield a value from a perfectly nested loop nest, the following
// pattern needs to be created, i.e. starting with
//
// ```mlir
// scf.for .. {
// scf.for .. {
// scf.for .. {
// %value = ...
// }
// }
// }
// ```
//
// needs to be modified to
//
// ```mlir
// %0 = scf.for .. iter_args(%arg0 = %init) {
// %1 = scf.for .. iter_args(%arg1 = %arg0) {
// %2 = scf.for .. iter_args(%arg2 = %arg1) {
// %value = ...
// scf.yield %value
// }
// scf.yield %2
// }
// scf.yield %1
// }
// ```
//
// The inner most loop is handled using the `replaceLoopWithNewYields`
// that works on a single loop.
if (loopNest.size() == 1) {
auto innerMostLoop = replaceLoopWithNewYields(
builder, loopNest.back(), newIterOperands, newYieldValueFn,
replaceIterOperandsUsesInLoop);
return {innerMostLoop};
}
// The outer loops are modified by calling this method recursively
// - The return value of the inner loop is the value yielded by this loop.
// - The region iter args of this loop are the init_args for the inner loop.
SmallVector<scf::ForOp> newLoopNest;
NewYieldValueFn fn =
[&](OpBuilder &innerBuilder, Location loc,
ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> {
newLoopNest = replaceLoopNestWithNewYields(builder, loopNest.drop_front(),
innerNewBBArgs, newYieldValueFn,
replaceIterOperandsUsesInLoop);
return llvm::to_vector(llvm::map_range(
newLoopNest.front().getResults().take_back(innerNewBBArgs.size()),
[](OpResult r) -> Value { return r; }));
};
scf::ForOp outerMostLoop =
replaceLoopWithNewYields(builder, loopNest.front(), newIterOperands, fn,
replaceIterOperandsUsesInLoop);
newLoopNest.insert(newLoopNest.begin(), outerMostLoop);
return newLoopNest;
}

Expand Down

0 comments on commit 0d5cb90

Please sign in to comment.