diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h index 8a421c0e2f273..5e0983a817c62 100644 --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -216,6 +216,10 @@ LogicalResult foldDynamicStrideList(SmallVectorImpl &strides); /// unsigned. (The result of this function must be interpreted as an unsigned /// integer.) A lower bound greater than the upper bound is considered invalid /// and will yield a zero trip count. +/// +/// Note: The loops modeled here use a less-than comparison (`<`), meaning the +/// loop continues while `iv < ub`. This is different from arbitrary C++ loops +/// which can use various comparison operators. /// The `computeUbMinusLb` callback is invoked to compute the difference between /// the upper and lower bound when not constant. It can be used by the client /// to compute a static difference when the bounds are not constant. diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp index 1c19b995b9f3f..c9f95ba69ac0a 100644 --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -365,7 +365,13 @@ std::optional constantTripCount( << (isSigned ? "isSigned" : "isUnsigned") << ")"; return APInt(bitwidth, 0); } + // Compute the difference. Since we've already checked that ub > lb, the + // result can be interpreted as an unsigned value without overflow concerns. diff = ubCst - lbCst; + // Convert diff to unsigned. This handles cases like i8: ub=127, lb=-128 + // where the subtraction yields 255, which wraps to -1 in signed i8 but is + // correctly represented as 255 when interpreted as unsigned. + diff.setIsUnsigned(true); } else { if (maybeUbCst) return std::nullopt; @@ -397,11 +403,14 @@ std::optional constantTripCount( return std::nullopt; } - // Create new APSInt instances with explicit signedness to ensure they match - llvm::APInt tripCount = isSigned ? diff.sdiv(stepCst) : diff.udiv(stepCst); - llvm::APInt remainder = isSigned ? diff.srem(stepCst) : diff.urem(stepCst); + // Both diff and step are non-negative at this point (negative steps are + // rejected earlier), so we use unsigned division regardless of the loop + // comparison signedness. + llvm::APInt tripCount = diff.udiv(stepCst); + llvm::APInt remainder = diff.urem(stepCst); if (!remainder.isZero()) tripCount = tripCount + 1; + LDBG() << "constantTripCount found: " << tripCount; return tripCount; } diff --git a/mlir/test/Dialect/SCF/trip_count.mlir b/mlir/test/Dialect/SCF/trip_count.mlir index 927b405dbaea6..7e74988b35019 100644 --- a/mlir/test/Dialect/SCF/trip_count.mlir +++ b/mlir/test/Dialect/SCF/trip_count.mlir @@ -770,6 +770,24 @@ func.func @trip_count_i8_signed_crossing_zero(%a : i32, %b : i32) -> i32 { // ----- +// CHECK-LABEL:func.func @trip_count_i8_signed_overflow_fix( +func.func @trip_count_i8_signed_overflow_fix(%a : i32, %b : i32) -> i32 { + %c-128 = arith.constant -128 : i8 + %c127 = arith.constant 127 : i8 + %c1 = arith.constant 1 : i8 + + // Signed i8 from -128 to 127: tests overflow fix + // Without the fix, computing (127 - (-128)) would overflow in i8. + // The trip count should be 255, but will be printed as -1 in i8 signed format. + // CHECK: "test.trip-count" = -1 : i8 + %r = scf.for %i = %c-128 to %c127 step %c1 iter_args(%0 = %a) -> i32 : i8 { + scf.yield %b : i32 + } + return %r : i32 +} + +// ----- + // CHECK-LABEL:func.func @trip_count_i16_unsigned_full_range( func.func @trip_count_i16_unsigned_full_range(%a : i32, %b : i32) -> i32 { %c0 = arith.constant 0 : i16