diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index bae6e68400e22..c789b4c8904d3 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -387,9 +387,26 @@ FailureOr mlir::loopUnrollByFactor( std::optional constTripCount = forOp.getStaticTripCount(); if (constTripCount) { // Constant loop bounds computation. - int64_t lbCst = getConstantIntValue(forOp.getLowerBound()).value(); - int64_t ubCst = getConstantIntValue(forOp.getUpperBound()).value(); - int64_t stepCst = getConstantIntValue(forOp.getStep()).value(); + bool isUnsignedLoop = forOp.getUnsignedCmp(); + // For unsigned loops, bounds must be zero-extended: narrow integer types + // (e.g. i1, i2, i3) may have bit patterns that are negative in a signed + // context (e.g., i1 value 1 has getSExtValue() == -1, getZExtValue() == 1). + // Zero-extension is only safe when the unsigned value fits in int64_t, i.e. + // the type's bitwidth is < 64. Bail out for 64-bit unsigned loops. + if (isUnsignedLoop) { + if (auto intTy = dyn_cast(forOp.getUpperBound().getType())) + if (intTy.getWidth() >= 64) + return failure(); + } + auto getLoopBound = [&](Value v) -> int64_t { + auto apInt = getConstantAPIntValue(v); + assert(apInt && "expected constant loop bound"); + return isUnsignedLoop ? static_cast(apInt->first.getZExtValue()) + : apInt->first.getSExtValue(); + }; + int64_t lbCst = getLoopBound(forOp.getLowerBound()); + int64_t ubCst = getLoopBound(forOp.getUpperBound()); + int64_t stepCst = getLoopBound(step); if (unrollFactor == 1) { if (constTripCount->isOne() && failed(forOp.promoteIfSingleIteration(rewriter))) @@ -412,9 +429,16 @@ FailureOr mlir::loopUnrollByFactor( else upperBoundUnrolled = forOp.getUpperBound(); - // Create constant for 'stepUnrolled'. + // Create constant for 'stepUnrolled'. When the main loop has zero + // iterations (tripCountEvenMultiple == 0), keep the original step. + // stepCst * unrollFactor may produce a value that, when truncated to the + // bound type's bitwidth during IntegerAttr construction, wraps to zero; a + // zero step causes constantTripCount to return nullopt instead of 0, which + // prevents the zero-trip main loop from being elided. + bool mainLoopHasNoIter = (tripCountEvenMultiple == 0); + bool stepUnchanged = (stepCst == stepUnrolledCst); stepUnrolled = - stepCst == stepUnrolledCst + (mainLoopHasNoIter || stepUnchanged) ? step : arith::ConstantOp::create(boundsBuilder, loc, boundsBuilder.getIntegerAttr( diff --git a/mlir/test/Dialect/SCF/loop-unroll.mlir b/mlir/test/Dialect/SCF/loop-unroll.mlir index 4c72d9e99d049..f764013ed50f9 100644 --- a/mlir/test/Dialect/SCF/loop-unroll.mlir +++ b/mlir/test/Dialect/SCF/loop-unroll.mlir @@ -518,3 +518,145 @@ func.func @loop_unroll_static_yield_value_defined_above(%init: i32) { // UNROLL-OUTER-BY-2: %[[SUM1:.*]] = arith.andi %[[INIT]], %[[SUM]] // UNROLL-OUTER-BY-2: scf.yield %[[SUM1]], %[[INIT]] : i32, i32 +// ----- + +// Test unrolling an unsigned scf.for whose bounds use narrow integer types. +// The unsigned upper bound 2 in i2 has the same bit pattern as signed -2, so +// getConstantIntValue (which sign-extends) would return -2. The fix ensures we +// zero-extend when the loop comparison is unsigned. + +// Case 1: trip count 2 (i2 unsigned: lb=0, ub=2, step=1), unroll by 2 => +// fully unrolled, no residual loop. +func.func @unroll_unsigned_i2_tc2() -> (i32, i32) { + %0 = arith.constant 7 : i32 + %lb = arith.constant 0 : i2 + %ub = arith.constant 2 : i2 + %step = arith.constant 1 : i2 + %result:2 = scf.for unsigned %i = %lb to %ub step %step + iter_args(%arg0 = %0, %arg1 = %0) -> (i32, i32) : i2 { + %add = arith.addi %arg0, %arg1 : i32 + %mul = arith.muli %arg0, %arg1 : i32 + scf.yield %add, %mul : i32, i32 + } + return %result#0, %result#1 : i32, i32 +} +// UNROLL-BY-2-LABEL: @unroll_unsigned_i2_tc2 +// UNROLL-BY-2-NOT: scf.for +// UNROLL-BY-2: arith.addi +// UNROLL-BY-2: arith.muli +// UNROLL-BY-2: arith.addi +// UNROLL-BY-2: arith.muli +// UNROLL-BY-2: return + +// ----- + +// Case 2: trip count 4 (i3 unsigned: lb=0, ub=4, step=1), unroll by 2 => +// main loop with step 2 and 2 copies of body, no epilogue. +func.func @unroll_unsigned_i3_tc4() -> (i32, i32) { + %0 = arith.constant 7 : i32 + %lb = arith.constant 0 : i3 + %ub = arith.constant 4 : i3 + %step = arith.constant 1 : i3 + %result:2 = scf.for unsigned %i = %lb to %ub step %step + iter_args(%arg0 = %0, %arg1 = %0) -> (i32, i32) : i3 { + %add = arith.addi %arg0, %arg1 : i32 + %mul = arith.muli %arg0, %arg1 : i32 + scf.yield %add, %mul : i32, i32 + } + return %result#0, %result#1 : i32, i32 +} +// UNROLL-BY-2-LABEL: @unroll_unsigned_i3_tc4 +// UNROLL-BY-2: scf.for unsigned %{{.*}} = %{{.*}} to %c-4_i3 step %c2_i3 +// UNROLL-BY-2: arith.addi +// UNROLL-BY-2: arith.muli +// UNROLL-BY-2: arith.addi +// UNROLL-BY-2: arith.muli +// UNROLL-BY-2-NOT: scf.for +// UNROLL-BY-2: return + +// ----- + +// Case 3: trip count 1 (i1 unsigned: lb=0, ub=1, step=1), unroll by 2 => +// tripCountEvenMultiple=0, so the main loop has zero iterations and is elided; +// the single iteration is handled by the epilogue (promoted out of loop). +// Exercises both Bug 1 (sext(1:i1)=-1 would suppress the epilogue) and Bug 2 +// (stepCst*2=2, truncated to i1, wraps to 0, preventing zero-trip elision). +func.func @unroll_unsigned_i1_tc1() -> (i32, i32) { + %0 = arith.constant 7 : i32 + %lb = arith.constant 0 : i1 + %ub = arith.constant 1 : i1 + %step = arith.constant 1 : i1 + %result:2 = scf.for unsigned %i = %lb to %ub step %step + iter_args(%arg0 = %0, %arg1 = %0) -> (i32, i32) : i1 { + %add = arith.addi %arg0, %arg1 : i32 + %mul = arith.muli %arg0, %arg1 : i32 + scf.yield %add, %mul : i32, i32 + } + return %result#0, %result#1 : i32, i32 +} +// UNROLL-BY-2-LABEL: @unroll_unsigned_i1_tc1 +// UNROLL-BY-2-NOT: scf.for +// UNROLL-BY-2: arith.addi +// UNROLL-BY-2: arith.muli +// UNROLL-BY-2: return + +// ----- + +// Case 4: trip count 5 (i3 unsigned: lb=0, ub=5, step=1), unroll by 2 => +// main loop runs twice with step 2 (4 of 5 iterations), epilogue runs once. +// In i3, ub=5 has the same bit pattern as signed -3; with the old sign-extended +// ubCst=-3 the comparison upperBoundUnrolledCst(4) < ubCst(-3) would be false, +// suppressing the epilogue. This tests Bug 1 independently of Bug 2. +func.func @unroll_unsigned_i3_tc5_with_epilogue() -> (i32, i32) { + %0 = arith.constant 7 : i32 + %lb = arith.constant 0 : i3 + %ub = arith.constant 5 : i3 + %step = arith.constant 1 : i3 + %result:2 = scf.for unsigned %i = %lb to %ub step %step + iter_args(%arg0 = %0, %arg1 = %0) -> (i32, i32) : i3 { + %add = arith.addi %arg0, %arg1 : i32 + %mul = arith.muli %arg0, %arg1 : i32 + scf.yield %add, %mul : i32, i32 + } + return %result#0, %result#1 : i32, i32 +} +// UNROLL-BY-2-LABEL: @unroll_unsigned_i3_tc5_with_epilogue +// Main loop (4 of 5 iterations unrolled, step=2, body appears twice). +// In i3, the unrolled upper bound 4 is printed as -4 (signed 2's complement). +// UNROLL-BY-2: scf.for unsigned %{{.*}} = %{{.*}} to %c-4_i3 step %c2_i3 +// UNROLL-BY-2: arith.addi +// UNROLL-BY-2: arith.muli +// UNROLL-BY-2: arith.addi +// UNROLL-BY-2: arith.muli +// Epilogue (1 remaining iteration, promoted out of loop). +// UNROLL-BY-2-NOT: scf.for +// UNROLL-BY-2: arith.addi +// UNROLL-BY-2: arith.muli +// UNROLL-BY-2: return + +// ----- + +// Case 5: trip count 1 (i2 unsigned: lb=0, ub=2, step=2), unroll by 2 => +// tripCountEvenMultiple=0 with step > 1. Without the Bug 2 fix, stepCst*2=4 +// would be truncated to i2 (4 mod 4 = 0), creating a zero step. A zero step +// causes constantTripCount to return nullopt instead of 0, preventing the +// zero-trip main loop from being elided. +// Expected: no scf.for, single body copy (epilogue promoted), then return. +func.func @unroll_unsigned_i2_step2_bug2() -> (i32, i32) { + %0 = arith.constant 7 : i32 + %lb = arith.constant 0 : i2 + %ub = arith.constant 2 : i2 + %step = arith.constant 2 : i2 + %result:2 = scf.for unsigned %i = %lb to %ub step %step + iter_args(%arg0 = %0, %arg1 = %0) -> (i32, i32) : i2 { + %add = arith.addi %arg0, %arg1 : i32 + %mul = arith.muli %arg0, %arg1 : i32 + scf.yield %add, %mul : i32, i32 + } + return %result#0, %result#1 : i32, i32 +} +// UNROLL-BY-2-LABEL: @unroll_unsigned_i2_step2_bug2 +// UNROLL-BY-2-NOT: scf.for +// UNROLL-BY-2: arith.addi +// UNROLL-BY-2: arith.muli +// UNROLL-BY-2: return