diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h index 3475bb2151777..881125bd0da2f 100644 --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -251,6 +251,7 @@ FailureOr parallelLoopUnrollByFactors( /// Get constant trip counts for each of the induction variables of the given /// loop operation. If any of the loop's trip counts is not constant, return an /// empty vector. +/// TODO(#178506): Should return SmallVector for correct signedness. llvm::SmallVector getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp); diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h index ba8a0304de9d3..020a124656d19 100644 --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -212,10 +212,11 @@ foldDynamicOffsetSizeList(SmallVectorImpl &offsetsOrSizes); LogicalResult foldDynamicStrideList(SmallVectorImpl &strides); /// Return the number of iterations for a loop with a lower bound `lb`, upper -/// bound `ub` and step `step`. The `isSigned` flag indicates whether the loop -/// comparison between lb and ub is signed or unsigned. A negative step or a -/// lower bound greater than the upper bound are considered invalid and will -/// yield a zero trip count. +/// bound `ub` and step `step`, as an unsigned integer. The `isSigned` flag +/// indicates whether the loop comparison between lb and ub is signed or +/// unsigned. (The result of this function must be interpreted as an unsigned +/// integer.) A lower bound greater than the upper bound is considered invalid +/// and will yield a zero trip count. /// The `computeUbMinusLb` callback is invoked to compute the difference between /// the upper and lower bound when not constant. It can be used by the client /// to compute a static difference when the bounds are not constant. diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index b900c63044d44..c46a0577c4b96 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -443,7 +443,7 @@ LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) { LDBG() << "promoteIfSingleIteration tripCount is " << tripCount << " for loop " << OpWithFlags(getOperation(), OpPrintingFlags().skipRegions()); - if (!tripCount.has_value() || tripCount->getSExtValue() > 1) + if (!tripCount.has_value() || tripCount->getZExtValue() > 1) return failure(); if (*tripCount == 0) { diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 888dd448b66f9..e46a3b6ce6c06 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -396,9 +396,12 @@ FailureOr mlir::loopUnrollByFactor( return UnrolledLoopInfo{forOp, std::nullopt}; } + // TODO(#178506): This may overflow for large trip counts. Should use + // uint64_t. int64_t tripCountEvenMultiple = - constTripCount->getSExtValue() - - (constTripCount->getSExtValue() % unrollFactor); + constTripCount->getZExtValue() - + (constTripCount->getZExtValue() % unrollFactor); + // TODO(#178506): This may overflow when computing upperBoundUnrolledCst. int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst; int64_t stepUnrolledCst = stepCst * unrollFactor; @@ -500,9 +503,9 @@ LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) { const APInt &tripCount = *mayBeConstantTripCount; if (tripCount.isZero()) return success(); - if (tripCount.getSExtValue() == 1) + if (tripCount.getZExtValue() == 1) return forOp.promoteIfSingleIteration(rewriter); - return loopUnrollByFactor(forOp, tripCount.getSExtValue()); + return loopUnrollByFactor(forOp, tripCount.getZExtValue()); } /// Check if bounds of all inner loops are defined outside of `forOp` @@ -553,7 +556,7 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp, "trip " "count"; unrollJamFactor = tripCount->getZExtValue(); - } else if (tripCount->getSExtValue() % unrollJamFactor != 0) { + } else if (tripCount->getZExtValue() % unrollJamFactor != 0) { LDBG() << "failed to unroll and jam: unsupported trip count that is not a " "multiple of unroll jam factor"; return failure(); @@ -1566,13 +1569,16 @@ mlir::getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp) { std::optional> steps = loopOp.getLoopSteps(); if (!loBnds || !upBnds || !steps) return {}; + // TODO(#178506): The result should be SmallVector and use uint64_t + // for trip counts. llvm::SmallVector tripCounts; for (auto [lb, ub, step] : llvm::zip(*loBnds, *upBnds, *steps)) { + // TODO(#178506): Signedness is not handled correctly here. std::optional numIter = constantTripCount( lb, ub, step, /*isSigned=*/true, scf::computeUbMinusLb); if (!numIter) return {}; - tripCounts.push_back(numIter->getSExtValue()); + tripCounts.push_back(numIter->getZExtValue()); } return tripCounts; } diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp index 7fb0d4e9710f8..a10925c8c2e65 100644 --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -336,8 +336,6 @@ std::optional constantTripCount( // case applies, so the static trip count is unknown. return std::nullopt; } - if (stepCst.isNegative()) - return APInt(bitwidth, 0); } if (isIndex) { @@ -391,6 +389,14 @@ std::optional constantTripCount( return std::nullopt; } auto &stepCst = maybeStepCst->first; + // For signed loops, a negative step size could indicate an infinite number of + // iterations. + if (isSigned && stepCst.isSignBitSet()) { + LDBG() << "constantTripCount is infinite because step is negative"; + return std::nullopt; + } + + // Create new APSInt instances with explicit signedness to ensure they match llvm::APInt tripCount = isSigned ? diff.sdiv(stepCst) : diff.udiv(stepCst); llvm::APInt remainder = isSigned ? diff.srem(stepCst) : diff.urem(stepCst); if (!remainder.isZero()) diff --git a/mlir/test/Dialect/SCF/trip_count.mlir b/mlir/test/Dialect/SCF/trip_count.mlir index 54883d7bb874c..927b405dbaea6 100644 --- a/mlir/test/Dialect/SCF/trip_count.mlir +++ b/mlir/test/Dialect/SCF/trip_count.mlir @@ -699,4 +699,105 @@ func.func @trip_count_arith_add_nuw_loop_unsigned_invalid(%lb : i32, %other : i3 scf.yield %arg0 : i32 } return %1 : i32 +} + +// ----- + +// CHECK-LABEL:func.func @trip_count_i8_unsigned_full_range( +func.func @trip_count_i8_unsigned_full_range(%a : i32, %b : i32) -> i32 { + %c0 = arith.constant 0 : i8 + %c255 = arith.constant 255 : i8 + %c1 = arith.constant 1 : i8 + + // Unsigned i8 from 0 to 255: trip count is 255 + // Trip counts are returned in their natural bitwidth and printed as signed. + // 255 in i8 is represented as -1 when printed in signed format. + // CHECK: "test.trip-count" = -1 : i8 + %r = scf.for unsigned %i = %c0 to %c255 step %c1 iter_args(%0 = %a) -> i32 : i8 { + scf.yield %b : i32 + } + return %r : i32 +} + +// ----- + +// CHECK-LABEL:func.func @trip_count_i8_unsigned_partial_range( +func.func @trip_count_i8_unsigned_partial_range(%a : i32, %b : i32) -> i32 { + %c0 = arith.constant 0 : i8 + %c200 = arith.constant 200 : i8 + %c1 = arith.constant 1 : i8 + + // Unsigned i8 from 0 to 200: trip count is 200 + // 200 in i8 is represented as -56 when printed in signed format. + // CHECK: "test.trip-count" = -56 : i8 + %r = scf.for unsigned %i = %c0 to %c200 step %c1 iter_args(%0 = %a) -> i32 : i8 { + scf.yield %b : i32 + } + return %r : i32 +} + +// ----- + +// CHECK-LABEL:func.func @trip_count_i8_unsigned_high_range( +func.func @trip_count_i8_unsigned_high_range(%a : i32, %b : i32) -> i32 { + %c200 = arith.constant 200 : i8 + %c255 = arith.constant 255 : i8 + %c1 = arith.constant 1 : i8 + + // Unsigned i8 from 200 to 255: trip count is 55 + // CHECK: "test.trip-count" = 55 : i8 + %r = scf.for unsigned %i = %c200 to %c255 step %c1 iter_args(%0 = %a) -> i32 : i8 { + scf.yield %b : i32 + } + return %r : i32 +} + +// ----- + +// CHECK-LABEL:func.func @trip_count_i8_signed_crossing_zero( +func.func @trip_count_i8_signed_crossing_zero(%a : i32, %b : i32) -> i32 { + %c-128 = arith.constant -128 : i32 + %c127 = arith.constant 127 : i32 + %c1 = arith.constant 1 : i32 + + // Signed i32 from -128 to 127, crossing zero + // CHECK: "test.trip-count" = 255 + %r = scf.for %i = %c-128 to %c127 step %c1 iter_args(%0 = %a) -> i32 : i32 { + scf.yield %b : i32 + } + return %r : i32 +} + +// ----- + +// CHECK-LABEL:func.func @trip_count_i16_unsigned_full_range( +func.func @trip_count_i16_unsigned_full_range(%a : i32, %b : i32) -> i32 { + %c0 = arith.constant 0 : i16 + %c65535 = arith.constant 65535 : i16 + %c1 = arith.constant 1 : i16 + + // Unsigned i16 from 0 to 65535: trip count is 65535 + // 65535 in i16 is represented as -1 when printed in signed format. + // CHECK: "test.trip-count" = -1 : i16 + %r = scf.for unsigned %i = %c0 to %c65535 step %c1 iter_args(%0 = %a) -> i32 : i16 { + scf.yield %b : i32 + } + return %r : i32 +} + +// ----- + +// CHECK-LABEL:func.func @trip_count_i8_unsigned_step_2( +func.func @trip_count_i8_unsigned_step_2(%a : i32, %b : i32) -> i32 { + %c0 = arith.constant 0 : i8 + %c255 = arith.constant 255 : i8 + %c2 = arith.constant 2 : i8 + + // Unsigned i8 from 0 to 255 step 2: trip count is 128 (255/2 rounded up) + // 128 in i8 is represented as -128 when printed in signed format. + // CHECK: "test.trip-count" = -128 : i8 + %r = scf.for unsigned %i = %c0 to %c255 step %c2 iter_args(%0 = %a) -> i32 : i8 { + scf.yield %b : i32 + } + return %r : i32 } \ No newline at end of file diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp index 6199cb13149af..3fc2df5c1e477 100644 --- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp +++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp @@ -50,7 +50,7 @@ struct TestSCFForUtilsPass "test.trip-count", IntegerAttr::get(IntegerType::get(&getContext(), tripCount.value().getBitWidth()), - tripCount.value().getSExtValue())); + tripCount.value().getZExtValue())); else loopOp->setDiscardableAttr("test.trip-count", StringAttr::get(&getContext(), "none"));