diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index cde966592212d..5de21cc350ac1 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -578,6 +578,10 @@ def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [ /// Return the number of leading operands before `offsets`, `sizes` and /// `strides` operands. static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; } + + /// Return the OpResult of the enclosing ForeachThreadOp that is + /// corresponding to this ParallelInsertSliceOp. + OpResult getTiedOpResult(); }]; let builders = [ diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index bd0f16dbd0e07..557a9edc2f18e 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1215,6 +1215,18 @@ ForeachThreadOp mlir::scf::getForeachThreadOpThreadIndexOwner(Value val) { // ParallelInsertSliceOp //===----------------------------------------------------------------------===// +OpResult ParallelInsertSliceOp::getTiedOpResult() { + auto foreachThreadOp = getOperation()->getParentOfType(); + assert(foreachThreadOp && "unlinked ParallelInsertSliceOp"); + PerformConcurrentlyOp performConcurrentlyOp = foreachThreadOp.getTerminator(); + for (const auto &it : llvm::enumerate(performConcurrentlyOp.yieldingOps())) { + Operation &nextOp = it.value(); + if (&nextOp == getOperation()) + return foreachThreadOp->getResult(it.index()); + } + llvm_unreachable("ParallelInsertSliceOp not found"); +} + // Build a ParallelInsertSliceOp with mixed static and dynamic entries. void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result, Value source, Value dest, diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp index 0227816b77845..1f6359bfbb497 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -961,42 +961,6 @@ struct ForeachThreadOpInterface return BufferRelation::Equivalent; } - LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, - const AnalysisState &state) const { - auto bufferizableOp = cast(op); - if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state))) - return failure(); - - OpBuilder::InsertionGuard g(rewriter); - auto foreachThreadOp = cast(op); - for (OpResult opResult : foreachThreadOp->getOpResults()) { - SmallVector destOperands = - state.getAliasingOpOperand(opResult); - assert(destOperands.size() == 1 && - "expected exactly one aliasing OpOperand"); - assert(isa(destOperands.front()->getOwner()) && - "expected ParallelInsertSliceOp"); - - // Nothing to do if there is no conflict. - if (state.isInPlace(*destOperands.front())) - continue; - - // Insert tensor allocation. - bool isYielded = state.isTensorYielded(opResult); - FailureOr alloc = allocateTensorForShapedValue( - rewriter, op->getLoc(), destOperands.front()->get(), - /*escape=*/isYielded, state.getOptions()); - if (failed(alloc)) - return failure(); - - // Update terminator operand. - rewriter.updateRootInPlace(destOperands.front()->getOwner(), - [&]() { destOperands.front()->set(*alloc); }); - } - - return success(); - } - LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto foreachThreadOp = cast(op); @@ -1118,7 +1082,55 @@ struct ParallelInsertSliceOpInterface LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, const AnalysisState &state) const { - // RaW conflicts are resolved as part of ForeachThreadOp. + // This interface method is overridden because we want to set a custom + // insertion point for tensor copies. They should be inserted right before + // the ForeachThreadOp. E.g.: + // + // %r0, %r1 = foreach_thead ... { + // ... + // perform_concurrently { + // parallel_insert_slice %a into %b ... {inplace = ["true", "true"]} + // parallel_insert_slice %c into %d ... {inplace = ["true", "false"]} + // } + // } + // + // After TensorCopyInsertion: + // + // %copy = bufferization.alloc_tensor() copy(%d) + // %r0, %r1 = foreach_thead ... { + // ... + // perform_concurrently { + // parallel_insert_slice %a into %b ... + // parallel_insert_slice %c into %copy ... + // } + // } + + OpBuilder::InsertionGuard g(rewriter); + auto insertOp = cast(op); + auto foreachThreadOp = insertOp->getParentOfType(); + + // Nothing to do if the destination tensor is inplace. + assert(state.isInPlace(op->getOpOperand(0) /*src*/) && + "source is always in-place"); + if (state.isInPlace(op->getOpOperand(1) /*dest*/)) + return success(); + + // Find corresponding OpResult. + OpResult opResult = insertOp.getTiedOpResult(); + + // Insert tensor allocation right before the ForeachThreadOp. + rewriter.setInsertionPoint(foreachThreadOp); + bool isYielded = state.isTensorYielded(opResult); + FailureOr alloc = + allocateTensorForShapedValue(rewriter, op->getLoc(), insertOp.getDest(), + /*escape=*/isYielded, state.getOptions()); + if (failed(alloc)) + return failure(); + + // Update destination operand. + rewriter.updateRootInPlace( + insertOp, [&]() { insertOp.getDestMutable().assign(*alloc); }); + return success(); } @@ -1149,29 +1161,20 @@ struct ParallelInsertSliceOpInterface if (failed(options.createMemCpy(rewriter, insertOp.getLoc(), *srcBuffer, subview))) return failure(); - rewriter.eraseOp(op); // Replace all uses of ForeachThreadOp (just the corresponding result). rewriter.setInsertionPointAfter(foreachThreadOp); Value toTensorOp = rewriter.create(foreachThreadOp.getLoc(), *destBuffer); - // PerformConcurrentlyOp can have multiple ParallelInserSliceOps. Find the - // index of `op` within yielding ops. - unsigned resultNum = 0; - for (Operation &nextOp : performConcurrentlyOp.yieldingOps()) { - if (&nextOp == op) - break; - resultNum++; - } - assert(resultNum < foreachThreadOp->getNumResults() && - "ParallelInsertSliceOp not found in PerformConcurrentlyOp"); - SmallVector resultUses = llvm::to_vector( - llvm::map_range(foreachThreadOp->getResult(resultNum).getUses(), - [](OpOperand &use) { return &use; })); + // PerformConcurrentlyOp can have multiple ParallelInsertSliceOps. + SmallVector resultUses = + llvm::to_vector(llvm::map_range(insertOp.getTiedOpResult().getUses(), + [](OpOperand &use) { return &use; })); for (OpOperand *use : resultUses) { rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(toTensorOp); }); } + rewriter.eraseOp(op); return success(); }