diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp index eb54ceeadea0a7..d31ec52cdb313e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -615,11 +615,12 @@ static OpResult getAliasingOpResult(OpOperand &opOperand) { // Predeclaration of function. static bool bufferizesToMemoryRead(OpOperand &opOperand); -/// scf::ForOp alone doesn't bufferize to a memory read, one of the uses of its -/// matching bbArg may. -static bool bufferizesToMemoryRead(scf::ForOp forOp, OpOperand &opOperand) { +/// Return true if the given value is read by an op that bufferizes to a memory +/// read. Also takes into account ops that create an alias but do not read by +/// themselves (e.g., ExtractSliceOp). +static bool isValueRead(Value value) { SmallVector workingSet; - for (OpOperand &use : forOp.getRegionIterArgForOpOperand(opOperand).getUses()) + for (OpOperand &use : value.getUses()) workingSet.push_back(&use); while (!workingSet.empty()) { @@ -647,8 +648,10 @@ static bool bufferizesToMemoryRead(OpOperand &opOperand) { // may. if (isa(opOperand.getOwner())) return false; + // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of its + // matching bbArg may. if (auto forOp = dyn_cast(opOperand.getOwner())) - return bufferizesToMemoryRead(forOp, opOperand); + return isValueRead(forOp.getRegionIterArgForOpOperand(opOperand)); // TiledLoop alone doesn't bufferize to a memory read, one of the uses of its // matching bbArg may. if (auto tiledLoopOp = dyn_cast(opOperand.getOwner())) { @@ -1437,7 +1440,13 @@ static Value getResultBuffer(OpBuilder &b, OpResult result, // Allocate the result buffer. Value resultBuffer = createNewAllocDeallocPairForShapedValue(b, loc, operand, aliasInfo); - if (!skipCopy && !isInitTensorOp(operand)) { + // Do not copy the result of an InitTensorOp. + if (isInitTensorOp(operand)) + skipCopy = true; + // Do not copy if the copied data is never read. + if (!isValueRead(result)) + skipCopy = true; + if (!skipCopy) { // Set insertion point now that potential alloc/dealloc are introduced. b.setInsertionPoint(op); b.create(loc, operandBuffer, resultBuffer); @@ -2002,7 +2011,9 @@ static LogicalResult bufferize(OpBuilder &b, ExtractSliceOp extractSliceOp, /// If not inplaceable, copy. if (alloc) { - b.create(extractSliceOp.getLoc(), subView, alloc); + // Do not copy if the copied data is never read. + if (isValueRead(extractSliceOp.result())) + b.create(extractSliceOp.getLoc(), subView, alloc); subView = alloc; } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir index 693974cf9296cf..b012409fd873fa 100644 --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -720,9 +720,6 @@ func @matmul( tensor<256x192xf32> to tensor<256x16xf32> // %4 does not match an insert_slice, it cannot be bufferized inplace and needs to alloc. - // CHECK: %[[T:.*]] = memref.subview %[[C]][%[[I]], %[[J]]] [8, 16] [1, 1] - // TODO: %4 is never read but just overwritten, this copy can be elided. - // CHECK: linalg.copy(%[[T]], %[[ALLOC]]) %4 = tensor.extract_slice %C[%arg3, %arg5] [8, 16] [1, 1] : tensor<128x192xf32> to tensor<8x16xf32> @@ -748,6 +745,7 @@ func @matmul( // insert_slice is inplace but its source comes from an equivalent buffer // that is not in place. So we must insert a copy of the small buffer into // the bigger buffer. + // CHECK: %[[T:.*]] = memref.subview %[[C]][%[[I]], %[[J]]] [8, 16] [1, 1] // CHECK: linalg.copy(%[[ALLOC]], %[[T]]) %7 = tensor.insert_slice %6 into %arg6[%arg3, %arg5] [8, 16] [1, 1] : tensor<8x16xf32> into tensor<128x192xf32> @@ -819,9 +817,6 @@ func @buffer_forwarding_conflict( // insert_slice. InitTensorOp replaces the init_tensor with an out-of-place // extract_slice. // CHECK: %[[EXTRACT_SLICE_ALLOC:.*]] = memref.alloc(%[[sz]]) - // CHECK: %[[T_SUBVIEW:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1] - // TODO: This copy can be avoided because the copied data is never read. - // CHECK: linalg.copy(%[[T_SUBVIEW]], %[[EXTRACT_SLICE_ALLOC]]) %a = linalg.init_tensor[%sz] : tensor // CHECK: linalg.fill({{.*}}, %[[EXTRACT_SLICE_ALLOC]]) : f32, memref @@ -832,6 +827,7 @@ func @buffer_forwarding_conflict( // CHECK: linalg.copy(%[[EXTRACT_SLICE_ALLOC]], %[[SV0_ALLOC]]) : memref, memref %r0 = tensor.insert_slice %f into %t[0][%sz][1]: tensor into tensor + // CHECK: %[[T_SUBVIEW:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1] // CHECK: linalg.copy(%[[EXTRACT_SLICE_ALLOC]], %[[T_SUBVIEW]]) %r1 = tensor.insert_slice %f into %t[42][%sz][1]: tensor into tensor