Skip to content

Commit

Permalink
[MPS] Fix type casting copy with storage offset (pytorch#95573)
Browse files Browse the repository at this point in the history
This PR handles the case where the `dst` tensor of type casting has a storage offset by creating a temporary buffer to store results and then copy them back to the dst with the offset added.

Fixes pytorch#95417

Pull Request resolved: pytorch#95573
Approved by: https://github.com/kulinseth
  • Loading branch information
qqaatw authored and Ho Yin Chau committed Apr 10, 2023
1 parent f08d20a commit 79783a0
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
14 changes: 10 additions & 4 deletions aten/src/ATen/native/mps/operations/Copy.mm
Original file line number Diff line number Diff line change
Expand Up @@ -284,13 +284,19 @@ void copy_blit_mps(void* dst, const void* src, size_t size) {
src._set_conj(src_.is_conj());
src._set_neg(src_.is_neg());

const size_t src_size = src.nbytes();
MPSStream* stream = getCurrentMPSStream();
if (sameDataType) {
MPSStream* stream = getCurrentMPSStream();
// for GPU to GPU copies we only encode to stream's command buffer (no flushing)
stream->copy(sourceBuffer, destBuffer, src_size, src_byte_offset, dst_byte_offset);
stream->copy(sourceBuffer, destBuffer, src.nbytes(), src_byte_offset, dst_byte_offset);
} else {
copy_cast_mps(dst_, src, destBuffer, sourceBuffer);
if (dst_byte_offset) {
auto tmp = at::native::empty_mps(dst_.sizes(), dst_.scalar_type(), c10::nullopt, kMPS);
auto tmpBuffer = getMTLBufferStorage(tmp);
copy_cast_mps(tmp, src, tmpBuffer, sourceBuffer);
stream->copy(tmpBuffer, destBuffer, dst_.nbytes(), 0, dst_byte_offset);
} else {
copy_cast_mps(dst_, src, destBuffer, sourceBuffer);
}
}
return dst_;
}
Expand Down
10 changes: 10 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2631,6 +2631,16 @@ def test_copy_non_contiguous(self):
y.permute(3, 2, 1, 0)[1::, ::2] = z
self.assertEqual(x, y.to('cpu'))

# See https://github.com/pytorch/pytorch/issues/95417
def test_copy_storage_offset(self):
x_cpu = torch.zeros(5, device="cpu", dtype=torch.float32)
x_mps = torch.zeros(5, device="mps", dtype=torch.float32)
update_cpu = torch.tensor([1, 1], device="cpu", dtype=torch.int64)
update_mps = torch.tensor([1, 1], device="mps", dtype=torch.int64)
x_cpu[2:4] = update_cpu
x_mps[2:4] = update_mps # implicit type casting and copy
self.assertEqual(x_cpu, x_mps)

# See https://github.com/pytorch/pytorch/pull/84742
# and https://github.com/pytorch/pytorch/pull/78319
def test_binops_dtype_precedence(self):
Expand Down

0 comments on commit 79783a0

Please sign in to comment.