diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp index 785cd0d7806dc..fb514a2f2b085 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -949,64 +949,32 @@ struct ForeachThreadOpInterface return success(); } - LogicalResult bufferize(Operation *op, RewriterBase &b, + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { - OpBuilder::InsertionGuard g(b); auto foreachThreadOp = cast(op); - // Gather new results of the ForeachThreadOp. - SmallVector newResults; - for (OpResult opResult : foreachThreadOp->getOpResults()) { - OpOperand *insertDest = - getInsertionDest(foreachThreadOp)[opResult.getResultNumber()]; - // Insert copies right before the PerformConcurrentlyOp terminator. They - // should not be inside terminator (which would be the default insertion - // point). - Value buffer = getBuffer(b, insertDest->get(), options); - newResults.push_back(buffer); - } +#ifndef NDEBUG + // ParallelInsertSliceOpInterface replaces all uses. + for (OpResult opResult : foreachThreadOp->getOpResults()) + assert(opResult.getUses().empty() && + "expected that all uses were already replaced"); +#endif // NDEBUG // Create new ForeachThreadOp without any results and drop the automatically // introduced terminator. TypeRange newResultTypes; - auto newForeachThreadOp = - b.create(foreachThreadOp.getLoc(), newResultTypes, - foreachThreadOp.getNumThreads()); + auto newForeachThreadOp = rewriter.create( + foreachThreadOp.getLoc(), newResultTypes, + foreachThreadOp.getNumThreads()); newForeachThreadOp.getBody()->getTerminator()->erase(); // Move over block contents of the old op. - b.mergeBlocks(foreachThreadOp.getBody(), newForeachThreadOp.getBody(), - {newForeachThreadOp.getBody()->getArguments()}); - - // Bufferize terminator. - auto performConcurrentlyOp = cast( - newForeachThreadOp.getBody()->getTerminator()); - b.setInsertionPoint(performConcurrentlyOp); - unsigned resultCounter = 0; - WalkResult walkResult = - performConcurrentlyOp.walk([&](ParallelInsertSliceOp insertOp) { - Location loc = insertOp.getLoc(); - Type srcType = getMemRefType( - insertOp.getSource().getType().cast(), options); - // ParallelInsertSliceOp bufferizes to a copy. - auto srcMemref = b.create( - loc, srcType, insertOp.getSource()); - Value destMemref = newResults[resultCounter++]; - Value subview = b.create( - loc, destMemref, insertOp.getMixedOffsets(), - insertOp.getMixedSizes(), insertOp.getMixedStrides()); - // This memcpy will fold away if everything bufferizes in-place. - if (failed(options.createMemCpy(b, insertOp.getLoc(), srcMemref, - subview))) - return WalkResult::interrupt(); - b.eraseOp(insertOp); - return WalkResult::advance(); - }); - if (walkResult.wasInterrupted()) - return failure(); + rewriter.mergeBlocks(foreachThreadOp.getBody(), + newForeachThreadOp.getBody(), + {newForeachThreadOp.getBody()->getArguments()}); - // Replace the op. - replaceOpWithBufferizedValues(b, op, newResults); + // Remove the old op. + rewriter.eraseOp(op); return success(); } @@ -1104,9 +1072,50 @@ struct ParallelInsertSliceOpInterface return success(); } - LogicalResult bufferize(Operation *op, RewriterBase &b, + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { - // Will be bufferized as part of ForeachThreadOp. + OpBuilder::InsertionGuard g(rewriter); + auto insertOp = cast(op); + auto performConcurrentlyOp = cast(op->getParentOp()); + auto foreachThreadOp = + cast(performConcurrentlyOp->getParentOp()); + + // If the op bufferizes out-of-place, allocate the copy before the + // ForeachThreadOp. + rewriter.setInsertionPoint(foreachThreadOp); + Value destBuffer = getBuffer(rewriter, insertOp.getDest(), options); + + // Bufferize the ParallelInsertSliceOp outside of the PerformConcurrentlyOp. + rewriter.setInsertionPoint(performConcurrentlyOp); + Value srcBuffer = getBuffer(rewriter, insertOp.getSource(), options); + Value subview = rewriter.create( + insertOp.getLoc(), destBuffer, insertOp.getMixedOffsets(), + insertOp.getMixedSizes(), insertOp.getMixedStrides()); + // This memcpy will fold away if everything bufferizes in-place. + if (failed(options.createMemCpy(rewriter, insertOp.getLoc(), srcBuffer, + subview))) + return failure(); + rewriter.eraseOp(op); + + // Replace all uses of ForeachThreadOp (just the corresponding result). + rewriter.setInsertionPointAfter(foreachThreadOp); + Value toTensorOp = + rewriter.create(foreachThreadOp.getLoc(), destBuffer); + unsigned resultNum = 0; + for (Operation &nextOp : performConcurrentlyOp.yieldingOps()) { + if (&nextOp == op) + break; + resultNum++; + } + assert(resultNum < foreachThreadOp->getNumResults() && + "ParallelInsertSliceOp not found in PerformConcurrentlyOp"); + SmallVector resultUses = llvm::to_vector( + llvm::map_range(foreachThreadOp->getResult(resultNum).getUses(), + [](OpOperand &use) { return &use; })); + for (OpOperand *use : resultUses) { + rewriter.updateRootInPlace(use->getOwner(), + [&]() { use->set(toTensorOp); }); + } return success(); }