-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR] Add a getStaticTripCount method to LoopLikeOpInterface #158679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,13 +19,18 @@ | |
#include "mlir/IR/BuiltinAttributes.h" | ||
#include "mlir/IR/IRMapping.h" | ||
#include "mlir/IR/Matchers.h" | ||
#include "mlir/IR/Operation.h" | ||
#include "mlir/IR/OperationSupport.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Interfaces/FunctionInterfaces.h" | ||
#include "mlir/Interfaces/ParallelCombiningOpInterface.h" | ||
#include "mlir/Interfaces/ValueBoundsOpInterface.h" | ||
#include "mlir/Transforms/InliningUtils.h" | ||
#include "llvm/ADT/MapVector.h" | ||
#include "llvm/ADT/SmallPtrSet.h" | ||
#include "llvm/Support/Casting.h" | ||
#include "llvm/Support/DebugLog.h" | ||
#include <optional> | ||
|
||
using namespace mlir; | ||
using namespace mlir::scf; | ||
|
@@ -105,6 +110,24 @@ static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, | |
return nullptr; | ||
} | ||
|
||
/// Helper function to compute the difference between two values. This is used | ||
/// by the loop implementations to compute the trip count. | ||
static std::optional<llvm::APSInt> computeUbMinusLb(Value lb, Value ub, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is for a restricted form right? looks for ub = lb + const. E.g., this would fail in the case where ub and lb are constants. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes it's not meant for constants. |
||
bool isSigned) { | ||
llvm::APSInt diff; | ||
auto addOp = ub.getDefiningOp<arith::AddIOp>(); | ||
if (!addOp) | ||
return std::nullopt; | ||
if ((isSigned && !addOp.hasNoSignedWrap()) || | ||
(!isSigned && !addOp.hasNoUnsignedWrap())) | ||
return std::nullopt; | ||
|
||
if (addOp.getLhs() != lb || | ||
!matchPattern(addOp.getRhs(), m_ConstantInt(&diff))) | ||
return std::nullopt; | ||
return diff; | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// ExecuteRegionOp | ||
//===----------------------------------------------------------------------===// | ||
|
@@ -408,11 +431,19 @@ std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); } | |
/// Promotes the loop body of a forOp to its containing block if the forOp | ||
/// it can be determined that the loop has a single iteration. | ||
LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) { | ||
std::optional<int64_t> tripCount = | ||
constantTripCount(getLowerBound(), getUpperBound(), getStep()); | ||
if (!tripCount.has_value() || tripCount != 1) | ||
std::optional<APInt> tripCount = getStaticTripCount(); | ||
LDBG() << "promoteIfSingleIteration tripCount is " << tripCount | ||
<< " for loop " | ||
<< OpWithFlags(getOperation(), OpPrintingFlags().skipRegions()); | ||
if (!tripCount.has_value() || tripCount->getSExtValue() > 1) | ||
return failure(); | ||
|
||
if (*tripCount == 0) { | ||
rewriter.replaceAllUsesWith(getResults(), getInitArgs()); | ||
rewriter.eraseOp(*this); | ||
return success(); | ||
} | ||
|
||
// Replace all results with the yielded values. | ||
auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator()); | ||
rewriter.replaceAllUsesWith(getResults(), getYieldedValues()); | ||
|
@@ -646,7 +677,8 @@ SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; } | |
LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) { | ||
for (auto [lb, ub, step] : | ||
llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) { | ||
auto tripCount = constantTripCount(lb, ub, step); | ||
auto tripCount = | ||
constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb); | ||
if (!tripCount.has_value() || *tripCount != 1) | ||
return failure(); | ||
} | ||
|
@@ -1003,27 +1035,6 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> { | |
} | ||
}; | ||
|
||
/// Util function that tries to compute a constant diff between u and l. | ||
/// Returns std::nullopt when the difference between two AffineValueMap is | ||
/// dynamic. | ||
static std::optional<APInt> computeConstDiff(Value l, Value u) { | ||
IntegerAttr clb, cub; | ||
if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) { | ||
llvm::APInt lbValue = clb.getValue(); | ||
llvm::APInt ubValue = cub.getValue(); | ||
return ubValue - lbValue; | ||
} | ||
|
||
// Else a simple pattern match for x + c or c + x | ||
llvm::APInt diff; | ||
if (matchPattern( | ||
u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) || | ||
matchPattern( | ||
u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l)))) | ||
return diff; | ||
return std::nullopt; | ||
} | ||
|
||
/// Rewriting pattern that erases loops that are known not to iterate, replaces | ||
/// single-iteration loops with their bodies, and removes empty loops that | ||
/// iterate at least once and only return values defined outside of the loop. | ||
|
@@ -1032,34 +1043,21 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> { | |
|
||
LogicalResult matchAndRewrite(ForOp op, | ||
PatternRewriter &rewriter) const override { | ||
// If the upper bound is the same as the lower bound, the loop does not | ||
// iterate, just remove it. | ||
if (op.getLowerBound() == op.getUpperBound()) { | ||
std::optional<APInt> tripCount = op.getStaticTripCount(); | ||
if (!tripCount.has_value()) | ||
return rewriter.notifyMatchFailure(op, | ||
"can't compute constant trip count"); | ||
|
||
if (tripCount->isZero()) { | ||
LDBG() << "SimplifyTrivialLoops tripCount is 0 for loop " | ||
<< OpWithFlags(op, OpPrintingFlags().skipRegions()); | ||
rewriter.replaceOp(op, op.getInitArgs()); | ||
return success(); | ||
} | ||
|
||
std::optional<APInt> diff = | ||
computeConstDiff(op.getLowerBound(), op.getUpperBound()); | ||
if (!diff) | ||
return failure(); | ||
|
||
// If the loop is known to have 0 iterations, remove it. | ||
bool zeroOrLessIterations = | ||
diff->isZero() || (!op.getUnsignedCmp() && diff->isNegative()); | ||
if (zeroOrLessIterations) { | ||
rewriter.replaceOp(op, op.getInitArgs()); | ||
return success(); | ||
} | ||
|
||
std::optional<llvm::APInt> maybeStepValue = op.getConstantStep(); | ||
if (!maybeStepValue) | ||
return failure(); | ||
|
||
// If the loop is known to have 1 iteration, inline its body and remove the | ||
// loop. | ||
llvm::APInt stepValue = *maybeStepValue; | ||
if (stepValue.sge(*diff)) { | ||
if (tripCount->getSExtValue() == 1) { | ||
LDBG() << "SimplifyTrivialLoops tripCount is 1 for loop " | ||
<< OpWithFlags(op, OpPrintingFlags().skipRegions()); | ||
SmallVector<Value, 4> blockArgs; | ||
blockArgs.reserve(op.getInitArgs().size() + 1); | ||
blockArgs.push_back(op.getLowerBound()); | ||
|
@@ -1072,11 +1070,14 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> { | |
Block &block = op.getRegion().front(); | ||
if (!llvm::hasSingleElement(block)) | ||
return failure(); | ||
// If the loop is empty, iterates at least once, and only returns values | ||
// The loop is empty and iterates at least once, if it only returns values | ||
// defined outside of the loop, remove it and replace it with yield values. | ||
if (llvm::any_of(op.getYieldedValues(), | ||
[&](Value v) { return !op.isDefinedOutsideOfLoop(v); })) | ||
return failure(); | ||
LDBG() << "SimplifyTrivialLoops empty body loop allows replacement with " | ||
"yield operands for loop " | ||
<< OpWithFlags(op, OpPrintingFlags().skipRegions()); | ||
rewriter.replaceOp(op, op.getYieldedValues()); | ||
return success(); | ||
} | ||
|
@@ -1172,6 +1173,11 @@ Speculation::Speculatability ForOp::getSpeculatability() { | |
return Speculation::NotSpeculatable; | ||
} | ||
|
||
std::optional<APInt> ForOp::getStaticTripCount() { | ||
return constantTripCount(getLowerBound(), getUpperBound(), getStep(), | ||
/*isSigned=*/!getUnsignedCmp(), computeUbMinusLb); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// ForallOp | ||
//===----------------------------------------------------------------------===// | ||
|
@@ -1768,7 +1774,8 @@ struct ForallOpSingleOrZeroIterationDimsFolder | |
for (auto [lb, ub, step, iv] : | ||
llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(), | ||
op.getMixedStep(), op.getInductionVars())) { | ||
auto numIterations = constantTripCount(lb, ub, step); | ||
auto numIterations = | ||
constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb); | ||
if (numIterations.has_value()) { | ||
// Remove the loop if it performs zero iterations. | ||
if (*numIterations == 0) { | ||
|
@@ -1839,7 +1846,8 @@ struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> { | |
op.getMixedStep(), op.getInductionVars())) { | ||
if (iv.hasNUses(0)) | ||
continue; | ||
auto numIterations = constantTripCount(lb, ub, step); | ||
auto numIterations = | ||
constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb); | ||
if (!numIterations.has_value() || numIterations.value() != 1) { | ||
continue; | ||
} | ||
|
@@ -3084,7 +3092,8 @@ struct ParallelOpSingleOrZeroIterationDimsFolder | |
for (auto [lb, ub, step, iv] : | ||
llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(), | ||
op.getInductionVars())) { | ||
auto numIterations = constantTripCount(lb, ub, step); | ||
auto numIterations = | ||
constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb); | ||
if (numIterations.has_value()) { | ||
// Remove the loop if it performs zero iterations. | ||
if (*numIterations == 0) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ | |
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Interfaces/SideEffectInterfaces.h" | ||
#include "mlir/Transforms/RegionUtils.h" | ||
#include "llvm/ADT/APInt.h" | ||
#include "llvm/ADT/STLExtras.h" | ||
#include "llvm/ADT/SmallVector.h" | ||
#include "llvm/Support/DebugLog.h" | ||
|
@@ -291,14 +292,6 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, | |
return arith::DivUIOp::create(builder, loc, sum, divisor); | ||
} | ||
|
||
/// Returns the trip count of `forOp` if its' low bound, high bound and step are | ||
/// constants, or optional otherwise. Trip count is computed as | ||
/// ceilDiv(highBound - lowBound, step). | ||
static std::optional<int64_t> getConstantTripCount(scf::ForOp forOp) { | ||
return constantTripCount(forOp.getLowerBound(), forOp.getUpperBound(), | ||
forOp.getStep()); | ||
} | ||
|
||
/// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with | ||
/// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap | ||
/// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each | ||
|
@@ -377,7 +370,7 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor( | |
Value stepUnrolled; | ||
bool generateEpilogueLoop = true; | ||
|
||
std::optional<int64_t> constTripCount = getConstantTripCount(forOp); | ||
std::optional<APInt> constTripCount = forOp.getStaticTripCount(); | ||
if (constTripCount) { | ||
// Constant loop bounds computation. | ||
int64_t lbCst = getConstantIntValue(forOp.getLowerBound()).value(); | ||
|
@@ -391,7 +384,8 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor( | |
} | ||
|
||
int64_t tripCountEvenMultiple = | ||
*constTripCount - (*constTripCount % unrollFactor); | ||
constTripCount->getSExtValue() - | ||
(constTripCount->getSExtValue() % unrollFactor); | ||
int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst; | ||
int64_t stepUnrolledCst = stepCst * unrollFactor; | ||
|
||
|
@@ -487,15 +481,15 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor( | |
/// Unrolls this loop completely. | ||
LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) { | ||
IRRewriter rewriter(forOp.getContext()); | ||
std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp); | ||
std::optional<APInt> mayBeConstantTripCount = forOp.getStaticTripCount(); | ||
if (!mayBeConstantTripCount.has_value()) | ||
return failure(); | ||
uint64_t tripCount = *mayBeConstantTripCount; | ||
if (tripCount == 0) | ||
APInt &tripCount = *mayBeConstantTripCount; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. const reference? |
||
if (tripCount.isZero()) | ||
return success(); | ||
if (tripCount == 1) | ||
if (tripCount.getSExtValue() == 1) | ||
return forOp.promoteIfSingleIteration(rewriter); | ||
return loopUnrollByFactor(forOp, tripCount); | ||
return loopUnrollByFactor(forOp, tripCount.getSExtValue()); | ||
} | ||
|
||
/// Check if bounds of all inner loops are defined outside of `forOp` | ||
|
@@ -535,18 +529,18 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp, | |
|
||
// Currently, only constant trip count that divided by the unroll factor is | ||
// supported. | ||
std::optional<uint64_t> tripCount = getConstantTripCount(forOp); | ||
std::optional<APInt> tripCount = forOp.getStaticTripCount(); | ||
if (!tripCount.has_value()) { | ||
// If the trip count is dynamic, do not unroll & jam. | ||
LDBG() << "failed to unroll and jam: trip count could not be determined"; | ||
return failure(); | ||
} | ||
if (unrollJamFactor > *tripCount) { | ||
if (unrollJamFactor > tripCount->getZExtValue()) { | ||
LDBG() << "unroll and jam factor is greater than trip count, set factor to " | ||
"trip " | ||
"count"; | ||
unrollJamFactor = *tripCount; | ||
} else if (*tripCount % unrollJamFactor != 0) { | ||
unrollJamFactor = tripCount->getZExtValue(); | ||
} else if (tripCount->getSExtValue() % unrollJamFactor != 0) { | ||
LDBG() << "failed to unroll and jam: unsupported trip count that is not a " | ||
"multiple of unroll jam factor"; | ||
return failure(); | ||
|
Uh oh!
There was an error while loading. Please reload this page.