-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR] Add a getStaticTripCount method to LoopLikeOpInterface #158679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-scf Author: Mehdi Amini (joker-eph) ChangesThis patch adds a Patch is 47.73 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/158679.diff 10 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index d3c01c31636a7..fadd3fc10bfc4 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -152,7 +152,7 @@ def ForOp : SCF_Op<"for",
[AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getInitsMutable", "getLoopResults", "getRegionIterArgs",
"getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
- "getLoopUpperBounds", "getYieldedValuesMutable",
+ "getLoopUpperBounds", "getStaticTripCount", "getYieldedValuesMutable",
"promoteIfSingleIteration", "replaceWithAdditionalYields",
"yieldTiledValuesAndReplace"]>,
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 77c376fb9973a..aa0cc35a1d675 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -105,6 +105,10 @@ OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val);
SmallVector<OpFoldResult> getAsIndexOpFoldResult(MLIRContext *ctx,
ArrayRef<int64_t> values);
+/// If ofr is a constant integer or an IntegerAttr, return the integer.
+/// The second return value indicates whether the value is an index type
+/// and thus the bitwidth is not defined.
+std::optional<std::pair<APInt, bool>> getConstantAPIntValue(OpFoldResult ofr);
/// If ofr is a constant integer or an IntegerAttr, return the integer.
std::optional<int64_t> getConstantIntValue(OpFoldResult ofr);
/// If all ofrs are constant integers or IntegerAttrs, return the integers.
@@ -201,9 +205,24 @@ foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes);
LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides);
/// Return the number of iterations for a loop with a lower bound `lb`, upper
-/// bound `ub` and step `step`.
-std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
- OpFoldResult step);
+/// bound `ub` and step `step`. The `isSigned` flag indicates whether the loop
+/// comparison between lb and ub is signed or unsigned. A negative step or a
+/// lower bound greater than the upper bound are considered invalid and will
+/// yield a zero trip count.
+/// The `computeUbMinusLb` callback is invoked to compute the difference between
+/// the upper and lower bound when not constant. It can be used by the client
+/// to match the pattern:
+///
+/// %ub = arith.addi nsw %lb, %c16_i32 : i32
+/// %1 = scf.for %arg0 = %lb to %ub ...
+///
+/// where %ub is computed as a static offset from %lb.
+/// Note: the matched addition should be nsw/nuw (matching the loop comparison)
+/// to avoid overflow, otherwise an overflow would imply a zero trip count.
+std::optional<APInt> constantTripCount(
+ OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned,
+ llvm::function_ref<std::optional<llvm::APSInt>(Value, Value, bool)>
+ computeUbMinusLb);
/// Idiomatic saturated operations on values like offsets, sizes, and strides.
struct SaturatedInteger {
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index 6c95b4802837b..cfd15a7746e19 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -232,6 +232,17 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
/*defaultImplementation=*/[{
return ::mlir::failure();
}]
+ >,
+ InterfaceMethod<[{
+ Compute the static trip count if possible.
+ }],
+ /*retTy=*/"::std::optional<APInt>",
+ /*methodName=*/"getStaticTripCount",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::std::nullopt;
+ }]
>
];
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index c35989ecba6cd..e68dc04de231b 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -19,6 +19,8 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
@@ -26,6 +28,9 @@
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/DebugLog.h"
+#include <optional>
using namespace mlir;
using namespace mlir::scf;
@@ -105,6 +110,23 @@ static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion,
return nullptr;
}
+/// Helper function to compute the difference between two values. This is used
+/// by the loop implementations to compute the trip count.
+static std::optional<llvm::APSInt> computeUbMinusLb(Value lb, Value ub,
+ bool isSigned) {
+ llvm::APSInt diff;
+ auto addOp = ub.getDefiningOp<arith::AddIOp>();
+ if (!addOp)
+ return std::nullopt;
+ if ((isSigned && !addOp.hasNoSignedWrap()) ||
+ (!isSigned && !addOp.hasNoUnsignedWrap()))
+ return std::nullopt;
+
+ if (!matchPattern(addOp.getRhs(), m_ConstantInt(&diff)))
+ return std::nullopt;
+ return diff;
+}
+
//===----------------------------------------------------------------------===//
// ExecuteRegionOp
//===----------------------------------------------------------------------===//
@@ -408,11 +430,19 @@ std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
/// Promotes the loop body of a forOp to its containing block if the forOp
/// it can be determined that the loop has a single iteration.
LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
- std::optional<int64_t> tripCount =
- constantTripCount(getLowerBound(), getUpperBound(), getStep());
- if (!tripCount.has_value() || tripCount != 1)
+ std::optional<APInt> tripCount = getStaticTripCount();
+ LDBG() << "promoteIfSingleIteration tripCount is " << tripCount
+ << " for loop "
+ << OpWithFlags(getOperation(), OpPrintingFlags().skipRegions());
+ if (!tripCount.has_value() || tripCount->getSExtValue() > 1)
return failure();
+ if (*tripCount == 0) {
+ rewriter.replaceAllUsesWith(getResults(), getInitArgs());
+ rewriter.eraseOp(*this);
+ return success();
+ }
+
// Replace all results with the yielded values.
auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
@@ -646,7 +676,8 @@ SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
for (auto [lb, ub, step] :
llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
- auto tripCount = constantTripCount(lb, ub, step);
+ auto tripCount =
+ constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
if (!tripCount.has_value() || *tripCount != 1)
return failure();
}
@@ -1003,27 +1034,6 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
}
};
-/// Util function that tries to compute a constant diff between u and l.
-/// Returns std::nullopt when the difference between two AffineValueMap is
-/// dynamic.
-static std::optional<APInt> computeConstDiff(Value l, Value u) {
- IntegerAttr clb, cub;
- if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
- llvm::APInt lbValue = clb.getValue();
- llvm::APInt ubValue = cub.getValue();
- return ubValue - lbValue;
- }
-
- // Else a simple pattern match for x + c or c + x
- llvm::APInt diff;
- if (matchPattern(
- u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) ||
- matchPattern(
- u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l))))
- return diff;
- return std::nullopt;
-}
-
/// Rewriting pattern that erases loops that are known not to iterate, replaces
/// single-iteration loops with their bodies, and removes empty loops that
/// iterate at least once and only return values defined outside of the loop.
@@ -1032,34 +1042,21 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
LogicalResult matchAndRewrite(ForOp op,
PatternRewriter &rewriter) const override {
- // If the upper bound is the same as the lower bound, the loop does not
- // iterate, just remove it.
- if (op.getLowerBound() == op.getUpperBound()) {
+ std::optional<APInt> tripCount = op.getStaticTripCount();
+ if (!tripCount.has_value())
+ return rewriter.notifyMatchFailure(op,
+ "can't compute constant trip count");
+
+ if (tripCount->isZero()) {
+ LDBG() << "SimplifyTrivialLoops tripCount is 0 for loop "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
rewriter.replaceOp(op, op.getInitArgs());
return success();
}
- std::optional<APInt> diff =
- computeConstDiff(op.getLowerBound(), op.getUpperBound());
- if (!diff)
- return failure();
-
- // If the loop is known to have 0 iterations, remove it.
- bool zeroOrLessIterations =
- diff->isZero() || (!op.getUnsignedCmp() && diff->isNegative());
- if (zeroOrLessIterations) {
- rewriter.replaceOp(op, op.getInitArgs());
- return success();
- }
-
- std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
- if (!maybeStepValue)
- return failure();
-
- // If the loop is known to have 1 iteration, inline its body and remove the
- // loop.
- llvm::APInt stepValue = *maybeStepValue;
- if (stepValue.sge(*diff)) {
+ if (tripCount->getSExtValue() == 1) {
+ LDBG() << "SimplifyTrivialLoops tripCount is 1 for loop "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
SmallVector<Value, 4> blockArgs;
blockArgs.reserve(op.getInitArgs().size() + 1);
blockArgs.push_back(op.getLowerBound());
@@ -1072,11 +1069,14 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
Block &block = op.getRegion().front();
if (!llvm::hasSingleElement(block))
return failure();
- // If the loop is empty, iterates at least once, and only returns values
+ // The loop is empty and iterates at least once, if it only returns values
// defined outside of the loop, remove it and replace it with yield values.
if (llvm::any_of(op.getYieldedValues(),
[&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
return failure();
+ LDBG() << "SimplifyTrivialLoops empty body loop allows replacement with "
+ "yield operands for loop "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
rewriter.replaceOp(op, op.getYieldedValues());
return success();
}
@@ -1172,6 +1172,11 @@ Speculation::Speculatability ForOp::getSpeculatability() {
return Speculation::NotSpeculatable;
}
+std::optional<APInt> ForOp::getStaticTripCount() {
+ return constantTripCount(getLowerBound(), getUpperBound(), getStep(),
+ /*isSigned=*/!getUnsignedCmp(), computeUbMinusLb);
+}
+
//===----------------------------------------------------------------------===//
// ForallOp
//===----------------------------------------------------------------------===//
@@ -1768,7 +1773,8 @@ struct ForallOpSingleOrZeroIterationDimsFolder
for (auto [lb, ub, step, iv] :
llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
op.getMixedStep(), op.getInductionVars())) {
- auto numIterations = constantTripCount(lb, ub, step);
+ auto numIterations =
+ constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
if (numIterations.has_value()) {
// Remove the loop if it performs zero iterations.
if (*numIterations == 0) {
@@ -1839,7 +1845,8 @@ struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
op.getMixedStep(), op.getInductionVars())) {
if (iv.hasNUses(0))
continue;
- auto numIterations = constantTripCount(lb, ub, step);
+ auto numIterations =
+ constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
if (!numIterations.has_value() || numIterations.value() != 1) {
continue;
}
@@ -3084,7 +3091,8 @@ struct ParallelOpSingleOrZeroIterationDimsFolder
for (auto [lb, ub, step, iv] :
llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
op.getInductionVars())) {
- auto numIterations = constantTripCount(lb, ub, step);
+ auto numIterations =
+ constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
if (numIterations.has_value()) {
// Remove the loop if it performs zero iterations.
if (*numIterations == 0) {
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index f1203b2bdfee5..e3717aa9d940e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -94,7 +94,9 @@ static void specializeForLoopForUnrolling(ForOp op) {
OpBuilder b(op);
IRMapping map;
- Value constant = arith::ConstantIndexOp::create(b, op.getLoc(), minConstant);
+ Value constant = arith::ConstantOp::create(
+ b, op.getLoc(),
+ IntegerAttr::get(op.getUpperBound().getType(), minConstant));
Value cond = arith::CmpIOp::create(b, op.getLoc(), arith::CmpIPredicate::eq,
bound, constant);
map.map(bound, constant);
@@ -150,6 +152,9 @@ static LogicalResult peelForLoop(RewriterBase &b, ForOp forOp,
ValueRange{forOp.getLowerBound(),
forOp.getUpperBound(),
forOp.getStep()});
+ if (splitBound.getType() != forOp.getLowerBound().getType())
+ splitBound = b.createOrFold<arith::IndexCastOp>(
+ loc, forOp.getLowerBound().getType(), splitBound);
// Create ForOp for partial iteration.
b.setInsertionPointAfter(forOp);
@@ -230,6 +235,9 @@ LogicalResult mlir::scf::peelForLoopFirstIteration(RewriterBase &b, ForOp forOp,
auto loc = forOp.getLoc();
Value splitBound = b.createOrFold<AffineApplyOp>(
loc, ubMap, ValueRange{forOp.getLowerBound(), forOp.getStep()});
+ if (splitBound.getType() != forOp.getUpperBound().getType())
+ splitBound = b.createOrFold<arith::IndexCastOp>(
+ loc, forOp.getUpperBound().getType(), splitBound);
// Peel the first iteration.
IRMapping map;
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 684dff8121de6..57c4deab89321 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -22,6 +22,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/APInt.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/DebugLog.h"
@@ -291,14 +292,6 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
return arith::DivUIOp::create(builder, loc, sum, divisor);
}
-/// Returns the trip count of `forOp` if its' low bound, high bound and step are
-/// constants, or optional otherwise. Trip count is computed as
-/// ceilDiv(highBound - lowBound, step).
-static std::optional<int64_t> getConstantTripCount(scf::ForOp forOp) {
- return constantTripCount(forOp.getLowerBound(), forOp.getUpperBound(),
- forOp.getStep());
-}
-
/// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
/// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
/// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
@@ -377,7 +370,7 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
Value stepUnrolled;
bool generateEpilogueLoop = true;
- std::optional<int64_t> constTripCount = getConstantTripCount(forOp);
+ std::optional<APInt> constTripCount = forOp.getStaticTripCount();
if (constTripCount) {
// Constant loop bounds computation.
int64_t lbCst = getConstantIntValue(forOp.getLowerBound()).value();
@@ -391,7 +384,8 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
}
int64_t tripCountEvenMultiple =
- *constTripCount - (*constTripCount % unrollFactor);
+ constTripCount->getSExtValue() -
+ (constTripCount->getSExtValue() % unrollFactor);
int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
int64_t stepUnrolledCst = stepCst * unrollFactor;
@@ -487,15 +481,15 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
/// Unrolls this loop completely.
LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
IRRewriter rewriter(forOp.getContext());
- std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
+ std::optional<APInt> mayBeConstantTripCount = forOp.getStaticTripCount();
if (!mayBeConstantTripCount.has_value())
return failure();
- uint64_t tripCount = *mayBeConstantTripCount;
- if (tripCount == 0)
+ APInt &tripCount = *mayBeConstantTripCount;
+ if (tripCount.isZero())
return success();
- if (tripCount == 1)
+ if (tripCount.getSExtValue() == 1)
return forOp.promoteIfSingleIteration(rewriter);
- return loopUnrollByFactor(forOp, tripCount);
+ return loopUnrollByFactor(forOp, tripCount.getSExtValue());
}
/// Check if bounds of all inner loops are defined outside of `forOp`
@@ -535,18 +529,18 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
// Currently, only constant trip count that divided by the unroll factor is
// supported.
- std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
+ std::optional<APInt> tripCount = forOp.getStaticTripCount();
if (!tripCount.has_value()) {
// If the trip count is dynamic, do not unroll & jam.
LDBG() << "failed to unroll and jam: trip count could not be determined";
return failure();
}
- if (unrollJamFactor > *tripCount) {
+ if (unrollJamFactor > tripCount->getZExtValue()) {
LDBG() << "unroll and jam factor is greater than trip count, set factor to "
"trip "
"count";
- unrollJamFactor = *tripCount;
- } else if (*tripCount % unrollJamFactor != 0) {
+ unrollJamFactor = tripCount->getZExtValue();
+ } else if (tripCount->getSExtValue() % unrollJamFactor != 0) {
LDBG() << "failed to unroll and jam: unsupported trip count that is not a "
"multiple of unroll jam factor";
return failure();
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 34385d76f133a..e5deea6fa21ab 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -11,6 +11,7 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/MathExtras.h"
namespace mlir {
@@ -112,21 +113,30 @@ SmallVector<OpFoldResult> getAsIndexOpFoldResult(MLIRContext *ctx,
}
/// If ofr is a constant integer or an IntegerAttr, return the integer.
-std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
+/// The boolean indicates whether the value is an index type.
+std::optional<std::pair<APInt, bool>> getConstantAPIntValue(OpFoldResult ofr) {
// Case 1: Check for Constant integer.
if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
- APSInt intVal;
+ APInt intVal;
if (matchPattern(val, m_ConstantInt(&intVal)))
- return intVal.getSExtValue();
+ return std::make_pair(intVal, val.getType().isIndex());
return std::null...
[truncated]
|
6f99f0f
to
cd6dc69
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can we not compute the trip count? index
has "infinite" bitwidth, so it can never overflow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You may have missed this thread I opened: https://discourse.llvm.org/t/index-type-and-assumption-about-bitwidth/88287 :)
cd6dc69
to
81802df
Compare
… Index handling This patch adds a `getStaticTripCount` to the LoopLikeOpInterface, allowing loops to optionally return a static trip count when possible. This is implemented on SCF ForOp, revamping the implementation of `constantTripCount`, removing redundant duplicate implementations from SCF.cpp.
81802df
to
6b8fbb9
Compare
|
||
/// Helper function to compute the difference between two values. This is used | ||
/// by the loop implementations to compute the trip count. | ||
static std::optional<llvm::APSInt> computeUbMinusLb(Value lb, Value ub, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for a restricted form right? looks for ub = lb + const. E.g., this would fail in the case where ub and lb are constants.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it's not meant for constants.
return failure(); | ||
uint64_t tripCount = *mayBeConstantTripCount; | ||
if (tripCount == 0) | ||
APInt &tripCount = *mayBeConstantTripCount; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const reference?
// We infer it from the type of the bound if it isn't an index type. | ||
bool isIndex = true; | ||
auto getBitwidth = [&](OpFoldResult ofr) -> int { | ||
if (auto attr = dyn_cast<Attribute>(ofr)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd consider merging these two in one by using dyn_cast_if_present
if (auto attr = dyn_cast<Attribute>(ofr)) { | ||
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) { | ||
if (auto intType = dyn_cast<IntegerType>(intAttr.getType())) { | ||
isIndex = intType.isIndex(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not return a pair and then unpack?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
History of how the code was written :)
return intType.getWidth(); | ||
} | ||
} | ||
return IndexType::kInternalStorageBitWidth; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should isIndex be true here?
std::optional<std::pair<APInt, bool>> maybeLbCst = getConstantAPIntValue(lb); | ||
std::optional<std::pair<APInt, bool>> maybeUbCst = getConstantAPIntValue(ub); | ||
if (maybeLbCst) { | ||
// If one of the bounds is not a constant, we can't compute the trip count. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is one constant and if the other isn't constant it gives up? E.g., it would try A and A+k if A is unknown but not where A is known?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the expectation is that A+k should have been folded to a constant if A was a constant.
|
||
return llvm::divideCeilSigned(*ubConstant - *lbConstant, *stepConstant); | ||
/// Compute the difference between the upper and lower bound: either from the | ||
/// constant value or using the computeUbMinusLb callback. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why isn't this all in that callback? It seems one is handling special cases here which I'd expect to be handled when computing UB minus LB.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because we don't want everyone reimplementing the callback: we provide the "normal" constant diff logic here.
The user can inject their own "offset" op there. We need the callback because the utility can't depend on arith.add and the callback makes it possible to work with other adds.
Add post-merge review comments on llvm#158679
Add post-merge review comments on #158679
The commit from #158679 added getStaticTripCount to LoopLikeOpInterface, which the CIRLoopOpInterface uses. However, it doesn't include APInt. This patch adds an include for APInt to CIRLoopOpInterface, plus a 'using', as we're likely to run into this again.
This PR adapts to a change in `LoopLikeOpInterface` from llvm/llvm-project#158679, which changed the API for getting a statically known iteration count. The PR adapts to the change by using the accessor of the `ForOp` for exactly that information instead of relying on a free function that takes the properties/operands of the op. Signed-off-by: Ingo Müller <ingomueller@google.com>
This patch adds a
getStaticTripCount
to the LoopLikeOpInterface, allowing loops to optionally return a static trip count when possible. This is implemented on SCF ForOp, revamping the implementation ofconstantTripCount
, removing redundant duplicate implementations from SCF.cpp.