diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp index bc9d8a2496b4b..065d1d3545b1c 100644 --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -346,10 +346,13 @@ std::optional constantTripCount( std::optional> maybeLbCst = getConstantAPIntValue(lb); std::optional> maybeUbCst = getConstantAPIntValue(ub); if (maybeLbCst) { + APSInt lbCst(maybeLbCst->first, /*isUnsigned=*/!isSigned); + if (lbCst.isZero() && step == ub) + return APInt(bitwidth, 1); + // If one of the bounds is not a constant, we can't compute the trip count. if (!maybeUbCst) return std::nullopt; - APSInt lbCst(maybeLbCst->first, /*isUnsigned=*/!isSigned); APSInt ubCst(maybeUbCst->first, /*isUnsigned=*/!isSigned); if (ubCst <= lbCst) { LDBG() << "constantTripCount is 0 because ub <= lb (" << lbCst << "(" diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index e770f595bd262..230ea843e7057 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -789,6 +789,25 @@ func.func @replace_single_iteration_loop_unsigned_cmp() { // ----- +// CHECK-LABEL: @replace_single_iteration_loop_ub_equal_step +func.func @replace_single_iteration_loop_ub_equal_step(%ub_step : index) { + // CHECK: %[[LB:.*]] = arith.constant 0 + %c0 = arith.constant 0 : index + // CHECK: %[[INIT:.*]] = "test.init" + %init = "test.init"() : () -> i32 + // CHECK-NOT: scf.for + // CHECK: %[[VAL:.*]] = "test.op"(%[[LB]], %[[INIT]]) + %0 = scf.for %i = %c0 to %ub_step step %ub_step iter_args(%arg = %init) -> (i32) { + %1 = "test.op"(%i, %arg) : (index, i32) -> i32 + scf.yield %1 : i32 + } + // CHECK: "test.consume"(%[[VAL]]) + "test.consume"(%0) : (i32) -> () + return +} + +// ----- + // CHECK-LABEL: @remove_empty_parallel_loop func.func @remove_empty_parallel_loop(%lb: index, %ub: index, %s: index) { // CHECK: %[[INIT:.*]] = "test.init"