Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for copying cpu tensors into strided mps tensors #142

Merged
merged 2 commits into from
Oct 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aten/src/ATen/native/mps/OperationUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ std::string getTensorsStringKey(const TensorList& tensors, bool use_scalar_value
std::string getArrayRefString(const IntArrayRef s);
// use has_storage() on the returned tensor to determine if src actually is a view
Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst);
Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output);
Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output, id<MTLBuffer> updatesBuffer = nil);

MPSShape* getMPSShape(const Tensor& t);
MPSShape* getMPSShape(IntArrayRef sizes);
Expand Down
26 changes: 21 additions & 5 deletions aten/src/ATen/native/mps/operations/Copy.mm
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,30 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src,
NSUInteger alignedLength = 0;

void* alignedPtr = pageAlignedBlockPtr(host_src, (NSUInteger)src_total_size, &alignedLength);
id<MTLBuffer> sourceBuffer = [device newBufferWithBytesNoCopy:alignedPtr
length:alignedLength
options:options
deallocator:nil];
sourceOffset = uintptr_t(host_src) - uintptr_t(alignedPtr);
sourceOffset += src_.storage_offset() * src_.itemsize();

stream->copy_and_sync(sourceBuffer, destBuffer, size_to_copy, sourceOffset, dst_byte_offset, non_blocking);
id<MTLBuffer> sourceBuffer = nil;
// If the destination is a strided MPS tensor, we cannot perform a blit directly to copy the
// memory from the CPU tensor into the MPS tensor. We need to scatter the data into the right indices
bool doScatter = (!dst_.is_contiguous() && src.is_contiguous());
if (doScatter) {
sourceBuffer = [device newBufferWithBytes:(void*)((uint8_t*)host_src + (src_.storage_offset() * src_.itemsize()))
length:size_to_copy
options:options];
}
else {
sourceBuffer = [device newBufferWithBytesNoCopy:alignedPtr
length:alignedLength
options:options
deallocator:nil];
}

if (doScatter) {
scatterViewTensor(src, dst_, sourceBuffer);
} else {
stream->copy_and_sync(sourceBuffer, destBuffer, size_to_copy, sourceOffset, dst_byte_offset, non_blocking);
}
[sourceBuffer release];
}

Expand Down
13 changes: 9 additions & 4 deletions aten/src/ATen/native/mps/operations/View.mm
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@
}

// initializes the MTLBuffers for tensor data and runs the MPSGraph for the view op
static Tensor& runViewGraph(ViewCachedGraph* cachedGraph, const at::Tensor& src, Tensor& output, bool needsScatter)
static Tensor& runViewGraph(
ViewCachedGraph* cachedGraph,
const at::Tensor& src,
Tensor& output,
bool needsScatter,
id<MTLBuffer> updatesBuffer = nil)
{
const id<MTLBuffer> sourceBuffer = getMTLBufferStorage(src);
const id<MTLBuffer> outputBuffer = getMTLBufferStorage(output);
Expand All @@ -49,7 +54,7 @@
shape: inputShape
dataType: inputType] autorelease];
if (needsScatter) {
feeds[cachedGraph->updatesTensor] = [[[MPSGraphTensorData alloc] initWithMTLBuffer: sourceBuffer
feeds[cachedGraph->updatesTensor] = [[[MPSGraphTensorData alloc] initWithMTLBuffer: (updatesBuffer != nil) ? updatesBuffer : sourceBuffer
shape: getMPSShape(src.numel())
razarmehr marked this conversation as resolved.
Show resolved Hide resolved
dataType: inputType] autorelease];
}
Expand Down Expand Up @@ -603,11 +608,11 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst)
return runViewGraph(cachedGraph, src, dst.has_storage() ? dst : output, /*needsScatter*/ false);
}

Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output)
Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output, id<MTLBuffer> updatesBuffer)
{
ViewCachedGraph* cachedGraph = createViewGraph(output, output.sizes(), output.strides(),
output.storage_offset(), /*needsScatter*/ true);
return runViewGraph(cachedGraph, src, output, /*needsScatter*/ true);
return runViewGraph(cachedGraph, src, output, /*needsScatter*/ true, updatesBuffer);
}

} // namespace mps
Expand Down
14 changes: 13 additions & 1 deletion test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1290,6 +1290,19 @@ def test_expand_cpu_to_mps_copy(self):

self.assertEqual(x_cpu, x.cpu())

def test_cpu_to_strided_mps_copy(self):
# https://github.com/pytorch/pytorch/issues/86975

a1 = torch.Tensor([[1,2],[3,4], [5,6]]).to(torch.device("mps"))
b1 = torch.Tensor([-1, -1])
a1[1:,1] = b1

a2 = torch.Tensor([[1,2],[3,4], [5,6]]).to(torch.device("mps"))
b2 = torch.Tensor([-1, -1]).to(torch.device("mps"))
a2[1:,1] = b2

self.assertEqual(a1, a2)

def test_view_slice(self):
# https://github.com/pytorch/pytorch/issues/83995
NUM_SAMPLES=60
Expand Down Expand Up @@ -6128,7 +6141,6 @@ def test_view(self, device="mps"):
self.assertRaises(RuntimeError, lambda: tensor.view(7, -1))
self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1))

# RuntimeError: Invalid device for storage: mps
def test_contiguous(self, device="mps"):
x = torch.randn(1, 16, 5, 5, device=device)
self.assertTrue(x.is_contiguous())
Expand Down