Skip to content

Commit

Permalink
[mlir][Interfaces] LoopLikeOpInterface: Support ops with multiple r…
Browse files Browse the repository at this point in the history
…egions (#66754)

This commit implements `LoopLikeOpInterface` on `scf.while`. This
enables LICM (and potentially other transforms) on `scf.while`.

`LoopLikeOpInterface::getLoopBody()` is renamed to `getLoopRegions` and
can now return multiple regions.

Also fix a bug in the default implementation of
`LoopLikeOpInterface::isDefinedOutsideOfLoop()`, which returned "false"
for some values that are defined outside of the loop (in a nested op, in
such a way that the value does not dominate the loop). This interface is
currently only used for LICM and there is no way to trigger this bug, so
no test is added.
  • Loading branch information
matthias-springer committed Sep 19, 2023
1 parent d69293c commit 9b5ef2b
Show file tree
Hide file tree
Showing 23 changed files with 101 additions and 71 deletions.
8 changes: 6 additions & 2 deletions flang/lib/Optimizer/Dialect/FIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1947,7 +1947,9 @@ void fir::IterWhileOp::print(mlir::OpAsmPrinter &p) {
/*printBlockTerminators=*/true);
}

mlir::Region &fir::IterWhileOp::getLoopBody() { return getRegion(); }
llvm::SmallVector<mlir::Region *> fir::IterWhileOp::getLoopRegions() {
return {&getRegion()};
}

mlir::BlockArgument fir::IterWhileOp::iterArgToBlockArg(mlir::Value iterArg) {
for (auto i : llvm::enumerate(getInitArgs()))
Expand Down Expand Up @@ -2234,7 +2236,9 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
printBlockTerminators);
}

mlir::Region &fir::DoLoopOp::getLoopBody() { return getRegion(); }
llvm::SmallVector<mlir::Region *> fir::DoLoopOp::getLoopRegions() {
return {&getRegion()};
}

/// Translate a value passed as an iter_arg to the corresponding block
/// argument in the body of the loop.
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,7 @@ def ReduceReturnOp :
def WhileOp : SCF_Op<"while",
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands"]>,
DeclareOpInterfaceMethods<LoopLikeOpInterface>,
RecursiveMemoryEffects, SingleBlock]> {
let summary = "a generic 'while' loop";
let description = [{
Expand Down
8 changes: 4 additions & 4 deletions mlir/include/mlir/Interfaces/LoopLikeInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
/*args=*/(ins "::mlir::Value ":$value),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return value.getParentRegion()->isProperAncestor(&$_op.getLoopBody());
return !$_op->isAncestor(value.getParentRegion()->getParentOp());
}]
>,
InterfaceMethod<[{
Returns the region that makes up the body of the loop and should be
Returns the regions that make up the body of the loop and should be
inspected for loop-invariant operations.
}],
/*retTy=*/"::mlir::Region &",
/*methodName=*/"getLoopBody"
/*retTy=*/"::llvm::SmallVector<::mlir::Region *>",
/*methodName=*/"getLoopRegions"
>,
InterfaceMethod<[{
Moves the given loop-invariant operation out of the loop.
Expand Down
5 changes: 3 additions & 2 deletions mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@

#include "mlir/Support/LLVM.h"

#include "llvm/ADT/SmallVector.h"

namespace mlir {

class LoopLikeOpInterface;
class Operation;
class Region;
class RegionRange;
class Value;

/// Given a list of regions, perform loop-invariant code motion. An operation is
Expand Down Expand Up @@ -61,7 +62,7 @@ class Value;
///
/// Returns the number of operations moved.
size_t moveLoopInvariantCode(
RegionRange regions,
ArrayRef<Region *> regions,
function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
function_ref<void(Operation *, Region *)> moveOutOfRegion);
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
signatureConverter.remapInput(0, newIndVar);
for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
signatureConverter.remapInput(i, header->getArgument(i));
body = rewriter.applySignatureConversion(&forOp.getLoopBody(),
body = rewriter.applySignatureConversion(&forOp.getRegion(),
signatureConverter);

// Move the blocks from the forOp into the loopOp. This is the body of the
Expand Down
7 changes: 3 additions & 4 deletions mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1103,7 +1103,7 @@ convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
}

// Replace ForOp with a new ForOp with extra operands. The YieldOp is not
// updated and needs to be updated separatly for the loop to be correct.
// updated and needs to be updated separately for the loop to be correct.
static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
scf::ForOp loop,
ValueRange newInitArgs) {
Expand All @@ -1119,9 +1119,8 @@ static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
operands);
newLoop.getBody()->erase();

newLoop.getLoopBody().getBlocks().splice(
newLoop.getLoopBody().getBlocks().begin(),
loop.getLoopBody().getBlocks());
newLoop.getRegion().getBlocks().splice(
newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
for (Value operand : newInitArgs)
newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());

Expand Down
20 changes: 10 additions & 10 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2380,8 +2380,7 @@ void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
/// induction variable. AffineForOp only has one region, so zero is the only
/// valid value for `index`.
OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
assert((point.isParent() || point == getLoopBody()) &&
"invalid region point");
assert((point.isParent() || point == getRegion()) && "invalid region point");

// The initial operands map to the loop arguments after the induction
// variable or are forwarded to the results when the trip count is zero.
Expand All @@ -2395,16 +2394,15 @@ OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
/// not a constant.
void AffineForOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
assert((point.isParent() || point == getLoopBody()) &&
"expected loop region");
assert((point.isParent() || point == getRegion()) && "expected loop region");
// The loop may typically branch back to its body or to the parent operation.
// If the predecessor is the parent op and the trip count is known to be at
// least one, branch into the body using the iterator arguments. And in cases
// we know the trip count is zero, it can only branch back to its parent.
std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
if (point.isParent() && tripCount.has_value()) {
if (tripCount.value() > 0) {
regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
return;
}
if (tripCount.value() == 0) {
Expand All @@ -2422,7 +2420,7 @@ void AffineForOp::getSuccessorRegions(

// In all other cases, the loop may branch back to itself or the parent
// operation.
regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
regions.push_back(RegionSuccessor(getResults()));
}

Expand Down Expand Up @@ -2561,7 +2559,7 @@ bool AffineForOp::matchingBoundOperandList() {
return true;
}

Region &AffineForOp::getLoopBody() { return getRegion(); }
SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }

std::optional<Value> AffineForOp::getSingleInductionVar() {
return getInductionVar();
Expand Down Expand Up @@ -2758,9 +2756,9 @@ AffineForOp mlir::affine::replaceForOpWithNewYields(OpBuilder &b,
b.create<AffineForOp>(loop.getLoc(), lbOperands, lbMap, ubOperands, ubMap,
loop.getStep(), operands);
// Take the body of the original parent loop.
newLoop.getLoopBody().takeBody(loop.getLoopBody());
newLoop.getRegion().takeBody(loop.getRegion());
for (Value val : newIterArgs)
newLoop.getLoopBody().addArgument(val.getType(), val.getLoc());
newLoop.getRegion().addArgument(val.getType(), val.getLoc());

// Update yield operation with new values to be added.
if (!newYieldedValues.empty()) {
Expand Down Expand Up @@ -3848,7 +3846,9 @@ void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
ensureTerminator(*bodyRegion, builder, result.location);
}

Region &AffineParallelOp::getLoopBody() { return getRegion(); }
SmallVector<Region *> AffineParallelOp::getLoopRegions() {
return {&getRegion()};
}

unsigned AffineParallelOp::getNumDims() { return getSteps().size(); }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ static bool isOpLoopInvariant(Operation &op, Value indVar, ValueRange iterArgs,
opsToHoist))
return false;
} else if (auto forOp = dyn_cast<AffineForOp>(op)) {
if (!areAllOpsInTheBlockListInvariant(forOp.getLoopBody(), indVar, iterArgs,
if (!areAllOpsInTheBlockListInvariant(forOp.getRegion(), indVar, iterArgs,
opsWithUsers, opsToHoist))
return false;
} else if (auto parOp = dyn_cast<AffineParallelOp>(op)) {
if (!areAllOpsInTheBlockListInvariant(parOp.getLoopBody(), indVar, iterArgs,
if (!areAllOpsInTheBlockListInvariant(parOp.getRegion(), indVar, iterArgs,
opsWithUsers, opsToHoist))
return false;
} else if (!isMemoryEffectFree(&op) &&
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ static ParallelComputeFunction createParallelComputeFunction(
mapping.map(op.getInductionVars(), computeBlockInductionVars);
mapping.map(computeFuncType.captures, captures);

for (auto &bodyOp : op.getLoopBody().getOps())
for (auto &bodyOp : op.getRegion().getOps())
b.clone(bodyOp, mapping);
};
};
Expand Down Expand Up @@ -732,7 +732,7 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,

// Make sure that all constants will be inside the parallel operation body to
// reduce the number of parallel compute function arguments.
cloneConstantsIntoTheRegion(op.getLoopBody(), rewriter);
cloneConstantsIntoTheRegion(op.getRegion(), rewriter);

// Compute trip count for each loop induction variable:
// tripCount = ceil_div(upperBound - lowerBound, step);
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) {
// Replace all uses of the `transferRead` with the corresponding
// basic block argument.
transferRead.getVector().replaceUsesWithIf(
newForOp.getLoopBody().getArguments().back(),
[&](OpOperand &use) {
newForOp.getBody()->getArguments().back(), [&](OpOperand &use) {
Operation *user = use.getOwner();
return newForOp->isProperAncestor(user);
});
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,9 @@ static void replaceIndexOpsByInductionVariables(RewriterBase &rewriter,
// Replace the index operations in the body of the innermost loop op.
if (!loopOps.empty()) {
auto loopOp = cast<LoopLikeOpInterface>(loopOps.back());
for (IndexOp indexOp :
llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>()))
rewriter.replaceOp(indexOp, allIvs[indexOp.getDim()]);
for (Region *r : loopOp.getLoopRegions())
for (IndexOp indexOp : llvm::make_early_inc_range(r->getOps<IndexOp>()))
rewriter.replaceOp(indexOp, allIvs[indexOp.getDim()]);
}
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ static Operation *isTensorChunkAccessedByUnknownOp(Operation *writeOp,
// pass-through tensor arguments left from previous level of
// hoisting.
if (auto forUser = dyn_cast<scf::ForOp>(user)) {
Value arg = forUser.getLoopBody().getArgument(
Value arg = forUser.getBody()->getArgument(
use.getOperandNumber() - forUser.getNumControlOperands() +
/*iv value*/ 1);
uses.push_back(arg.getUses());
Expand Down
8 changes: 5 additions & 3 deletions mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,9 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
std::optional<Value> inductionVar = candidateLoop.getSingleInductionVar();
std::optional<OpFoldResult> lowerBound = candidateLoop.getSingleLowerBound();
std::optional<OpFoldResult> singleStep = candidateLoop.getSingleStep();
if (!inductionVar || !lowerBound || !singleStep) {
LLVM_DEBUG(DBGS() << "Skip alloc: no single iv, lb or step\n");
if (!inductionVar || !lowerBound || !singleStep ||
!llvm::hasSingleElement(candidateLoop.getLoopRegions())) {
LLVM_DEBUG(DBGS() << "Skip alloc: no single iv, lb, step or region\n");
return failure();
}

Expand Down Expand Up @@ -184,7 +185,8 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,

// 3. Within the loop, build the modular leading index (i.e. each loop
// iteration %iv accesses slice ((%iv - %lb) / %step) % %mb_factor).
rewriter.setInsertionPointToStart(&candidateLoop.getLoopBody().front());
rewriter.setInsertionPointToStart(
&candidateLoop.getLoopRegions().front()->front());
Value ivVal = *inductionVar;
Value lbVal = getValueOrCreateConstantIndexOp(rewriter, loc, *lowerBound);
Value stepVal = getValueOrCreateConstantIndexOp(rewriter, loc, *singleStep);
Expand Down
20 changes: 12 additions & 8 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
return success();
}

Region &ForOp::getLoopBody() { return getRegion(); }
SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }

ForOp mlir::scf::getForInductionVarOwner(Value val) {
auto ivArg = llvm::dyn_cast<BlockArgument>(val);
Expand Down Expand Up @@ -558,11 +558,11 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point,
// Both the operation itself and the region may be branching into the body or
// back into the operation itself. It is possible for loop not to enter the
// body.
regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
regions.push_back(RegionSuccessor(getResults()));
}

Region &ForallOp::getLoopBody() { return getRegion(); }
SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }

/// Promotes the loop body of a forallOp to its containing block if it can be
/// determined that the loop has a single iteration.
Expand Down Expand Up @@ -894,7 +894,7 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
blockArgs.reserve(op.getInitArgs().size() + 1);
blockArgs.push_back(op.getLowerBound());
llvm::append_range(blockArgs, op.getInitArgs());
replaceOpWithRegion(rewriter, op, op.getLoopBody(), blockArgs);
replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs);
return success();
}

Expand Down Expand Up @@ -2872,7 +2872,7 @@ void ParallelOp::print(OpAsmPrinter &p) {
/*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
}

Region &ParallelOp::getLoopBody() { return getRegion(); }
SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }

ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) {
auto ivArg = llvm::dyn_cast<BlockArgument>(val);
Expand Down Expand Up @@ -2926,7 +2926,7 @@ struct ParallelOpSingleOrZeroIterationDimsFolder
// loop body and nested ReduceOp's
SmallVector<Value> results;
results.reserve(op.getInitVals().size());
for (auto &bodyOp : op.getLoopBody().front().without_terminator()) {
for (auto &bodyOp : op.getBody()->without_terminator()) {
auto reduce = dyn_cast<ReduceOp>(bodyOp);
if (!reduce) {
rewriter.clone(bodyOp, mapping);
Expand Down Expand Up @@ -2965,7 +2965,7 @@ struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {

LogicalResult matchAndRewrite(ParallelOp op,
PatternRewriter &rewriter) const override {
Block &outerBody = op.getLoopBody().front();
Block &outerBody = *op.getBody();
if (!llvm::hasSingleElement(outerBody.without_terminator()))
return failure();

Expand All @@ -2985,7 +2985,7 @@ struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {

auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
ValueRange iterVals, ValueRange) {
Block &innerBody = innerOp.getLoopBody().front();
Block &innerBody = *innerOp.getBody();
assert(iterVals.size() ==
(outerBody.getNumArguments() + innerBody.getNumArguments()));
IRMapping mapping;
Expand Down Expand Up @@ -3203,6 +3203,10 @@ void WhileOp::getSuccessorRegions(RegionBranchPoint point,
regions.emplace_back(&getAfter(), getAfter().getArguments());
}

SmallVector<Region *> WhileOp::getLoopRegions() {
return {&getBefore(), &getAfter()};
}

/// Parses a `while` op.
///
/// op ::= `scf.while` assignments `:` function-type region `do` region
Expand Down
9 changes: 3 additions & 6 deletions mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,8 @@ struct ForOpInterface

// An EQ constraint can be added if the yielded value (dimension size)
// equals the corresponding block argument (dimension size).
assert(forOp.getLoopBody().hasOneBlock() &&
"multiple blocks not supported");
Value yieldedValue =
cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator())
.getOperand(iterArgIdx);
Value yieldedValue = cast<scf::YieldOp>(forOp.getBody()->getTerminator())
.getOperand(iterArgIdx);
Value iterArg = forOp.getRegionIterArg(iterArgIdx);
Value initArg = forOp.getInitArgs()[iterArgIdx];

Expand Down Expand Up @@ -68,7 +65,7 @@ struct ForOpInterface
// Stop when reaching a value that is defined outside of the loop. It
// is impossible to reach an iter_arg from there.
Operation *op = v.getDefiningOp();
return forOp.getLoopBody().findAncestorOpInRegion(*op) == nullptr;
return forOp.getRegion().findAncestorOpInRegion(*op) == nullptr;
});
if (failed(status))
return;
Expand Down

0 comments on commit 9b5ef2b

Please sign in to comment.