@@ -191,7 +191,7 @@ tileAndFuseProducerOfSliceImpl(RewriterBase &rewriter,
191191// / @param candidateSliceOp: %4 = extract %args2
192192// / @param backwardSlice: in-out parameter populated by backward extractSliceOps
193193// / @return OpResult Producer : %0 = producer
194- FailureOr<OpResult> mlir::scfX::getRealProducerOfExtractSliceOp (
194+ FailureOr<OpResult> mlir::scfX::getRealProducerFromExtractSliceOp (
195195 Operation *candidateSliceOp,
196196 SmallVector<tensor::ExtractSliceOp> &backwardSlice, unsigned curDepth,
197197 unsigned maxDepth) {
@@ -216,8 +216,8 @@ FailureOr<OpResult> mlir::scfX::getRealProducerOfExtractSliceOp(
216216 }
217217 if (auto sliceOp = rootSource.getDefiningOp <tensor::ExtractSliceOp>()) {
218218 // walk up loop to find larger candidate extractSliceOp
219- return getRealProducerOfExtractSliceOp (sliceOp, backwardSlice,
220- curDepth + 1 );
219+ return getRealProducerFromExtractSliceOp (sliceOp, backwardSlice,
220+ curDepth + 1 );
221221 }
222222 break ;
223223 }
@@ -278,6 +278,21 @@ struct ErasedOpListener : public RewriterBase::Listener {
278278 bool isErased (Operation *op) { return erased.count (op); }
279279};
280280
281+ // / Check if it is the ForOp that yield the result of inner loop
282+ static LogicalResult isForOpYieldResultOfInnerLoop (LoopLikeOpInterface loop) {
283+ if (auto forOp = dyn_cast<scf::ForOp>(loop.getOperation ())) {
284+ Block::OpListType &opsInLoopBody = forOp.getBody ()->getOperations ();
285+ for (auto &&[index, op] : llvm::enumerate (opsInLoopBody)) {
286+ // If the orderIndex of inner loop is the last second one before the
287+ // yieldOp of ForOp, the given loop must yield the result of inner loop.
288+ if (isa<LoopLikeOpInterface>(op)) {
289+ return success ((index + 2 ) == opsInLoopBody.size ());
290+ }
291+ }
292+ }
293+ return failure ();
294+ }
295+
281296// / Enhanced version of `tileAndFuseProducerOfSliceImpl`, which can deal with
282297// / multi-level `extractSliceOp`. E.g.
283298// /
@@ -293,7 +308,9 @@ std::optional<scf::SCFFuseProducerOfSliceResult>
293308mlir::scfX::tileAndFuseProducerOfSlice (RewriterBase &rewriter,
294309 Operation *candidateSliceOp) {
295310 SmallVector<tensor::ExtractSliceOp> backwardSlice;
296- if (failed (getRealProducerOfExtractSliceOp (candidateSliceOp, backwardSlice)))
311+ FailureOr<OpResult> realProducer =
312+ getRealProducerFromExtractSliceOp (candidateSliceOp, backwardSlice);
313+ if (failed (realProducer))
297314 return std::nullopt ;
298315
299316 std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResult;
@@ -303,14 +320,18 @@ mlir::scfX::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
303320 for (auto &&[index, sliceOp] : llvm::enumerate (backwardSlice)) {
304321 // get nest loops between next candidate sliceOp and tiled producer.
305322 auto whileProducerOutOfLoopBlock =
306- [&fuseProducerResult](LoopLikeOpInterface loop) -> LogicalResult {
307- if (fuseProducerResult) {
308- Block &body = loop->getRegion (0 ).front ();
309- if (fuseProducerResult->tiledAndFusedProducer .getDefiningOp ()
310- ->getBlock () == &body)
311- return failure ();
312- }
313- return success ();
323+ [&fuseProducerResult,
324+ &realProducer](LoopLikeOpInterface loop) -> LogicalResult {
325+ // ensure that all surrounding outer loops are just yielding the result of
326+ // the inner loops.
327+ if (failed (isForOpYieldResultOfInnerLoop (loop)))
328+ return failure ();
329+ Operation *originalOp =
330+ fuseProducerResult
331+ ? fuseProducerResult->tiledAndFusedProducer .getDefiningOp ()
332+ : realProducer->getDefiningOp ();
333+ Block &body = loop->getRegion (0 ).front ();
334+ return success (originalOp->getBlock () != &body);
314335 };
315336 SmallVector<LoopLikeOpInterface> outerLoops =
316337 getOuterNestLoopsWhile (sliceOp->getParentOfType <LoopLikeOpInterface>(),
@@ -515,21 +536,6 @@ static FailureOr<OpOperand *> getConsumerFromUses(Value val,
515536 return operand;
516537}
517538
518- // / Check if it is the ForOp that yield the result of inner loop
519- static LogicalResult isForOpYieldResultOfInnerLoop (LoopLikeOpInterface loop) {
520- if (auto forOp = dyn_cast<scf::ForOp>(loop.getOperation ())) {
521- for (auto &&[index, op] :
522- llvm::enumerate (forOp.getBody ()->getOperations ())) {
523- // If the orderIndex of inner loop is the last second one before the
524- // yieldOp of ForOp, the given loop must yield the result of inner loop.
525- if (isa<LoopLikeOpInterface>(op)) {
526- return success ((index + 2 ) == forOp.getBody ()->getOperations ().size ());
527- }
528- }
529- }
530- return failure ();
531- }
532-
533539// / Fetch the untiled consumer of a scf.for's result which is yielded by a
534540// / tensor.insert_slice. This function makes the following assumptions that
535541// / tensor.insert_slice has scf.yield as its only user.
0 commit comments