diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp index 5537a8b212c51..20fa8089201aa 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp @@ -61,6 +61,11 @@ struct LoopPipelinerInternal { /// `idx` of `key` in the epilogue. void setValueMapping(Value key, Value el, int64_t idx); + /// Return the defining op of the given value, if the Value is an argument of + /// the loop return the associated defining op in the loop and its distance to + /// the Value. + std::pair getDefiningOpAndDistance(Value value); + public: /// Initalize the information for the given `op`, return true if it /// satisfies the pre-condition to apply pipelining. @@ -240,11 +245,12 @@ LoopPipelinerInternal::analyzeCrossStageValues() { unsigned stage = stages[op]; auto analyzeOperand = [&](OpOperand &operand) { - Operation *def = operand.get().getDefiningOp(); + auto [def, distance] = getDefiningOpAndDistance(operand.get()); if (!def) return; auto defStage = stages.find(def); - if (defStage == stages.end() || defStage->second == stage) + if (defStage == stages.end() || defStage->second == stage || + defStage->second == stage + distance) return; assert(stage > defStage->second); LiverangeInfo &info = crossStageValues[operand.get()]; @@ -261,6 +267,25 @@ LoopPipelinerInternal::analyzeCrossStageValues() { return crossStageValues; } +std::pair +LoopPipelinerInternal::getDefiningOpAndDistance(Value value) { + int64_t distance = 0; + if (auto arg = dyn_cast(value)) { + if (arg.getOwner() != forOp.getBody()) + return {nullptr, 0}; + // Ignore induction variable. + if (arg.getArgNumber() == 0) + return {nullptr, 0}; + distance++; + value = + forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1); + } + Operation *def = value.getDefiningOp(); + if (!def) + return {nullptr, 0}; + return {def, distance}; +} + scf::ForOp LoopPipelinerInternal::createKernelLoop( const llvm::MapVector &crossStageValues, @@ -366,10 +391,9 @@ LogicalResult LoopPipelinerInternal::createKernel( rewriter.setInsertionPointAfter(newOp); continue; } - auto arg = dyn_cast(operand->get()); + Value source = operand->get(); + auto arg = dyn_cast(source); if (arg && arg.getOwner() == forOp.getBody()) { - // If the value is a loop carried value coming from stage N + 1 remap, - // it will become a direct use. Value ret = forOp.getBody()->getTerminator()->getOperand( arg.getArgNumber() - 1); Operation *dep = ret.getDefiningOp(); @@ -378,15 +402,19 @@ LogicalResult LoopPipelinerInternal::createKernel( auto stageDep = stages.find(dep); if (stageDep == stages.end() || stageDep->second == useStage) continue; - assert(stageDep->second == useStage + 1); - nestedNewOp->setOperand(operand->getOperandNumber(), - mapping.lookupOrDefault(ret)); - continue; + // If the value is a loop carried value coming from stage N + 1 remap, + // it will become a direct use. + if (stageDep->second == useStage + 1) { + nestedNewOp->setOperand(operand->getOperandNumber(), + mapping.lookupOrDefault(ret)); + continue; + } + source = ret; } // For operands defined in a previous stage we need to remap it to use // the correct region argument. We look for the right version of the // Value based on the stage where it is used. - Operation *def = operand->get().getDefiningOp(); + Operation *def = source.getDefiningOp(); if (!def) continue; auto stageDef = stages.find(def); @@ -418,9 +446,29 @@ LogicalResult LoopPipelinerInternal::createKernel( // We create a mapping between original values and the associated loop // returned values that will be needed by the epilogue. llvm::SmallVector yieldOperands; - for (Value retVal : forOp.getBody()->getTerminator()->getOperands()) { - yieldOperands.push_back(mapping.lookupOrDefault(retVal)); + for (OpOperand &yieldOperand : + forOp.getBody()->getTerminator()->getOpOperands()) { + Value source = mapping.lookupOrDefault(yieldOperand.get()); + // When we don't peel the epilogue and the yield value is used outside the + // loop we need to make sure we return the version from numStages - + // defStage. + if (!peelEpilogue && + !forOp.getResult(yieldOperand.getOperandNumber()).use_empty()) { + Operation *def = getDefiningOpAndDistance(yieldOperand.get()).first; + if (def) { + auto defStage = stages.find(def); + if (defStage != stages.end() && defStage->second < maxStage) { + Value pred = predicates[defStage->second]; + source = rewriter.create( + pred.getLoc(), pred, source, + newForOp.getBody() + ->getArguments()[yieldOperand.getOperandNumber() + 1]); + } + } + } + yieldOperands.push_back(source); } + for (auto &it : crossStageValues) { int64_t version = maxStage - it.second.lastUseStage + 1; unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage; @@ -444,9 +492,11 @@ LogicalResult LoopPipelinerInternal::createKernel( Operation *def = retVal.value().getDefiningOp(); assert(def && "Only support loop carried dependencies of distance 1"); unsigned defStage = stages[def]; - setValueMapping(forOp.getRegionIterArgs()[retVal.index()], - newForOp->getResult(retVal.index()), - maxStage - defStage + 1); + if (defStage > 0) { + setValueMapping(forOp.getRegionIterArgs()[retVal.index()], + newForOp->getResult(retVal.index()), + maxStage - defStage + 1); + } } rewriter.create(forOp.getLoc(), yieldOperands); return success(); diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir index 0309287e409c1..4cd686d2cdb86 100644 --- a/mlir/test/Dialect/SCF/loop-pipelining.mlir +++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir @@ -670,4 +670,56 @@ func.func @backedge_mix_order(%A: memref) -> f32 { scf.yield %A3_elem : f32 } { __test_pipelining_loop__ } return %r : f32 -} \ No newline at end of file +} + +// ----- + +// CHECK-LABEL: @distance_1_use +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// Prologue: +// CHECK: %[[L0:.+]] = memref.load %{{.*}}[%[[C0]]] : memref +// CHECK: %[[L1:.+]] = memref.load %{{.*}}[%[[C1]]] : memref +// CHECK: %[[R:.+]]:5 = scf.for {{.*}} iter_args(%[[IDX0:.+]] = %[[C2]], %[[L2:.+]] = %[[L0]], %[[L3:.+]] = %[[L1]] +// CHECK: %[[L4:.+]] = memref.load %{{.*}}[%[[IDX0]]] : memref +// CHECK: %[[IDX1:.+]] = arith.addi %[[IDX0]], %[[C1]] : index +// CHECK: memref.store %[[L2]] +// CHECK: scf.yield %[[IDX1]], %[[L3]], %[[L4]] +func.func @distance_1_use(%A: memref, %result: memref) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %cf = arith.constant 1.0 : f32 + %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%idx = %c0) -> (index) { + %A_elem = memref.load %A[%idx] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 0 } : memref + %idx1 = arith.addi %idx, %c1 { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 1 } : index + memref.store %A_elem, %result[%idx] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 2 } : memref + scf.yield %idx1 : index + } { __test_pipelining_loop__ } + return +} + +// ----- + +// NOEPILOGUE-LABEL: stage_0_value_escape( +func.func @stage_0_value_escape(%A: memref, %result: memref) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %cf = arith.constant 1.0 : f32 +// NOEPILOGUE: %[[C3:.+]] = arith.constant 3 : index +// NOEPILOGUE: %[[A:.+]] = arith.addf +// NOEPILOGUE: scf.for %[[IV:.+]] = {{.*}} iter_args(%[[ARG:.+]] = %[[A]], +// NOEPILOGUE: %[[C:.+]] = arith.cmpi slt, %[[IV]], %[[C3]] : index +// NOEPILOGUE: %[[S:.+]] = arith.select %[[C]], %{{.+}}, %[[ARG]] : f32 +// NOEPILOGUE: scf.yield %[[S]] + %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) { + %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 1 } : memref + %A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32 + memref.store %A1_elem, %result[%c0] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 2 } : memref + scf.yield %A1_elem : f32 + } { __test_pipelining_loop__ } + memref.store %r, %result[%c1] : memref + return +}