diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp index 1b458f410af60..899fe7374dbbd 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp @@ -444,7 +444,15 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop( loc, rewriter.getIntegerAttr(t, maxStage)); Value maxStageByStep = rewriter.create(loc, step, maxStageValue); - newUb = rewriter.create(loc, ub, maxStageByStep); + Value hasAtLeastOneIteration = rewriter.create( + loc, arith::CmpIPredicate::slt, maxStageByStep, ub); + Value possibleNewUB = + rewriter.create(loc, ub, maxStageByStep); + // In case of `index` or `unsigned` type, we need to make sure that the + // subtraction does not result in a negative value, instead we use lb + // to avoid entering the kernel loop. + newUb = rewriter.create( + loc, hasAtLeastOneIteration, possibleNewUB, forOp.getLowerBound()); } auto newForOp = rewriter.create(forOp.getLoc(), forOp.getLowerBound(), newUb, diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir index c879c83275bf8..d13d258412f2d 100644 --- a/mlir/test/Dialect/SCF/loop-pipelining.mlir +++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir @@ -770,7 +770,10 @@ func.func @stage_0_value_escape(%A: memref, %result: memref, %ub: // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[CM1:.*]] = arith.constant -1 : index -// CHECK: %[[UBM:.*]] = arith.subi %[[UB:.*]], %{{.*}} +// CHECK: %[[IT2_UB:.*]] = arith.muli %[[STEP:.*]], %[[C2:.*]] +// CHECK: %[[ENTERKERNEL:.*]] = arith.cmpi slt, %[[IT2_UB:.*]], %[[UB:.*]] +// CHECK: %[[PUBM:.*]] = arith.subi %[[UB:.*]], %[[IT2_UB:.*]] +// CHECK: %[[UBM:.*]] = arith.select %[[ENTERKERNEL:.*]], %[[PUBM:.*]], %[[LB:.*]] // CHECK: %{{.*}}:2 = scf.for %[[ARG5:.*]] = %[[LB:.*]] to %[[UBM]] step %[[STEP:.*]] iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}) // CHECK: memref.store %[[ARG6]], %{{.*}}[%[[ARG5]]] // CHECK: %[[ADDF_24:.*]] = arith.addf %[[ARG7]], %{{.*}} @@ -844,7 +847,9 @@ func.func @dynamic_loop(%A: memref, %result: memref, %lb: index, % // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[CM1:.*]] = arith.constant -1 : index // CHECK-DAG: %[[CF0:.*]] = arith.constant 0.000000e+00 -// CHECK: %[[UBM:.*]] = arith.subi %[[UB:.*]], %{{.*}} +// CHECK: %[[ENTERKERNEL:.*]] = arith.cmpi slt, %[[STEP:.*]], %[[UB:.*]] +// CHECK: %[[PUBM:.*]] = arith.subi %[[UB:.*]], %[[STEP:.*]] +// CHECK: %[[UBM:.*]] = arith.select %[[ENTERKERNEL:.*]], %[[PUBM:.*]], %[[LB:.*]] // CHECK: %{{.*}}:2 = scf.for %[[ARG5:.*]] = %[[LB:.*]] to %[[UBM]] step %[[STEP:.*]] iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}) // CHECK: %[[ADDF_13:.*]] = arith.addf %[[ARG7]], %[[ARG6]] // CHECK: %[[MULF_14:.*]] = arith.mulf %[[ADDF_13]], %{{.*}}