diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index d3c01c31636a7..fadd3fc10bfc4 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -152,7 +152,7 @@ def ForOp : SCF_Op<"for", [AutomaticAllocationScope, DeclareOpInterfaceMethods, AllTypesMatch<["lowerBound", "upperBound", "step"]>, diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h index 77c376fb9973a..2e7f85cce4654 100644 --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -105,6 +105,10 @@ OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val); SmallVector getAsIndexOpFoldResult(MLIRContext *ctx, ArrayRef values); +/// If ofr is a constant integer or an IntegerAttr, return the integer. +/// The second return value indicates whether the value is an index type +/// and thus the bitwidth is not defined (the APInt will be set with 64bits). +std::optional> getConstantAPIntValue(OpFoldResult ofr); /// If ofr is a constant integer or an IntegerAttr, return the integer. std::optional getConstantIntValue(OpFoldResult ofr); /// If all ofrs are constant integers or IntegerAttrs, return the integers. @@ -201,9 +205,26 @@ foldDynamicOffsetSizeList(SmallVectorImpl &offsetsOrSizes); LogicalResult foldDynamicStrideList(SmallVectorImpl &strides); /// Return the number of iterations for a loop with a lower bound `lb`, upper -/// bound `ub` and step `step`. -std::optional constantTripCount(OpFoldResult lb, OpFoldResult ub, - OpFoldResult step); +/// bound `ub` and step `step`. The `isSigned` flag indicates whether the loop +/// comparison between lb and ub is signed or unsigned. A negative step or a +/// lower bound greater than the upper bound are considered invalid and will +/// yield a zero trip count. +/// The `computeUbMinusLb` callback is invoked to compute the difference between +/// the upper and lower bound when not constant. It can be used by the client +/// to compute a static difference when the bounds are not constant. +/// +/// For example, the following code: +/// +/// %ub = arith.addi nsw %lb, %c16_i32 : i32 +/// %1 = scf.for %arg0 = %lb to %ub ... +/// +/// where %ub is computed as a static offset from %lb. +/// Note: the matched addition should be nsw/nuw (matching the loop comparison) +/// to avoid overflow, otherwise an overflow would imply a zero trip count. +std::optional constantTripCount( + OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, + llvm::function_ref(Value, Value, bool)> + computeUbMinusLb); /// Idiomatic saturated operations on values like offsets, sizes, and strides. struct SaturatedInteger { diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td index 6c95b4802837b..cfd15a7746e19 100644 --- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td +++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td @@ -232,6 +232,17 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> { /*defaultImplementation=*/[{ return ::mlir::failure(); }] + >, + InterfaceMethod<[{ + Compute the static trip count if possible. + }], + /*retTy=*/"::std::optional", + /*methodName=*/"getStaticTripCount", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return ::std::nullopt; + }] > ]; diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index c35989ecba6cd..ae55eaded0554 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -19,6 +19,8 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/ParallelCombiningOpInterface.h" @@ -26,6 +28,9 @@ #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/DebugLog.h" +#include using namespace mlir; using namespace mlir::scf; @@ -105,6 +110,24 @@ static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, return nullptr; } +/// Helper function to compute the difference between two values. This is used +/// by the loop implementations to compute the trip count. +static std::optional computeUbMinusLb(Value lb, Value ub, + bool isSigned) { + llvm::APSInt diff; + auto addOp = ub.getDefiningOp(); + if (!addOp) + return std::nullopt; + if ((isSigned && !addOp.hasNoSignedWrap()) || + (!isSigned && !addOp.hasNoUnsignedWrap())) + return std::nullopt; + + if (addOp.getLhs() != lb || + !matchPattern(addOp.getRhs(), m_ConstantInt(&diff))) + return std::nullopt; + return diff; +} + //===----------------------------------------------------------------------===// // ExecuteRegionOp //===----------------------------------------------------------------------===// @@ -408,11 +431,19 @@ std::optional ForOp::getLoopResults() { return getResults(); } /// Promotes the loop body of a forOp to its containing block if the forOp /// it can be determined that the loop has a single iteration. LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) { - std::optional tripCount = - constantTripCount(getLowerBound(), getUpperBound(), getStep()); - if (!tripCount.has_value() || tripCount != 1) + std::optional tripCount = getStaticTripCount(); + LDBG() << "promoteIfSingleIteration tripCount is " << tripCount + << " for loop " + << OpWithFlags(getOperation(), OpPrintingFlags().skipRegions()); + if (!tripCount.has_value() || tripCount->getSExtValue() > 1) return failure(); + if (*tripCount == 0) { + rewriter.replaceAllUsesWith(getResults(), getInitArgs()); + rewriter.eraseOp(*this); + return success(); + } + // Replace all results with the yielded values. auto yieldOp = cast(getBody()->getTerminator()); rewriter.replaceAllUsesWith(getResults(), getYieldedValues()); @@ -646,7 +677,8 @@ SmallVector ForallOp::getLoopRegions() { return {&getRegion()}; } LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) { for (auto [lb, ub, step] : llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) { - auto tripCount = constantTripCount(lb, ub, step); + auto tripCount = + constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb); if (!tripCount.has_value() || *tripCount != 1) return failure(); } @@ -1003,27 +1035,6 @@ struct ForOpIterArgsFolder : public OpRewritePattern { } }; -/// Util function that tries to compute a constant diff between u and l. -/// Returns std::nullopt when the difference between two AffineValueMap is -/// dynamic. -static std::optional computeConstDiff(Value l, Value u) { - IntegerAttr clb, cub; - if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) { - llvm::APInt lbValue = clb.getValue(); - llvm::APInt ubValue = cub.getValue(); - return ubValue - lbValue; - } - - // Else a simple pattern match for x + c or c + x - llvm::APInt diff; - if (matchPattern( - u, m_Op(matchers::m_Val(l), m_ConstantInt(&diff))) || - matchPattern( - u, m_Op(m_ConstantInt(&diff), matchers::m_Val(l)))) - return diff; - return std::nullopt; -} - /// Rewriting pattern that erases loops that are known not to iterate, replaces /// single-iteration loops with their bodies, and removes empty loops that /// iterate at least once and only return values defined outside of the loop. @@ -1032,34 +1043,21 @@ struct SimplifyTrivialLoops : public OpRewritePattern { LogicalResult matchAndRewrite(ForOp op, PatternRewriter &rewriter) const override { - // If the upper bound is the same as the lower bound, the loop does not - // iterate, just remove it. - if (op.getLowerBound() == op.getUpperBound()) { + std::optional tripCount = op.getStaticTripCount(); + if (!tripCount.has_value()) + return rewriter.notifyMatchFailure(op, + "can't compute constant trip count"); + + if (tripCount->isZero()) { + LDBG() << "SimplifyTrivialLoops tripCount is 0 for loop " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); rewriter.replaceOp(op, op.getInitArgs()); return success(); } - std::optional diff = - computeConstDiff(op.getLowerBound(), op.getUpperBound()); - if (!diff) - return failure(); - - // If the loop is known to have 0 iterations, remove it. - bool zeroOrLessIterations = - diff->isZero() || (!op.getUnsignedCmp() && diff->isNegative()); - if (zeroOrLessIterations) { - rewriter.replaceOp(op, op.getInitArgs()); - return success(); - } - - std::optional maybeStepValue = op.getConstantStep(); - if (!maybeStepValue) - return failure(); - - // If the loop is known to have 1 iteration, inline its body and remove the - // loop. - llvm::APInt stepValue = *maybeStepValue; - if (stepValue.sge(*diff)) { + if (tripCount->getSExtValue() == 1) { + LDBG() << "SimplifyTrivialLoops tripCount is 1 for loop " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); SmallVector blockArgs; blockArgs.reserve(op.getInitArgs().size() + 1); blockArgs.push_back(op.getLowerBound()); @@ -1072,11 +1070,14 @@ struct SimplifyTrivialLoops : public OpRewritePattern { Block &block = op.getRegion().front(); if (!llvm::hasSingleElement(block)) return failure(); - // If the loop is empty, iterates at least once, and only returns values + // The loop is empty and iterates at least once, if it only returns values // defined outside of the loop, remove it and replace it with yield values. if (llvm::any_of(op.getYieldedValues(), [&](Value v) { return !op.isDefinedOutsideOfLoop(v); })) return failure(); + LDBG() << "SimplifyTrivialLoops empty body loop allows replacement with " + "yield operands for loop " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); rewriter.replaceOp(op, op.getYieldedValues()); return success(); } @@ -1172,6 +1173,11 @@ Speculation::Speculatability ForOp::getSpeculatability() { return Speculation::NotSpeculatable; } +std::optional ForOp::getStaticTripCount() { + return constantTripCount(getLowerBound(), getUpperBound(), getStep(), + /*isSigned=*/!getUnsignedCmp(), computeUbMinusLb); +} + //===----------------------------------------------------------------------===// // ForallOp //===----------------------------------------------------------------------===// @@ -1768,7 +1774,8 @@ struct ForallOpSingleOrZeroIterationDimsFolder for (auto [lb, ub, step, iv] : llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(), op.getMixedStep(), op.getInductionVars())) { - auto numIterations = constantTripCount(lb, ub, step); + auto numIterations = + constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb); if (numIterations.has_value()) { // Remove the loop if it performs zero iterations. if (*numIterations == 0) { @@ -1839,7 +1846,8 @@ struct ForallOpReplaceConstantInductionVar : public OpRewritePattern { op.getMixedStep(), op.getInductionVars())) { if (iv.hasNUses(0)) continue; - auto numIterations = constantTripCount(lb, ub, step); + auto numIterations = + constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb); if (!numIterations.has_value() || numIterations.value() != 1) { continue; } @@ -3084,7 +3092,8 @@ struct ParallelOpSingleOrZeroIterationDimsFolder for (auto [lb, ub, step, iv] : llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(), op.getInductionVars())) { - auto numIterations = constantTripCount(lb, ub, step); + auto numIterations = + constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb); if (numIterations.has_value()) { // Remove the loop if it performs zero iterations. if (*numIterations == 0) { diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 684dff8121de6..57c4deab89321 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/APInt.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/DebugLog.h" @@ -291,14 +292,6 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, return arith::DivUIOp::create(builder, loc, sum, divisor); } -/// Returns the trip count of `forOp` if its' low bound, high bound and step are -/// constants, or optional otherwise. Trip count is computed as -/// ceilDiv(highBound - lowBound, step). -static std::optional getConstantTripCount(scf::ForOp forOp) { - return constantTripCount(forOp.getLowerBound(), forOp.getUpperBound(), - forOp.getStep()); -} - /// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with /// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap /// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each @@ -377,7 +370,7 @@ FailureOr mlir::loopUnrollByFactor( Value stepUnrolled; bool generateEpilogueLoop = true; - std::optional constTripCount = getConstantTripCount(forOp); + std::optional constTripCount = forOp.getStaticTripCount(); if (constTripCount) { // Constant loop bounds computation. int64_t lbCst = getConstantIntValue(forOp.getLowerBound()).value(); @@ -391,7 +384,8 @@ FailureOr mlir::loopUnrollByFactor( } int64_t tripCountEvenMultiple = - *constTripCount - (*constTripCount % unrollFactor); + constTripCount->getSExtValue() - + (constTripCount->getSExtValue() % unrollFactor); int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst; int64_t stepUnrolledCst = stepCst * unrollFactor; @@ -487,15 +481,15 @@ FailureOr mlir::loopUnrollByFactor( /// Unrolls this loop completely. LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) { IRRewriter rewriter(forOp.getContext()); - std::optional mayBeConstantTripCount = getConstantTripCount(forOp); + std::optional mayBeConstantTripCount = forOp.getStaticTripCount(); if (!mayBeConstantTripCount.has_value()) return failure(); - uint64_t tripCount = *mayBeConstantTripCount; - if (tripCount == 0) + APInt &tripCount = *mayBeConstantTripCount; + if (tripCount.isZero()) return success(); - if (tripCount == 1) + if (tripCount.getSExtValue() == 1) return forOp.promoteIfSingleIteration(rewriter); - return loopUnrollByFactor(forOp, tripCount); + return loopUnrollByFactor(forOp, tripCount.getSExtValue()); } /// Check if bounds of all inner loops are defined outside of `forOp` @@ -535,18 +529,18 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp, // Currently, only constant trip count that divided by the unroll factor is // supported. - std::optional tripCount = getConstantTripCount(forOp); + std::optional tripCount = forOp.getStaticTripCount(); if (!tripCount.has_value()) { // If the trip count is dynamic, do not unroll & jam. LDBG() << "failed to unroll and jam: trip count could not be determined"; return failure(); } - if (unrollJamFactor > *tripCount) { + if (unrollJamFactor > tripCount->getZExtValue()) { LDBG() << "unroll and jam factor is greater than trip count, set factor to " "trip " "count"; - unrollJamFactor = *tripCount; - } else if (*tripCount % unrollJamFactor != 0) { + unrollJamFactor = tripCount->getZExtValue(); + } else if (tripCount->getSExtValue() % unrollJamFactor != 0) { LDBG() << "failed to unroll and jam: unsupported trip count that is not a " "multiple of unroll jam factor"; return failure(); diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp index 34385d76f133a..5048b19b2891f 100644 --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -11,6 +11,7 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/APSInt.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/MathExtras.h" namespace mlir { @@ -112,21 +113,30 @@ SmallVector getAsIndexOpFoldResult(MLIRContext *ctx, } /// If ofr is a constant integer or an IntegerAttr, return the integer. -std::optional getConstantIntValue(OpFoldResult ofr) { +/// The boolean indicates whether the value is an index type. +std::optional> getConstantAPIntValue(OpFoldResult ofr) { // Case 1: Check for Constant integer. if (auto val = llvm::dyn_cast_if_present(ofr)) { - APSInt intVal; + APInt intVal; if (matchPattern(val, m_ConstantInt(&intVal))) - return intVal.getSExtValue(); + return std::make_pair(intVal, val.getType().isIndex()); return std::nullopt; } // Case 2: Check for IntegerAttr. Attribute attr = llvm::dyn_cast_if_present(ofr); if (auto intAttr = dyn_cast_or_null(attr)) - return intAttr.getValue().getSExtValue(); + return std::make_pair(intAttr.getValue(), intAttr.getType().isIndex()); return std::nullopt; } +/// If ofr is a constant integer or an IntegerAttr, return the integer. +std::optional getConstantIntValue(OpFoldResult ofr) { + std::optional> apInt = getConstantAPIntValue(ofr); + if (!apInt) + return std::nullopt; + return apInt->first.getSExtValue(); +} + std::optional> getConstantIntValues(ArrayRef ofrs) { bool failed = false; @@ -264,22 +274,108 @@ getValuesSortedByKey(ArrayRef keys, ArrayRef values, /// Return the number of iterations for a loop with a lower bound `lb`, upper /// bound `ub` and step `step`. -std::optional constantTripCount(OpFoldResult lb, OpFoldResult ub, - OpFoldResult step) { +std::optional constantTripCount( + OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, + llvm::function_ref(Value, Value, bool)> + computeUbMinusLb) { + // This is the bitwidth used to return 0 when loop does not execute. + // We infer it from the type of the bound if it isn't an index type. + bool isIndex = true; + auto getBitwidth = [&](OpFoldResult ofr) -> int { + if (auto attr = dyn_cast(ofr)) { + if (auto intAttr = dyn_cast(attr)) { + if (auto intType = dyn_cast(intAttr.getType())) { + isIndex = intType.isIndex(); + return intType.getWidth(); + } + } + } else { + auto val = cast(ofr); + if (auto intType = dyn_cast(val.getType())) { + isIndex = intType.isIndex(); + return intType.getWidth(); + } + } + return IndexType::kInternalStorageBitWidth; + }; + int bitwidth = getBitwidth(lb); + assert(bitwidth == getBitwidth(ub) && + "lb and ub must have the same bitwidth"); if (lb == ub) - return 0; + return APInt(bitwidth, 0); + + std::optional> maybeStepCst = + getConstantAPIntValue(step); + + if (maybeStepCst) { + auto &stepCst = maybeStepCst->first; + assert(static_cast(stepCst.getBitWidth()) == bitwidth && + "step must have the same bitwidth as lb and ub"); + if (stepCst.isZero()) + return stepCst; + if (stepCst.isNegative()) + return APInt(bitwidth, 0); + } - std::optional lbConstant = getConstantIntValue(lb); - if (!lbConstant) - return std::nullopt; - std::optional ubConstant = getConstantIntValue(ub); - if (!ubConstant) - return std::nullopt; - std::optional stepConstant = getConstantIntValue(step); - if (!stepConstant || *stepConstant == 0) - return std::nullopt; + if (isIndex) { + LDBG() + << "Computing loop trip count for index type may break with overflow"; + // TODO: we can't compute the trip count for index type. We should fix this + // but too many tests are failing right now. + // return {}; + } - return llvm::divideCeilSigned(*ubConstant - *lbConstant, *stepConstant); + /// Compute the difference between the upper and lower bound: either from the + /// constant value or using the computeUbMinusLb callback. + llvm::APSInt diff; + std::optional> maybeLbCst = getConstantAPIntValue(lb); + std::optional> maybeUbCst = getConstantAPIntValue(ub); + if (maybeLbCst) { + // If one of the bounds is not a constant, we can't compute the trip count. + if (!maybeUbCst) + return std::nullopt; + APSInt lbCst(maybeLbCst->first, /*isUnsigned=*/!isSigned); + APSInt ubCst(maybeUbCst->first, /*isUnsigned=*/!isSigned); + if (!maybeUbCst) + return std::nullopt; + if (ubCst <= lbCst) { + LDBG() << "constantTripCount is 0 because ub <= lb (" << lbCst << "(" + << lbCst.getBitWidth() << ") <= " << ubCst << "(" + << ubCst.getBitWidth() << "), " + << (isSigned ? "isSigned" : "isUnsigned") << ")"; + return APInt(bitwidth, 0); + } + diff = ubCst - lbCst; + } else { + if (maybeUbCst) + return std::nullopt; + + /// Non-constant bound, let's try to compute the difference between the + /// upper and lower bound + std::optional maybeDiff = + computeUbMinusLb(cast(lb), cast(ub), isSigned); + if (!maybeDiff) + return std::nullopt; + diff = *maybeDiff; + } + LDBG() << "constantTripCount: " << (isSigned ? "isSigned" : "isUnsigned") + << ", ub-lb: " << diff << "(" << diff.getBitWidth() << "b)"; + if (diff.isNegative()) { + LDBG() << "constantTripCount is 0 because ub-lb diff is negative"; + return APInt(bitwidth, 0); + } + if (!maybeStepCst) { + LDBG() + << "constantTripCount can't be computed because step is not a constant"; + return std::nullopt; + } + auto &stepCst = maybeStepCst->first; + llvm::APInt tripCount = diff.sdiv(stepCst); + llvm::APInt r = diff.srem(stepCst); + if (!r.isZero()) + tripCount = tripCount + 1; + LDBG() << "constantTripCount found: " << tripCount; + return tripCount; } bool hasValidSizesOffsets(SmallVector sizesOrOffsets) { diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 4ad2da8388eb7..5e89f74075252 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -749,7 +749,7 @@ func.func @replace_single_iteration_const_diff(%arg0 : index) { // CHECK-NEXT: %[[CST:.*]] = arith.constant 2 %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index - %5 = arith.addi %arg0, %c1 : index + %5 = arith.addi %arg0, %c1 overflow : index // CHECK-NOT: scf.for scf.for %arg2 = %arg0 to %5 step %c1 { // CHECK-NEXT: %[[MUL:.*]] = arith.muli %[[A0]], %[[CST]] @@ -1933,8 +1933,9 @@ func.func @index_switch_fold_no_res() { // ----- +// Step 0 is invalid, the loop is eliminated. // CHECK-LABEL: func @scf_for_all_step_size_0() -// CHECK: scf.forall (%{{.*}}) = (0) to (1) step (0) +// CHECK-NOT: scf.forall func.func @scf_for_all_step_size_0() { %x = arith.constant 0 : index scf.forall (%i, %j) = (0, 4) to (1, 5) step (%x, 8) { diff --git a/mlir/test/Dialect/SCF/for-loop-peeling.mlir b/mlir/test/Dialect/SCF/for-loop-peeling.mlir index f59b79603b489..c08d6f6968bc4 100644 --- a/mlir/test/Dialect/SCF/for-loop-peeling.mlir +++ b/mlir/test/Dialect/SCF/for-loop-peeling.mlir @@ -293,10 +293,9 @@ func.func @regression(%arg0: memref, %arg1: index) { // ----- // Regression test: Make sure that we do not crash. - +// The step is 0, the loop will be eliminated. // CHECK-LABEL: func @zero_step( -// CHECK: scf.for -// CHECK: scf.for +// CHECK-NOT: scf.for func.func @zero_step(%arg0: memref) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index diff --git a/mlir/test/Dialect/SCF/trip_count.mlir b/mlir/test/Dialect/SCF/trip_count.mlir new file mode 100644 index 0000000000000..54883d7bb874c --- /dev/null +++ b/mlir/test/Dialect/SCF/trip_count.mlir @@ -0,0 +1,702 @@ +// RUN: mlir-opt %s -test-scf-for-utils --split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @trip_count_index_zero_to_zero( +func.func @trip_count_index_zero_to_zero(%a : i32, %b : i32) -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + // CHECK: "test.trip-count" = 0 + %r = scf.for %i = %c0 to %c0 step %c1 iter_args(%0 = %a) -> i32 { + scf.yield %b : i32 + } + return %r : i32 +} + +// ----- + +// CHECK-LABEL: func.func @trip_count_index_zero_to_zero_step_dyn( +func.func @trip_count_index_zero_to_zero_step_dyn(%a : i32, %b : i32, %step : index) -> i32 { + %c0 = arith.constant 0 : index + + // CHECK: "test.trip-count" = 0 + %r = scf.for %i = %c0 to %c0 step %step iter_args(%0 = %a) -> i32 { + scf.yield %b : i32 + } + return %r : i32 +} + +// ----- + +// CHECK-LABEL: func.func @trip_count_i32_zero_to_zero( +func.func @trip_count_i32_zero_to_zero(%a : i32, %b : i32) -> i32 { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + + // CHECK: "test.trip-count" = 0 + %r = scf.for %i = %c0 to %c0 step %c1 iter_args(%0 = %a) -> i32 : i32 { + scf.yield %b : i32 + } + return %r : i32 +} + +// ----- + + +// CHECK-LABEL: func.func @trip_count_i32_zero_to_zero_step_dyn( +func.func @trip_count_i32_zero_to_zero_step_dyn(%a : i32, %b : i32, %step : i32) -> i32 { + %c0 = arith.constant 0 : i32 + + // CHECK: "test.trip-count" = 0 + %r = scf.for %i = %c0 to %c0 step %step iter_args(%0 = %a) -> i32 : i32 { + scf.yield %b : i32 + } + return %r : i32 +} + +// ----- + +// CHECK-LABEL: func.func @trip_count_index_one_to_zero( +func.func @trip_count_index_one_to_zero(%a : i32, %b : i32) -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + // Index type has a unknown bitwidth, we can't compute a loop tripcount + // in theory because of overflow concerns. + // CHECK: "test.trip-count" = 0 + %r2 = scf.for %i = %c1 to %c0 step %c1 iter_args(%0 = %a) -> i32 { + scf.yield %b : i32 + } + return %r2 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @trip_count_i32_one_to_zero( +func.func @trip_count_i32_one_to_zero(%a : i32, %b : i32) -> i32 { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + + // CHECK: "test.trip-count" = 0 + %r2 = scf.for %i = %c1 to %c0 step %c1 iter_args(%0 = %a) -> i32 : i32 { + scf.yield %b : i32 + } + return %r2 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @trip_count_i32_one_to_zero_dyn_step( +func.func @trip_count_i32_one_to_zero_dyn_step(%a : i32, %b : i32, %step : i32) -> i32 { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + + // CHECK: "test.trip-count" = 0 + %r2 = scf.for %i = %c1 to %c0 step %step iter_args(%0 = %a) -> i32 : i32 { + scf.yield %b : i32 + } + return %r2 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @trip_count_index_negative_step( +func.func @trip_count_index_negative_step(%a : i32, %b : i32) -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c-1 = arith.constant -1 : index + + // Negative step is invalid, loop won't execute. + // CHECK: "test.trip-count" = 0 + %r3 = scf.for %i = %c1 to %c0 step %c-1 iter_args(%0 = %a) -> i32 { + scf.yield %b : i32 + } + return %r3 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @trip_count_i32_negative_step( +func.func @trip_count_i32_negative_step(%a : i32, %b : i32) -> i32 { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %c-1 = arith.constant -1 : i32 + + // Negative step is invalid, loop won't execute. + // CHECK: "test.trip-count" = 0 + %r3 = scf.for %i = %c1 to %c0 step %c-1 iter_args(%0 = %a) -> i32 : i32 { + scf.yield %b : i32 + } + return %r3 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @trip_count_index_negative_step_unsigned_loop( +func.func @trip_count_index_negative_step_unsigned_loop(%a : i32, %b : i32) -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c-1 = arith.constant -1 : index + + // Negative step is invalid, loop won't execute. + // CHECK: "test.trip-count" = 0 + %r3 = scf.for unsigned %i = %c1 to %c0 step %c-1 iter_args(%0 = %a) -> i32 { + scf.yield %b : i32 + } + return %r3 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @trip_count_i32_negative_step_unsigned_loop( +func.func @trip_count_i32_negative_step_unsigned_loop(%a : i32, %b : i32) -> i32 { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %c-1 = arith.constant -1 : i32 + + // Negative step is invalid, loop won't execute. + // CHECK: "test.trip-count" = 0 + %r3 = scf.for unsigned %i = %c1 to %c0 step %c-1 iter_args(%0 = %a) -> i32 : i32 { + scf.yield %b : i32 + } + return %r3 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @trip_count_index_normal_loop( +func.func @trip_count_index_normal_loop(%a : i32, %b : i32) -> i32 { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c10 = arith.constant 10 : index + + // Index type has a unknown bitwidth, we can't compute a loop tripcount + // in theory because of overflow concerns. + // CHECK: "test.trip-count" = 5 + %r4 = scf.for %i = %c0 to %c10 step %c2 iter_args(%0 = %a) -> i32 { + scf.yield %b : i32 + } + return %r4 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @trip_count_i32_normal_loop( +func.func @trip_count_i32_normal_loop(%a : i32, %b : i32) -> i32 { + %c0 = arith.constant 0 : i32 + %c2 = arith.constant 2 : i32 + %c10 = arith.constant 10 : i32 + + // Normal loop + // CHECK: "test.trip-count" = 5 + %r4 = scf.for %i = %c0 to %c10 step %c2 iter_args(%0 = %a) -> i32 : i32 { + scf.yield %b : i32 + } + return %r4 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @trip_count_index_signed_crossing_zero( +func.func @trip_count_index_signed_crossing_zero(%a : i32, %b : i32) -> i32 { + %c-1 = arith.constant -1 : index + %c1 = arith.constant 1 : index + + // Index type has a unknown bitwidth, we can't compute a loop tripcount + // in theory because of overflow concerns. + // CHECK: "test.trip-count" = 2 + %r5 = scf.for %i = %c-1 to %c1 step %c1 iter_args(%0 = %a) -> i32 { + scf.yield %b : i32 + } + return %r5 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @trip_count_i32_signed_crossing_zero( +func.func @trip_count_i32_signed_crossing_zero(%a : i32, %b : i32) -> i32 { + %c-1 = arith.constant -1 : i32 + %c1 = arith.constant 1 : i32 + + // This loop execute with signed comparison, but not unsigned, because it is crossing 0. + // CHECK: "test.trip-count" = 2 + %r5 = scf.for %i = %c-1 to %c1 step %c1 iter_args(%0 = %a) -> i32 : i32 { + scf.yield %b : i32 + } + return %r5 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @trip_count_index_unsigned_crossing_zero( +func.func @trip_count_index_unsigned_crossing_zero(%a : i32, %b : i32) -> i32 { + %c-1 = arith.constant -1 : index + %c1 = arith.constant 1 : index + + // Index type has a unknown bitwidth, we can't compute a loop tripcount + // in theory because of overflow concerns. + // CHECK: "test.trip-count" = 0 + %r6 = scf.for unsigned %i = %c-1 to %c1 step %c1 iter_args(%0 = %a) -> i32 { + scf.yield %b : i32 + } + return %r6 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @trip_count_i32_unsigned_crossing_zero( +func.func @trip_count_i32_unsigned_crossing_zero(%a : i32, %b : i32) -> i32 { + %c-1 = arith.constant -1 : i32 + %c1 = arith.constant 1 : i32 + + // This loop execute with signed comparison, but not unsigned, because it is crossing 0. + // CHECK: "test.trip-count" = 0 + %r6 = scf.for unsigned %i = %c-1 to %c1 step %c1 iter_args(%0 = %a) -> i32 : i32 { + scf.yield %b : i32 + } + return %r6 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @trip_count_i32_unsigned_crossing_zero_dyn_step( +func.func @trip_count_i32_unsigned_crossing_zero_dyn_step(%a : i32, %b : i32, %step : i32) -> i32 { + %c-1 = arith.constant -1 : i32 + %c1 = arith.constant 1 : i32 + + // This loop execute with signed comparison, but not unsigned, because it is crossing 0. + // CHECK: "test.trip-count" = 0 + %r6 = scf.for unsigned %i = %c-1 to %c1 step %step iter_args(%0 = %a) -> i32 : i32 { + scf.yield %b : i32 + } + return %r6 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @trip_count_index_negative_bounds_signed( +func.func @trip_count_index_negative_bounds_signed(%a : i32, %b : i32) -> i32 { + %c-10 = arith.constant -10 : index + %c-1 = arith.constant -1 : index + %c2 = arith.constant 2 : index + + // Index type has a unknown bitwidth, we can't compute a loop tripcount + // in theory because of overflow concerns. + // CHECK: "test.trip-count" = 5 + %r7 = scf.for %i = %c-10 to %c-1 step %c2 iter_args(%0 = %a) -> i32 { + scf.yield %b : i32 + } + return %r7 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @trip_count_i32_negative_bounds_signed( +func.func @trip_count_i32_negative_bounds_signed(%a : i32, %b : i32) -> i32 { + %c-10 = arith.constant -10 : i32 + %c-1 = arith.constant -1 : i32 + %c2 = arith.constant 2 : i32 + + // This loop execute with signed comparison, because both bounds are + // negative and there is no crossing of 0 here. + // CHECK: "test.trip-count" = 5 + %r7 = scf.for %i = %c-10 to %c-1 step %c2 iter_args(%0 = %a) -> i32 : i32 { + scf.yield %b : i32 + } + return %r7 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @trip_count_index_negative_bounds_unsigned( +func.func @trip_count_index_negative_bounds_unsigned(%a : i32, %b : i32) -> i32 { + %c-10 = arith.constant -10 : index + %c-1 = arith.constant -1 : index + %c2 = arith.constant 2 : index + + // Index type has a unknown bitwidth, we can't compute a loop tripcount + // in theory because of overflow concerns. + // CHECK: "test.trip-count" = 5 + %r8 = scf.for %i = %c-10 to %c-1 step %c2 iter_args(%0 = %a) -> i32 { + scf.yield %b : i32 + } + return %r8 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @trip_count_i32_negative_bounds_unsigned( +func.func @trip_count_i32_negative_bounds_unsigned(%a : i32, %b : i32) -> i32 { + %c-10 = arith.constant -10 : i32 + %c-1 = arith.constant -1 : i32 + %c2 = arith.constant 2 : i32 + + // CHECK: "test.trip-count" = 5 + %r8 = scf.for %i = %c-10 to %c-1 step %c2 iter_args(%0 = %a) -> i32 : i32 { + scf.yield %b : i32 + } + return %r8 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @trip_count_index_overflow_signed( +func.func @trip_count_index_overflow_signed(%a : i32, %b : i32) -> i32 { + %c1 = arith.constant 1 : index + %c_max = arith.constant 2147483647 : index // 2^31 - 1 + %c_min = arith.constant 2147483648 : index // -2^31 + + // Index type has a unknown bitwidth, we can't compute a loop tripcount + // in theory because of overflow concerns. + // CHECK: "test.trip-count" = 1 + %r9 = scf.for %i = %c_max to %c_min step %c1 iter_args(%0 = %a) -> i32 { + scf.yield %b : i32 + } + return %r9 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @trip_count_i32_overflow_signed( +func.func @trip_count_i32_overflow_signed(%a : i32, %b : i32) -> i32 { + %c1 = arith.constant 1 : i32 + %c_max = arith.constant 2147483647 : i32 // 2^31 - 1 + %c_min = arith.constant 2147483648 : i32 // -2^31 + + // This loop crosses the 2^31 threshold, which would overflow a signed 32-bit integer. + // CHECK: "test.trip-count" = 0 + %r9 = scf.for %i = %c_max to %c_min step %c1 iter_args(%0 = %a) -> i32 : i32 { + scf.yield %b : i32 + } + return %r9 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @trip_count_i32_overflow_signed_dyn_step( +func.func @trip_count_i32_overflow_signed_dyn_step(%a : i32, %b : i32, %step : i32) -> i32 { + %c_max = arith.constant 2147483647 : i32 // 2^31 - 1 + %c_min = arith.constant 2147483648 : i32 // -2^31 + + // This loop crosses the 2^31 threshold, which would overflow a signed 32-bit integer. + // CHECK: "test.trip-count" = 0 + %r9 = scf.for %i = %c_max to %c_min step %step iter_args(%0 = %a) -> i32 : i32 { + scf.yield %b : i32 + } + return %r9 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @trip_count_index_overflow_unsigned( +func.func @trip_count_index_overflow_unsigned(%a : i32, %b : i32) -> i32 { + %c1 = arith.constant 1 : index + %c_max = arith.constant 2147483647 : index // 2^31 - 1 + %c_min = arith.constant 2147483648 : index // -2^31 + + // Index type has a unknown bitwidth, we can't compute a loop tripcount + // in theory because of overflow concerns. + // CHECK: "test.trip-count" = 1 + %r10 = scf.for unsigned %i = %c_max to %c_min step %c1 iter_args(%0 = %a) -> i32 { + scf.yield %b : i32 + } + return %r10 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @trip_count_i32_overflow_unsigned( +func.func @trip_count_i32_overflow_unsigned(%a : i32, %b : i32) -> i32 { + %c1 = arith.constant 1 : i32 + %c_max = arith.constant 2147483647 : i32 // 2^31 - 1 + %c_min = arith.constant 2147483648 : i32 // -2^31 + + // The same loop with unsigned comparison executes normally + // CHECK: "test.trip-count" = 1 + %r10 = scf.for unsigned %i = %c_max to %c_min step %c1 iter_args(%0 = %a) -> i32 : i32 { + scf.yield %b : i32 + } + return %r10 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @trip_count_index_overflow_64bit_signed( +func.func @trip_count_index_overflow_64bit_signed(%a : i32, %b : i32) -> i32 { + %c1 = arith.constant 1 : index + %c_max = arith.constant 9223372036854775807 : index // 2^63 - 1 + %c_min = arith.constant -9223372036854775808 : index // -2^63 + + // This loop crosses the 2^63 threshold, which would overflow a signed 64-bit integer. + // Index type has a unknown bitwidth, we can't compute a loop tripcount. + // CHECK: "test.trip-count" = 0 + %r11 = scf.for %i = %c_max to %c_min step %c1 iter_args(%0 = %a) -> i32 { + scf.yield %b : i32 + } + return %r11 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @trip_count_i64_overflow_64bit_signed( +func.func @trip_count_i64_overflow_64bit_signed(%a : i32, %b : i32) -> i32 { + %c1 = arith.constant 1 : i64 + %c_max = arith.constant 9223372036854775807 : i64 // 2^63 - 1 + %c_min = arith.constant -9223372036854775808 : i64 // -2^63 + + // This loop crosses the 2^63 threshold, which would overflow a signed 64-bit integer. + // CHECK: "test.trip-count" = 0 + %r11 = scf.for %i = %c_max to %c_min step %c1 iter_args(%0 = %a) -> i32 : i64 { + scf.yield %b : i32 + } + return %r11 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @trip_count_index_overflow_64bit_unsigned( +func.func @trip_count_index_overflow_64bit_unsigned(%a : i32, %b : i32) -> i32 { + %c1 = arith.constant 1 : index + %c_max = arith.constant 9223372036854775807 : index // 2^63 - 1 + %c_min = arith.constant -9223372036854775808 : index // -2^63 + + // Index type has a unknown bitwidth, we can't compute a loop tripcount + // in theory because of overflow concerns. + // CHECK: "test.trip-count" = 1 + %r12 = scf.for unsigned %i = %c_max to %c_min step %c1 iter_args(%0 = %a) -> i32 { + scf.yield %b : i32 + } + return %r12 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @trip_count_i32_overflow_64bit_unsigned( +func.func @trip_count_i32_overflow_64bit_unsigned(%a : i32, %b : i32) -> i32 { + %c1 = arith.constant 1 : i64 + %c_max = arith.constant 9223372036854775807 : i64 // 2^63 - 1 + %c_min = arith.constant -9223372036854775808 : i64 // -2^63 + + // The same loop with unsigned comparison executes normally + // CHECK: "test.trip-count" = 1 + %r12 = scf.for unsigned %i = %c_max to %c_min step %c1 iter_args(%0 = %a) -> i32 : i64 { + scf.yield %b : i32 + } + return %r12 : i32 +} + +// ----- + +// CHECK-LABEL:func.func @trip_count_step_greater_than_iteration( +func.func @trip_count_step_greater_than_iteration() -> i32 { + %c0_i32 = arith.constant 0 : i32 + %c4_i32 = arith.constant 4 : i32 + %c17_i32 = arith.constant 17 : i32 + %c16_i32 = arith.constant 16 : i32 + // CHECK: "test.trip-count" = 1 + %1 = scf.for %arg0 = %c16_i32 to %c17_i32 step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32) : i32 { + scf.yield %arg0 : i32 + } + return %1 : i32 +} + + +// ----- + +// CHECK-LABEL:func.func @trip_count_arith_add( +func.func @trip_count_arith_add(%lb : i32) -> i32 { + %c0_i32 = arith.constant 0 : i32 + %c4_i32 = arith.constant 4 : i32 + %c17_i32 = arith.constant 17 : i32 + %c16_i32 = arith.constant 16 : i32 + // Can't compute a trip-count in the absence of overflow flag. + // CHECK: "test.trip-count" = "none" + %ub = arith.addi %lb, %c16_i32 : i32 + %1 = scf.for %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32) : i32 { + scf.yield %arg0 : i32 + } + return %1 : i32 +} + +// ----- + +// CHECK-LABEL:func.func @trip_count_arith_add_negative( +func.func @trip_count_arith_add_negative(%lb : i32) -> i32 { + %c0_i32 = arith.constant 0 : i32 + %c4_i32 = arith.constant 4 : i32 + %c-16_i32 = arith.constant -16 : i32 + // Can't compute a trip-count in the absence of overflow flag. + // CHECK: "test.trip-count" = "none" + %ub = arith.addi %lb, %c-16_i32 : i32 + %1 = scf.for %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32) : i32 { + scf.yield %arg0 : i32 + } + return %1 : i32 +} + +// ----- + +// CHECK-LABEL:func.func @trip_count_arith_add_nsw_loop_signed( +func.func @trip_count_arith_add_nsw_loop_signed(%lb : i32) -> i32 { + %c0_i32 = arith.constant 0 : i32 + %c4_i32 = arith.constant 4 : i32 + %c16_i32 = arith.constant 16 : i32 + %ub = arith.addi %lb, %c16_i32 overflow : i32 + // CHECK: "test.trip-count" = 4 + %1 = scf.for %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32) : i32 { + scf.yield %arg0 : i32 + } + return %1 : i32 +} + +// ----- + +// CHECK-LABEL:func.func @trip_count_arith_add_negative_nsw_loop_signed( +func.func @trip_count_arith_add_negative_nsw_loop_signed(%lb : i32) -> i32 { + %c0_i32 = arith.constant 0 : i32 + %c4_i32 = arith.constant 4 : i32 + %c-16_i32 = arith.constant -16 : i32 + %ub = arith.addi %lb, %c-16_i32 overflow : i32 + // CHECK: "test.trip-count" = 0 + %1 = scf.for %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32) : i32 { + scf.yield %arg0 : i32 + } + return %1 : i32 +} + +// ----- + +// CHECK-LABEL:func.func @trip_count_arith_add_negative_nsw_loop_signed_step_dyn( +func.func @trip_count_arith_add_negative_nsw_loop_signed_step_dyn(%lb : i32, %step : i32) -> i32 { + %c0_i32 = arith.constant 0 : i32 + %c-16_i32 = arith.constant -16 : i32 + %ub = arith.addi %lb, %c-16_i32 overflow : i32 + // CHECK: "test.trip-count" = 0 + %1 = scf.for %arg0 = %lb to %ub step %step iter_args(%arg1 = %c0_i32) -> (i32) : i32 { + scf.yield %arg0 : i32 + } + return %1 : i32 +} + +// ----- + +// CHECK-LABEL:func.func @trip_count_arith_add_nsw_loop_unsigned( +func.func @trip_count_arith_add_nsw_loop_unsigned(%lb : i32) -> i32 { + %c0_i32 = arith.constant 0 : i32 + %c4_i32 = arith.constant 4 : i32 + %c16_i32 = arith.constant 16 : i32 + // Can't compute a trip-count when the overflow flag mismatches the loop comparison signess + // CHECK: "test.trip-count" = "none" + %ub = arith.addi %lb, %c16_i32 overflow : i32 + %1 = scf.for unsigned %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32) : i32 { + scf.yield %arg0 : i32 + } + return %1 : i32 +} + +// ----- + +// CHECK-LABEL:func.func @trip_count_arith_add_negative_nsw_loop_unsigned( +func.func @trip_count_arith_add_negative_nsw_loop_unsigned(%lb : i32) -> i32 { + %c0_i32 = arith.constant 0 : i32 + %c4_i32 = arith.constant 4 : i32 + %c-16_i32 = arith.constant -16 : i32 + // Can't compute a trip-count when the overflow flag mismatches the loop comparison signess + // CHECK: "test.trip-count" = "none" + %ub = arith.addi %lb, %c-16_i32 overflow : i32 + %1 = scf.for unsigned %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32) : i32 { + scf.yield %arg0 : i32 + } + return %1 : i32 +} + +// ----- + +// CHECK-LABEL:func.func @trip_count_arith_add_nuw_loop_signed( +func.func @trip_count_arith_add_nuw_loop_signed(%lb : i32) -> i32 { + %c0_i32 = arith.constant 0 : i32 + %c4_i32 = arith.constant 4 : i32 + %c16_i32 = arith.constant 16 : i32 + // Can't compute a trip-count when the overflow flag mismatches the loop comparison signess + // CHECK: "test.trip-count" = "none" + %ub = arith.addi %lb, %c16_i32 overflow : i32 + %1 = scf.for %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32) : i32 { + scf.yield %arg0 : i32 + } + return %1 : i32 +} + +// ----- + +// CHECK-LABEL:func.func @trip_count_arith_add_negative_nuw_loop_signed( +func.func @trip_count_arith_add_negative_nuw_loop_signed(%lb : i32) -> i32 { + %c0_i32 = arith.constant 0 : i32 + %c4_i32 = arith.constant 4 : i32 + %c-16_i32 = arith.constant -16 : i32 + // Can't compute a trip-count when the overflow flag mismatches the loop comparison signess + // CHECK: "test.trip-count" = "none" + %ub = arith.addi %lb, %c-16_i32 overflow : i32 + %1 = scf.for %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32) : i32 { + scf.yield %arg0 : i32 + } + return %1 : i32 +} + +// ----- + +// CHECK-LABEL:func.func @trip_count_arith_add_nuw_loop_unsigned( +func.func @trip_count_arith_add_nuw_loop_unsigned(%lb : i32) -> i32 { + %c0_i32 = arith.constant 0 : i32 + %c4_i32 = arith.constant 4 : i32 + %c16_i32 = arith.constant 16 : i32 + // CHECK: "test.trip-count" = 4 + %ub = arith.addi %lb, %c16_i32 overflow : i32 + %1 = scf.for unsigned %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32) : i32 { + scf.yield %arg0 : i32 + } + return %1 : i32 +} + +// ----- + +// CHECK-LABEL:func.func @trip_count_arith_add_negative_nuw_loop_unsigned( +func.func @trip_count_arith_add_negative_nuw_loop_unsigned(%lb : i32) -> i32 { + %c0_i32 = arith.constant 0 : i32 + %c4_i32 = arith.constant 4 : i32 + %c-16_i32 = arith.constant -16 : i32 + // CHECK: "test.trip-count" = 0 + %ub = arith.addi %lb, %c-16_i32 overflow : i32 + %1 = scf.for unsigned %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32) : i32 { + scf.yield %arg0 : i32 + } + return %1 : i32 +} + +// ----- + +// CHECK-LABEL:func.func @trip_count_arith_add_negative_nuw_loop_unsigned_step_dyn( +func.func @trip_count_arith_add_negative_nuw_loop_unsigned_step_dyn(%lb : i32, %step : i32) -> i32 { + %c0_i32 = arith.constant 0 : i32 + %c-16_i32 = arith.constant -16 : i32 + // CHECK: "test.trip-count" = 0 + %ub = arith.addi %lb, %c-16_i32 overflow : i32 + %1 = scf.for unsigned %arg0 = %lb to %ub step %step iter_args(%arg1 = %c0_i32) -> (i32) : i32 { + scf.yield %arg0 : i32 + } + return %1 : i32 +} + +// ----- + +// CHECK-LABEL:func.func @trip_count_arith_add_nuw_loop_unsigned_invalid( +func.func @trip_count_arith_add_nuw_loop_unsigned_invalid(%lb : i32, %other : i32) -> i32 { + %c0_i32 = arith.constant 0 : i32 + %c4_i32 = arith.constant 4 : i32 + %c16_i32 = arith.constant 16 : i32 + // The addition here is not adding from %lb + // CHECK: "test.trip-count" = "none" + %ub = arith.addi %other, %c16_i32 overflow : i32 + %1 = scf.for unsigned %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32) : i32 { + scf.yield %arg0 : i32 + } + return %1 : i32 +} \ No newline at end of file diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp index 9a394d2875b67..6199cb13149af 100644 --- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp +++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp @@ -42,6 +42,20 @@ struct TestSCFForUtilsPass void runOnOperation() override { func::FuncOp func = getOperation(); + // Annotate every loop-like operation with the static trip count. + func.walk([&](LoopLikeOpInterface loopOp) { + std::optional tripCount = loopOp.getStaticTripCount(); + if (tripCount.has_value()) + loopOp->setDiscardableAttr( + "test.trip-count", + IntegerAttr::get(IntegerType::get(&getContext(), + tripCount.value().getBitWidth()), + tripCount.value().getSExtValue())); + else + loopOp->setDiscardableAttr("test.trip-count", + StringAttr::get(&getContext(), "none")); + }); + if (testReplaceWithNewYields) { func.walk([&](scf::ForOp forOp) { if (forOp.getNumResults() == 0)