diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp index 3af3e5dc59533..0227816b77845 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -1118,6 +1118,7 @@ struct ParallelInsertSliceOpInterface LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, const AnalysisState &state) const { + // RaW conflicts are resolved as part of ForeachThreadOp. return success(); } @@ -1129,9 +1130,7 @@ struct ParallelInsertSliceOpInterface auto foreachThreadOp = cast(performConcurrentlyOp->getParentOp()); - // If the op bufferizes out-of-place, allocate the copy before the - // ForeachThreadOp. - rewriter.setInsertionPoint(foreachThreadOp); + // Get destination buffer. FailureOr destBuffer = getBuffer(rewriter, insertOp.getDest(), options); if (failed(destBuffer)) @@ -1156,6 +1155,8 @@ struct ParallelInsertSliceOpInterface rewriter.setInsertionPointAfter(foreachThreadOp); Value toTensorOp = rewriter.create(foreachThreadOp.getLoc(), *destBuffer); + // PerformConcurrentlyOp can have multiple ParallelInserSliceOps. Find the + // index of `op` within yielding ops. unsigned resultNum = 0; for (Operation &nextOp : performConcurrentlyOp.yieldingOps()) { if (&nextOp == op)