diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp index 83c9cf69ba036..1b458f410af60 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp @@ -642,22 +642,25 @@ LogicalResult LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter, llvm::SmallVector &returnValues) { Location loc = forOp.getLoc(); + Type t = lb.getType(); + // Emit different versions of the induction variable. They will be // removed by dead code if not used. - // bounds_range = ub - lb - // total_iterations = (bounds_range + step - 1) / step - Type t = lb.getType(); - Value zero = - rewriter.create(loc, rewriter.getIntegerAttr(t, 0)); - Value one = - rewriter.create(loc, rewriter.getIntegerAttr(t, 1)); - Value minusOne = - rewriter.create(loc, rewriter.getIntegerAttr(t, -1)); + auto createConst = [&](int v) { + return rewriter.create(loc, + rewriter.getIntegerAttr(t, v)); + }; + + // total_iterations = cdiv(range_diff, step); + // - range_diff = ub - lb + // - total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step + Value zero = createConst(0); + Value one = createConst(1); Value stepLessZero = rewriter.create( loc, arith::CmpIPredicate::slt, step, zero); Value stepDecr = - rewriter.create(loc, stepLessZero, one, minusOne); + rewriter.create(loc, stepLessZero, one, createConst(-1)); Value rangeDiff = rewriter.create(loc, ub, lb); Value rangeIncrStep = rewriter.create(loc, rangeDiff, step); @@ -665,25 +668,31 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter, rewriter.create(loc, rangeIncrStep, stepDecr); Value totalIterations = rewriter.create(loc, rangeDecr, step); + // If total_iters < max_stage, start the epilogue at zero to match the + // ramp-up in the prologue. + // start_iter = max(0, total_iters - max_stage) + Value iterI = rewriter.create(loc, totalIterations, + createConst(maxStage)); + iterI = rewriter.create(loc, zero, iterI); + + // Capture predicates for dynamic loops. SmallVector predicates(maxStage + 1); - for (int64_t i = 0; i < maxStage; i++) { - // iterI = total_iters - 1 - i - // May go negative... - Value minusI = - rewriter.create(loc, rewriter.getIntegerAttr(t, -i)); - Value iterI = rewriter.create( - loc, rewriter.create(loc, totalIterations, minusOne), - minusI); + + for (int64_t i = 1; i <= maxStage; i++) { // newLastIter = lb + step * iterI Value newlastIter = rewriter.create( loc, lb, rewriter.create(loc, step, iterI)); - setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i); + setValueMapping(forOp.getInductionVar(), newlastIter, i); + + // increment to next iterI + iterI = rewriter.create(loc, iterI, one); if (dynamicLoop) { - // pred = iterI >= 0 - predicates[i + 1] = rewriter.create( - loc, arith::CmpIPredicate::sge, iterI, zero); + // Disable stages when `i` is greater than total_iters. + // pred = total_iters >= i + predicates[i] = rewriter.create( + loc, arith::CmpIPredicate::sge, totalIterations, createConst(i)); } } diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir index af49d2afc049b..c879c83275bf8 100644 --- a/mlir/test/Dialect/SCF/loop-pipelining.mlir +++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir @@ -767,6 +767,7 @@ func.func @stage_0_value_escape(%A: memref, %result: memref, %ub: // Check for predicated epilogue for dynamic loop. // CHECK-LABEL: dynamic_loop( // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[CM1:.*]] = arith.constant -1 : index // CHECK: %[[UBM:.*]] = arith.subi %[[UB:.*]], %{{.*}} @@ -779,32 +780,32 @@ func.func @stage_0_value_escape(%A: memref, %result: memref, %ub: // CHECK: scf.yield %[[ADDF_24]], %[[LOAD_27]] // CHECK: } // CHECK: %[[CMPI_10:.*]] = arith.cmpi slt, %[[STEP]], %[[C0]] -// CHECK: %[[SEL_10:.*]] = arith.select %[[CMPI_10]], %[[C1]], %[[CM1]] -// CHECK: %[[SUBI_10:.*]] = arith.subi %[[UB]], %[[LB]] -// CHECK: %[[ADDI_11:.*]] = arith.addi %[[SUBI_10]], %[[STEP]] -// CHECK: %[[ADDI_12:.*]] = arith.addi %[[ADDI_11]], %[[SEL_10]] -// CHECK: %[[DIVSI_13:.*]] = arith.divsi %[[ADDI_12]], %[[STEP]] -// CHECK: %[[ADDI_14:.*]] = arith.addi %[[DIVSI_13]], %[[CM1]] -// CHECK: %[[MULI_15:.*]] = arith.muli %{{.*}}, %[[ADDI_14]] -// CHECK: %[[ADDI_16:.*]] = arith.addi %{{.*}}, %[[MULI_15]] -// CHECK: %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %[[C0]] -// CHECK: %[[ADDI_18:.*]] = arith.addi %[[DIVSI_13]], %{{.*}}-1 -// CHECK: %[[ADDI_19:.*]] = arith.addi %[[ADDI_18]], %{{.*}}-1 -// CHECK: %[[MULI_20:.*]] = arith.muli %{{.*}}, %[[ADDI_19]] -// CHECK: %[[ADDI_21:.*]] = arith.addi %{{.*}}, %[[MULI_20]] -// CHECK: %[[CMPI_22:.*]] = arith.cmpi sge, %[[ADDI_19]], %[[C0]] -// CHECK: scf.if %[[CMPI_17]] { -// CHECK: memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_21]]] +// CHECK: %[[SELECT_11:.*]] = arith.select %[[CMPI_10]], %[[C1]], %[[CM1]] +// CHECK: %[[SUBI_12:.*]] = arith.subi %[[UB]], %[[LB]] +// CHECK: %[[ADDI_13:.*]] = arith.addi %[[SUBI_12]], %[[STEP]] +// CHECK: %[[ADDI_14:.*]] = arith.addi %[[ADDI_13]], %[[SELECT_11]] +// CHECK: %[[DIVSI_15:.*]] = arith.divsi %[[ADDI_14]], %[[STEP]] +// CHECK: %[[SUBI_17:.*]] = arith.subi %[[DIVSI_15]], %[[C2]] +// CHECK: %[[MAXSI_18:.*]] = arith.maxsi %[[SUBI_17]], %[[C0]] +// CHECK: %[[MULI_19:.*]] = arith.muli %[[STEP]], %[[MAXSI_18]] +// CHECK: %[[ADDI_20:.*]] = arith.addi %[[LB]], %[[MULI_19]] +// CHECK: %[[ADDI_21:.*]] = arith.addi %[[MAXSI_18]], %[[C1]] +// CHECK: %[[CMPI_22:.*]] = arith.cmpi sge, %[[DIVSI_15]], %[[C1]] +// CHECK: %[[MULI_23:.*]] = arith.muli %[[STEP]], %[[ADDI_21]] +// CHECK: %[[ADDI_24:.*]] = arith.addi %[[LB]], %[[MULI_23]] +// CHECK: %[[CMPI_25:.*]] = arith.cmpi sge, %[[DIVSI_15]], %[[C2]] +// CHECK: scf.if %[[CMPI_22]] { +// CHECK: memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_20]]] // CHECK: } else { // CHECK: } -// CHECK: %[[IF_23:.*]] = scf.if %[[CMPI_22]] -> (f32) { -// CHECK: %[[ADDF_24:.*]] = arith.addf %{{.*}}#1, %{{.*}} -// CHECK: scf.yield %[[ADDF_24]] +// CHECK: %[[IF_26:.*]] = scf.if %[[CMPI_25]] +// CHECK: %[[ADDF_27:.*]] = arith.addf %{{.*}}#1, %{{.*}} +// CHECK: scf.yield %[[ADDF_27]] // CHECK: } else { // CHECK: scf.yield %{{.*}} // CHECK: } -// CHECK: scf.if %[[CMPI_22]] { -// CHECK: memref.store %[[IF_23]], %{{.*}}[%[[ADDI_16]]] +// CHECK: scf.if %[[CMPI_25]] { +// CHECK: memref.store %[[IF_26]], %{{.*}}[%[[ADDI_24]]] // CHECK: } else { // CHECK: } // CHECK: return @@ -842,6 +843,7 @@ func.func @dynamic_loop(%A: memref, %result: memref, %lb: index, % // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[CM1:.*]] = arith.constant -1 : index +// CHECK-DAG: %[[CF0:.*]] = arith.constant 0.000000e+00 // CHECK: %[[UBM:.*]] = arith.subi %[[UB:.*]], %{{.*}} // CHECK: %{{.*}}:2 = scf.for %[[ARG5:.*]] = %[[LB:.*]] to %[[UBM]] step %[[STEP:.*]] iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}) // CHECK: %[[ADDF_13:.*]] = arith.addf %[[ARG7]], %[[ARG6]] @@ -856,22 +858,21 @@ func.func @dynamic_loop(%A: memref, %result: memref, %lb: index, % // CHECK: %[[ADDI_7:.*]] = arith.addi %[[SUBI_6]], %[[STEP]] // CHECK: %[[ADDI_8:.*]] = arith.addi %[[ADDI_7]], %[[SELECT_5]] // CHECK: %[[DIVSI_9:.*]] = arith.divsi %[[ADDI_8]], %[[STEP]] -// CHECK: %[[ADDI_10:.*]] = arith.addi %[[DIVSI_9]], %[[CM1]] -// CHECK: %[[CMPI_11:.*]] = arith.cmpi sge, %[[ADDI_10]], %[[C0]] -// CHECK: %[[IF_10:.*]] = scf.if %[[CMPI_11]] -// CHECK: %[[ADDF_13:.*]] = arith.addf %{{.*}}#1, %{{.*}}#0 -// CHECK: scf.yield %[[ADDF_13]] +// CHECK: %[[CMPI_10:.*]] = arith.cmpi sge, %[[DIVSI_9]], %[[C1]] +// CHECK: %[[IF_11:.*]] = scf.if %[[CMPI_10]] +// CHECK: %[[ADDF_14:.*]] = arith.addf %{{.*}}#1, %{{.*}}#0 +// CHECK: scf.yield %[[ADDF_14]] // CHECK: } else { -// CHECK: scf.yield %{{.*}} +// CHECK: scf.yield %[[CF0]] // CHECK: } -// CHECK: %[[IF_11:.*]] = scf.if %[[CMPI_11]] -// CHECK: %[[MULF_13:.*]] = arith.mulf %[[IF_10]], %{{.*}} -// CHECK: scf.yield %[[MULF_13]] +// CHECK: %[[IF_12:.*]] = scf.if %[[CMPI_10]] +// CHECK: %[[MULF_14:.*]] = arith.mulf %[[IF_11]], %{{.*}} +// CHECK: scf.yield %[[MULF_14]] // CHECK: } else { -// CHECK: scf.yield %{{.*}} +// CHECK: scf.yield %[[CF0]] // CHECK: } -// CHECK: %[[SELECT_12:.*]] = arith.select %[[CMPI_11]], %[[IF_11]], %{{.*}}#0 -// CHECK: memref.store %[[SELECT_12]], %{{.*}}[%{{.*}}] +// CHECK: %[[SELECT_13:.*]] = arith.select %[[CMPI_10]], %[[IF_12]], %{{.*}}#0 +// CHECK: memref.store %[[SELECT_13]], %{{.*}}[%[[C0]]] func.func @dynamic_loop_result(%A: memref, %result: memref, %lb: index, %ub: index, %step: index) { %cf0 = arith.constant 1.0 : f32 %cf1 = arith.constant 33.0 : f32