Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def ForOp : SCF_Op<"for",
[AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getInitsMutable", "getLoopResults", "getRegionIterArgs",
"getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
"getLoopUpperBounds", "getYieldedValuesMutable",
"getLoopUpperBounds", "getStaticTripCount", "getYieldedValuesMutable",
"promoteIfSingleIteration", "replaceWithAdditionalYields",
"yieldTiledValuesAndReplace"]>,
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
Expand Down
27 changes: 24 additions & 3 deletions mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val);
SmallVector<OpFoldResult> getAsIndexOpFoldResult(MLIRContext *ctx,
ArrayRef<int64_t> values);

/// If ofr is a constant integer or an IntegerAttr, return the integer.
/// The second return value indicates whether the value is an index type
/// and thus the bitwidth is not defined (the APInt will be set with 64bits).
std::optional<std::pair<APInt, bool>> getConstantAPIntValue(OpFoldResult ofr);
/// If ofr is a constant integer or an IntegerAttr, return the integer.
std::optional<int64_t> getConstantIntValue(OpFoldResult ofr);
/// If all ofrs are constant integers or IntegerAttrs, return the integers.
Expand Down Expand Up @@ -201,9 +205,26 @@ foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes);
LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides);

/// Return the number of iterations for a loop with a lower bound `lb`, upper
/// bound `ub` and step `step`.
std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
OpFoldResult step);
/// bound `ub` and step `step`. The `isSigned` flag indicates whether the loop
/// comparison between lb and ub is signed or unsigned. A negative step or a
/// lower bound greater than the upper bound are considered invalid and will
/// yield a zero trip count.
/// The `computeUbMinusLb` callback is invoked to compute the difference between
/// the upper and lower bound when not constant. It can be used by the client
/// to compute a static difference when the bounds are not constant.
///
/// For example, the following code:
///
/// %ub = arith.addi nsw %lb, %c16_i32 : i32
/// %1 = scf.for %arg0 = %lb to %ub ...
///
/// where %ub is computed as a static offset from %lb.
/// Note: the matched addition should be nsw/nuw (matching the loop comparison)
/// to avoid overflow, otherwise an overflow would imply a zero trip count.
std::optional<APInt> constantTripCount(
OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned,
llvm::function_ref<std::optional<llvm::APSInt>(Value, Value, bool)>
computeUbMinusLb);

/// Idiomatic saturated operations on values like offsets, sizes, and strides.
struct SaturatedInteger {
Expand Down
11 changes: 11 additions & 0 deletions mlir/include/mlir/Interfaces/LoopLikeInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,17 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
/*defaultImplementation=*/[{
return ::mlir::failure();
}]
>,
InterfaceMethod<[{
Compute the static trip count if possible.
}],
/*retTy=*/"::std::optional<APInt>",
/*methodName=*/"getStaticTripCount",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return ::std::nullopt;
}]
>
];

Expand Down
115 changes: 62 additions & 53 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,18 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/DebugLog.h"
#include <optional>

using namespace mlir;
using namespace mlir::scf;
Expand Down Expand Up @@ -105,6 +110,24 @@ static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
return nullptr;
}

/// Helper function to compute the difference between two values. This is used
/// by the loop implementations to compute the trip count.
static std::optional<llvm::APSInt> computeUbMinusLb(Value lb, Value ub,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for a restricted form right? looks for ub = lb + const. E.g., this would fail in the case where ub and lb are constants.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it's not meant for constants.

bool isSigned) {
llvm::APSInt diff;
auto addOp = ub.getDefiningOp<arith::AddIOp>();
if (!addOp)
return std::nullopt;
if ((isSigned && !addOp.hasNoSignedWrap()) ||
(!isSigned && !addOp.hasNoUnsignedWrap()))
return std::nullopt;

if (addOp.getLhs() != lb ||
!matchPattern(addOp.getRhs(), m_ConstantInt(&diff)))
return std::nullopt;
return diff;
}

//===----------------------------------------------------------------------===//
// ExecuteRegionOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -408,11 +431,19 @@ std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
/// Promotes the loop body of a forOp to its containing block if the forOp
/// it can be determined that the loop has a single iteration.
LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
std::optional<int64_t> tripCount =
constantTripCount(getLowerBound(), getUpperBound(), getStep());
if (!tripCount.has_value() || tripCount != 1)
std::optional<APInt> tripCount = getStaticTripCount();
LDBG() << "promoteIfSingleIteration tripCount is " << tripCount
<< " for loop "
<< OpWithFlags(getOperation(), OpPrintingFlags().skipRegions());
if (!tripCount.has_value() || tripCount->getSExtValue() > 1)
return failure();

if (*tripCount == 0) {
rewriter.replaceAllUsesWith(getResults(), getInitArgs());
rewriter.eraseOp(*this);
return success();
}

// Replace all results with the yielded values.
auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
Expand Down Expand Up @@ -646,7 +677,8 @@ SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
for (auto [lb, ub, step] :
llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
auto tripCount = constantTripCount(lb, ub, step);
auto tripCount =
constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
if (!tripCount.has_value() || *tripCount != 1)
return failure();
}
Expand Down Expand Up @@ -1003,27 +1035,6 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
}
};

/// Util function that tries to compute a constant diff between u and l.
/// Returns std::nullopt when the difference between two AffineValueMap is
/// dynamic.
static std::optional<APInt> computeConstDiff(Value l, Value u) {
IntegerAttr clb, cub;
if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
llvm::APInt lbValue = clb.getValue();
llvm::APInt ubValue = cub.getValue();
return ubValue - lbValue;
}

// Else a simple pattern match for x + c or c + x
llvm::APInt diff;
if (matchPattern(
u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) ||
matchPattern(
u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l))))
return diff;
return std::nullopt;
}

/// Rewriting pattern that erases loops that are known not to iterate, replaces
/// single-iteration loops with their bodies, and removes empty loops that
/// iterate at least once and only return values defined outside of the loop.
Expand All @@ -1032,34 +1043,21 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {

LogicalResult matchAndRewrite(ForOp op,
PatternRewriter &rewriter) const override {
// If the upper bound is the same as the lower bound, the loop does not
// iterate, just remove it.
if (op.getLowerBound() == op.getUpperBound()) {
std::optional<APInt> tripCount = op.getStaticTripCount();
if (!tripCount.has_value())
return rewriter.notifyMatchFailure(op,
"can't compute constant trip count");

if (tripCount->isZero()) {
LDBG() << "SimplifyTrivialLoops tripCount is 0 for loop "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
rewriter.replaceOp(op, op.getInitArgs());
return success();
}

std::optional<APInt> diff =
computeConstDiff(op.getLowerBound(), op.getUpperBound());
if (!diff)
return failure();

// If the loop is known to have 0 iterations, remove it.
bool zeroOrLessIterations =
diff->isZero() || (!op.getUnsignedCmp() && diff->isNegative());
if (zeroOrLessIterations) {
rewriter.replaceOp(op, op.getInitArgs());
return success();
}

std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
if (!maybeStepValue)
return failure();

// If the loop is known to have 1 iteration, inline its body and remove the
// loop.
llvm::APInt stepValue = *maybeStepValue;
if (stepValue.sge(*diff)) {
if (tripCount->getSExtValue() == 1) {
LDBG() << "SimplifyTrivialLoops tripCount is 1 for loop "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
SmallVector<Value, 4> blockArgs;
blockArgs.reserve(op.getInitArgs().size() + 1);
blockArgs.push_back(op.getLowerBound());
Expand All @@ -1072,11 +1070,14 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
Block &block = op.getRegion().front();
if (!llvm::hasSingleElement(block))
return failure();
// If the loop is empty, iterates at least once, and only returns values
// The loop is empty and iterates at least once, if it only returns values
// defined outside of the loop, remove it and replace it with yield values.
if (llvm::any_of(op.getYieldedValues(),
[&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
return failure();
LDBG() << "SimplifyTrivialLoops empty body loop allows replacement with "
"yield operands for loop "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
rewriter.replaceOp(op, op.getYieldedValues());
return success();
}
Expand Down Expand Up @@ -1172,6 +1173,11 @@ Speculation::Speculatability ForOp::getSpeculatability() {
return Speculation::NotSpeculatable;
}

std::optional<APInt> ForOp::getStaticTripCount() {
return constantTripCount(getLowerBound(), getUpperBound(), getStep(),
/*isSigned=*/!getUnsignedCmp(), computeUbMinusLb);
}

//===----------------------------------------------------------------------===//
// ForallOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1768,7 +1774,8 @@ struct ForallOpSingleOrZeroIterationDimsFolder
for (auto [lb, ub, step, iv] :
llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
op.getMixedStep(), op.getInductionVars())) {
auto numIterations = constantTripCount(lb, ub, step);
auto numIterations =
constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
if (numIterations.has_value()) {
// Remove the loop if it performs zero iterations.
if (*numIterations == 0) {
Expand Down Expand Up @@ -1839,7 +1846,8 @@ struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
op.getMixedStep(), op.getInductionVars())) {
if (iv.hasNUses(0))
continue;
auto numIterations = constantTripCount(lb, ub, step);
auto numIterations =
constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
if (!numIterations.has_value() || numIterations.value() != 1) {
continue;
}
Expand Down Expand Up @@ -3084,7 +3092,8 @@ struct ParallelOpSingleOrZeroIterationDimsFolder
for (auto [lb, ub, step, iv] :
llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
op.getInductionVars())) {
auto numIterations = constantTripCount(lb, ub, step);
auto numIterations =
constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
if (numIterations.has_value()) {
// Remove the loop if it performs zero iterations.
if (*numIterations == 0) {
Expand Down
32 changes: 13 additions & 19 deletions mlir/lib/Dialect/SCF/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/DebugLog.h"
Expand Down Expand Up @@ -291,14 +292,6 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
return arith::DivUIOp::create(builder, loc, sum, divisor);
}

/// Returns the trip count of `forOp` if its' low bound, high bound and step are
/// constants, or optional otherwise. Trip count is computed as
/// ceilDiv(highBound - lowBound, step).
static std::optional<int64_t> getConstantTripCount(scf::ForOp forOp) {
return constantTripCount(forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep());
}

/// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
/// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
/// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
Expand Down Expand Up @@ -377,7 +370,7 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
Value stepUnrolled;
bool generateEpilogueLoop = true;

std::optional<int64_t> constTripCount = getConstantTripCount(forOp);
std::optional<APInt> constTripCount = forOp.getStaticTripCount();
if (constTripCount) {
// Constant loop bounds computation.
int64_t lbCst = getConstantIntValue(forOp.getLowerBound()).value();
Expand All @@ -391,7 +384,8 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
}

int64_t tripCountEvenMultiple =
*constTripCount - (*constTripCount % unrollFactor);
constTripCount->getSExtValue() -
(constTripCount->getSExtValue() % unrollFactor);
int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
int64_t stepUnrolledCst = stepCst * unrollFactor;

Expand Down Expand Up @@ -487,15 +481,15 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
/// Unrolls this loop completely.
LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
IRRewriter rewriter(forOp.getContext());
std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
std::optional<APInt> mayBeConstantTripCount = forOp.getStaticTripCount();
if (!mayBeConstantTripCount.has_value())
return failure();
uint64_t tripCount = *mayBeConstantTripCount;
if (tripCount == 0)
APInt &tripCount = *mayBeConstantTripCount;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const reference?

if (tripCount.isZero())
return success();
if (tripCount == 1)
if (tripCount.getSExtValue() == 1)
return forOp.promoteIfSingleIteration(rewriter);
return loopUnrollByFactor(forOp, tripCount);
return loopUnrollByFactor(forOp, tripCount.getSExtValue());
}

/// Check if bounds of all inner loops are defined outside of `forOp`
Expand Down Expand Up @@ -535,18 +529,18 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,

// Currently, only constant trip count that divided by the unroll factor is
// supported.
std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
std::optional<APInt> tripCount = forOp.getStaticTripCount();
if (!tripCount.has_value()) {
// If the trip count is dynamic, do not unroll & jam.
LDBG() << "failed to unroll and jam: trip count could not be determined";
return failure();
}
if (unrollJamFactor > *tripCount) {
if (unrollJamFactor > tripCount->getZExtValue()) {
LDBG() << "unroll and jam factor is greater than trip count, set factor to "
"trip "
"count";
unrollJamFactor = *tripCount;
} else if (*tripCount % unrollJamFactor != 0) {
unrollJamFactor = tripCount->getZExtValue();
} else if (tripCount->getSExtValue() % unrollJamFactor != 0) {
LDBG() << "failed to unroll and jam: unsupported trip count that is not a "
"multiple of unroll jam factor";
return failure();
Expand Down
Loading