diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h index df8f8bff368f8..223a57c3624a7 100644 --- a/aten/src/ATen/native/mps/OperationUtils.h +++ b/aten/src/ATen/native/mps/OperationUtils.h @@ -48,7 +48,7 @@ std::string getTensorsStringKey(const TensorList& tensors, bool use_scalar_value std::string getArrayRefString(const IntArrayRef s); // use has_storage() on the returned tensor to determine if src actually is a view Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst); -Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output); +Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output, id updatesBuffer = nil); MPSShape* getMPSShape(const Tensor& t); MPSShape* getMPSShape(IntArrayRef sizes); diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm index 22de684454d5c..35309341af142 100644 --- a/aten/src/ATen/native/mps/operations/Copy.mm +++ b/aten/src/ATen/native/mps/operations/Copy.mm @@ -208,14 +208,30 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src, NSUInteger alignedLength = 0; void* alignedPtr = pageAlignedBlockPtr(host_src, (NSUInteger)src_total_size, &alignedLength); - id sourceBuffer = [device newBufferWithBytesNoCopy:alignedPtr - length:alignedLength - options:options - deallocator:nil]; sourceOffset = uintptr_t(host_src) - uintptr_t(alignedPtr); sourceOffset += src_.storage_offset() * src_.itemsize(); - stream->copy_and_sync(sourceBuffer, destBuffer, size_to_copy, sourceOffset, dst_byte_offset, non_blocking); + id sourceBuffer = nil; + // If the destination is a strided MPS tensor, we cannot perform a blit directly to copy the + // memory from the CPU tensor into the MPS tensor. We need to scatter the data into the right indices + bool doScatter = (!dst_.is_contiguous() && src.is_contiguous()); + if (doScatter) { + sourceBuffer = [device newBufferWithBytes:(void*)((uint8_t*)host_src + (src_.storage_offset() * src_.itemsize())) + length:size_to_copy + options:options]; + } + else { + sourceBuffer = [device newBufferWithBytesNoCopy:alignedPtr + length:alignedLength + options:options + deallocator:nil]; + } + + if (doScatter) { + scatterViewTensor(src, dst_, sourceBuffer); + } else { + stream->copy_and_sync(sourceBuffer, destBuffer, size_to_copy, sourceOffset, dst_byte_offset, non_blocking); + } [sourceBuffer release]; } diff --git a/aten/src/ATen/native/mps/operations/View.mm b/aten/src/ATen/native/mps/operations/View.mm index c4f1cc3ca28f6..043f0a110bb94 100644 --- a/aten/src/ATen/native/mps/operations/View.mm +++ b/aten/src/ATen/native/mps/operations/View.mm @@ -28,7 +28,12 @@ } // initializes the MTLBuffers for tensor data and runs the MPSGraph for the view op -static Tensor& runViewGraph(ViewCachedGraph* cachedGraph, const at::Tensor& src, Tensor& output, bool needsScatter) +static Tensor& runViewGraph( + ViewCachedGraph* cachedGraph, + const at::Tensor& src, + Tensor& output, + bool needsScatter, + id updatesBuffer = nil) { const id sourceBuffer = getMTLBufferStorage(src); const id outputBuffer = getMTLBufferStorage(output); @@ -49,7 +54,7 @@ shape: inputShape dataType: inputType] autorelease]; if (needsScatter) { - feeds[cachedGraph->updatesTensor] = [[[MPSGraphTensorData alloc] initWithMTLBuffer: sourceBuffer + feeds[cachedGraph->updatesTensor] = [[[MPSGraphTensorData alloc] initWithMTLBuffer: (updatesBuffer != nil) ? updatesBuffer : sourceBuffer shape: getMPSShape(src.numel()) dataType: inputType] autorelease]; } @@ -603,11 +608,11 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) return runViewGraph(cachedGraph, src, dst.has_storage() ? dst : output, /*needsScatter*/ false); } -Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output) +Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output, id updatesBuffer) { ViewCachedGraph* cachedGraph = createViewGraph(output, output.sizes(), output.strides(), output.storage_offset(), /*needsScatter*/ true); - return runViewGraph(cachedGraph, src, output, /*needsScatter*/ true); + return runViewGraph(cachedGraph, src, output, /*needsScatter*/ true, updatesBuffer); } } // namespace mps diff --git a/test/test_mps.py b/test/test_mps.py index 29ac0872d6a6f..7e0dafda8ab60 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -1290,6 +1290,19 @@ def test_expand_cpu_to_mps_copy(self): self.assertEqual(x_cpu, x.cpu()) + def test_cpu_to_strided_mps_copy(self): + # https://github.com/pytorch/pytorch/issues/86975 + + a1 = torch.Tensor([[1,2],[3,4], [5,6]]).to(torch.device("mps")) + b1 = torch.Tensor([-1, -1]) + a1[1:,1] = b1 + + a2 = torch.Tensor([[1,2],[3,4], [5,6]]).to(torch.device("mps")) + b2 = torch.Tensor([-1, -1]).to(torch.device("mps")) + a2[1:,1] = b2 + + self.assertEqual(a1, a2) + def test_view_slice(self): # https://github.com/pytorch/pytorch/issues/83995 NUM_SAMPLES=60 @@ -6128,7 +6141,6 @@ def test_view(self, device="mps"): self.assertRaises(RuntimeError, lambda: tensor.view(7, -1)) self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1)) - # RuntimeError: Invalid device for storage: mps def test_contiguous(self, device="mps"): x = torch.randn(1, 16, 5, 5, device=device) self.assertTrue(x.is_contiguous())