diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp index 59f068c205cf3..bc9d8a2496b4b 100644 --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -351,8 +351,6 @@ std::optional constantTripCount( return std::nullopt; APSInt lbCst(maybeLbCst->first, /*isUnsigned=*/!isSigned); APSInt ubCst(maybeUbCst->first, /*isUnsigned=*/!isSigned); - if (!maybeUbCst) - return std::nullopt; if (ubCst <= lbCst) { LDBG() << "constantTripCount is 0 because ub <= lb (" << lbCst << "(" << lbCst.getBitWidth() << ") <= " << ubCst << "(" @@ -385,9 +383,9 @@ std::optional constantTripCount( return std::nullopt; } auto &stepCst = maybeStepCst->first; - llvm::APInt tripCount = diff.sdiv(stepCst); - llvm::APInt r = diff.srem(stepCst); - if (!r.isZero()) + llvm::APInt tripCount = isSigned ? diff.sdiv(stepCst) : diff.udiv(stepCst); + llvm::APInt remainder = isSigned ? diff.srem(stepCst) : diff.urem(stepCst); + if (!remainder.isZero()) tripCount = tripCount + 1; LDBG() << "constantTripCount found: " << tripCount; return tripCount; diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index d5d0aee3bbe25..365c0e1d5c86f 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -762,6 +762,33 @@ func.func @replace_single_iteration_const_diff(%arg0 : index) { // ----- +func.func @replace_single_iteration_loop_unsigned_cmp() { +// CHECK-LABEL: func.func @replace_single_iteration_loop_unsigned_cmp() { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 +// CHECK: %[[CONSTANT_1:.*]] = arith.constant -100 : i32 +// CHECK: %[[CONSTANT_2:.*]] = arith.constant 2147483647 : i32 +// CHECK: %[[VAL_0:.*]] = "test.init"() : () -> i32 +// CHECK: %[[FOR_0:.*]] = scf.for unsigned %[[VAL_1:.*]] = %[[CONSTANT_0]] to %[[CONSTANT_1]] step %[[CONSTANT_2]] iter_args(%[[VAL_2:.*]] = %[[VAL_0]]) -> (i32) : i32 { +// CHECK: %[[VAL_3:.*]] = "test.op"(%[[VAL_1]], %[[VAL_2]]) : (i32, i32) -> i32 +// CHECK: scf.yield %[[VAL_3]] : i32 +// CHECK: } +// CHECK: "test.consume"(%[[FOR_0]]) : (i32) -> () +// CHECK: return +// CHECK: } + %lowerBound = arith.constant 0 : i32 + %upperBound = arith.constant -100 : i32 + %step = arith.constant 2147483647 : i32 + %init = "test.init"() : () -> i32 + %0 = scf.for unsigned %i = %lowerBound to %upperBound step %step iter_args(%arg = %init) -> (i32) : i32 { + %1 = "test.op"(%i, %arg) : (i32, i32) -> i32 + scf.yield %1 : i32 + } + "test.consume"(%0) : (i32) -> () + return +} + +// ----- + // CHECK-LABEL: @remove_empty_parallel_loop func.func @remove_empty_parallel_loop(%lb: index, %ub: index, %s: index) { // CHECK: %[[INIT:.*]] = "test.init"