Skip to content

Commit

Permalink
[mlir][SCF] Modernize coalesceLoops method to handle scf.for loop…
Browse files Browse the repository at this point in the history
…s with iter_args (#87019)

As part of this extension this change also does some general cleanup

1) Make all the methods take `RewriterBase` as arguments instead of
   creating their own builders that tend to crash when used within
   pattern rewrites
2) Split `coalesePerfectlyNestedLoops` into two separate methods, one
   for `scf.for` and other for `affine.for`. The templatization didnt
   seem to be buying much there.

Also general clean up of tests.
  • Loading branch information
MaheshRavishankar committed Apr 4, 2024
1 parent fd2a5c4 commit 5aeb604
Show file tree
Hide file tree
Showing 13 changed files with 587 additions and 240 deletions.
49 changes: 2 additions & 47 deletions mlir/include/mlir/Dialect/Affine/LoopUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,53 +299,8 @@ LogicalResult
separateFullTiles(MutableArrayRef<AffineForOp> nest,
SmallVectorImpl<AffineForOp> *fullTileNest = nullptr);

/// Walk either an scf.for or an affine.for to find a band to coalesce.
template <typename LoopOpTy>
LogicalResult coalescePerfectlyNestedLoops(LoopOpTy op) {
LogicalResult result(failure());
SmallVector<LoopOpTy> loops;
getPerfectlyNestedLoops(loops, op);

// Look for a band of loops that can be coalesced, i.e. perfectly nested
// loops with bounds defined above some loop.
// 1. For each loop, find above which parent loop its operands are
// defined.
SmallVector<unsigned, 4> operandsDefinedAbove(loops.size());
for (unsigned i = 0, e = loops.size(); i < e; ++i) {
operandsDefinedAbove[i] = i;
for (unsigned j = 0; j < i; ++j) {
if (areValuesDefinedAbove(loops[i].getOperands(), loops[j].getRegion())) {
operandsDefinedAbove[i] = j;
break;
}
}
}

// 2. Identify bands of loops such that the operands of all of them are
// defined above the first loop in the band. Traverse the nest bottom-up
// so that modifications don't invalidate the inner loops.
for (unsigned end = loops.size(); end > 0; --end) {
unsigned start = 0;
for (; start < end - 1; ++start) {
auto maxPos =
*std::max_element(std::next(operandsDefinedAbove.begin(), start),
std::next(operandsDefinedAbove.begin(), end));
if (maxPos > start)
continue;
assert(maxPos == start &&
"expected loop bounds to be known at the start of the band");
auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
if (succeeded(coalesceLoops(band)))
result = success();
break;
}
// If a band was found and transformed, keep looking at the loops above
// the outermost transformed loop.
if (start != end - 1)
end = start + 1;
}
return result;
}
/// Walk an affine.for to find a band to coalesce.
LogicalResult coalescePerfectlyNestedAffineLoops(AffineForOp op);

} // namespace affine
} // namespace mlir
Expand Down
7 changes: 6 additions & 1 deletion mlir/include/mlir/Dialect/SCF/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,16 @@ getSCFMinMaxExpr(Value value, SmallVectorImpl<Value> &dims,
/// `loops` contains a list of perfectly nested loops with bounds and steps
/// independent of any loop induction variable involved in the nest.
LogicalResult coalesceLoops(MutableArrayRef<scf::ForOp> loops);
LogicalResult coalesceLoops(RewriterBase &rewriter,
MutableArrayRef<scf::ForOp>);

/// Walk an affine.for to find a band to coalesce.
LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op);

/// Take the ParallelLoop and for each set of dimension indices, combine them
/// into a single dimension. combinedDimensions must contain each index into
/// loops exactly once.
void collapseParallelLoops(scf::ParallelOp loops,
void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops,
ArrayRef<std::vector<unsigned>> combinedDimensions);

/// Unrolls this for operation by the specified unroll factor. Returns failure
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "llvm/Support/TypeName.h"
#include <optional>

using llvm::SmallPtrSetImpl;
namespace mlir {

class PatternRewriter;
Expand Down Expand Up @@ -704,6 +705,8 @@ class RewriterBase : public OpBuilder {
return user != exceptedUser;
});
}
void replaceAllUsesExcept(Value from, Value to,
const SmallPtrSetImpl<Operation *> &preservedUsers);

/// Used to notify the listener that the IR failed to be rewritten because of
/// a match failure, and provide a callback to populate a diagnostic with the
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ struct LoopCoalescingPass
func::FuncOp func = getOperation();
func.walk<WalkOrder::PreOrder>([](Operation *op) {
if (auto scfForOp = dyn_cast<scf::ForOp>(op))
(void)coalescePerfectlyNestedLoops(scfForOp);
(void)coalescePerfectlyNestedSCFForLoops(scfForOp);
else if (auto affineForOp = dyn_cast<AffineForOp>(op))
(void)coalescePerfectlyNestedLoops(affineForOp);
(void)coalescePerfectlyNestedAffineLoops(affineForOp);
});
}
};
Expand Down
48 changes: 48 additions & 0 deletions mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2765,3 +2765,51 @@ mlir::affine::separateFullTiles(MutableArrayRef<AffineForOp> inputNest,

return success();
}

LogicalResult affine::coalescePerfectlyNestedAffineLoops(AffineForOp op) {
LogicalResult result(failure());
SmallVector<AffineForOp> loops;
getPerfectlyNestedLoops(loops, op);
if (loops.size() <= 1)
return success();

// Look for a band of loops that can be coalesced, i.e. perfectly nested
// loops with bounds defined above some loop.
// 1. For each loop, find above which parent loop its operands are
// defined.
SmallVector<unsigned> operandsDefinedAbove(loops.size());
for (unsigned i = 0, e = loops.size(); i < e; ++i) {
operandsDefinedAbove[i] = i;
for (unsigned j = 0; j < i; ++j) {
if (areValuesDefinedAbove(loops[i].getOperands(), loops[j].getRegion())) {
operandsDefinedAbove[i] = j;
break;
}
}
}

// 2. Identify bands of loops such that the operands of all of them are
// defined above the first loop in the band. Traverse the nest bottom-up
// so that modifications don't invalidate the inner loops.
for (unsigned end = loops.size(); end > 0; --end) {
unsigned start = 0;
for (; start < end - 1; ++start) {
auto maxPos =
*std::max_element(std::next(operandsDefinedAbove.begin(), start),
std::next(operandsDefinedAbove.begin(), end));
if (maxPos > start)
continue;
assert(maxPos == start &&
"expected loop bounds to be known at the start of the band");
auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
if (succeeded(coalesceLoops(band)))
result = success();
break;
}
// If a band was found and transformed, keep looking at the loops above
// the outermost transformed loop.
if (start != end - 1)
end = start + 1;
}
return result;
}
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,9 @@ transform::LoopCoalesceOp::applyToOne(transform::TransformRewriter &rewriter,
transform::TransformState &state) {
LogicalResult result(failure());
if (scf::ForOp scfForOp = dyn_cast<scf::ForOp>(op))
result = coalescePerfectlyNestedLoops(scfForOp);
result = coalescePerfectlyNestedSCFForLoops(scfForOp);
else if (AffineForOp affineForOp = dyn_cast<AffineForOp>(op))
result = coalescePerfectlyNestedLoops(affineForOp);
result = coalescePerfectlyNestedAffineLoops(affineForOp);

results.push_back(op);
if (failed(result)) {
Expand Down
4 changes: 3 additions & 1 deletion mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ namespace {
struct TestSCFParallelLoopCollapsing
: public impl::TestSCFParallelLoopCollapsingBase<
TestSCFParallelLoopCollapsing> {

void runOnOperation() override {
Operation *module = getOperation();

Expand Down Expand Up @@ -88,6 +89,7 @@ struct TestSCFParallelLoopCollapsing
// Only apply the transformation on parallel loops where the specified
// transformation is valid, but do NOT early abort in the case of invalid
// loops.
IRRewriter rewriter(&getContext());
module->walk([&](scf::ParallelOp op) {
if (flattenedCombinedLoops.size() != op.getNumLoops()) {
op.emitOpError("has ")
Expand All @@ -97,7 +99,7 @@ struct TestSCFParallelLoopCollapsing
<< flattenedCombinedLoops.size() << " iter args.";
return;
}
collapseParallelLoops(op, combinedLoops);
collapseParallelLoops(rewriter, op, combinedLoops);
});
}
};
Expand Down

0 comments on commit 5aeb604

Please sign in to comment.