Skip to content

Commit

Permalink
[MPS] Fix views with 3 or more sliced dimensions (#95762)
Browse files Browse the repository at this point in the history
  • Loading branch information
DenisVieriu97 authored and cyyever committed Mar 2, 2023
1 parent 3ac4aa3 commit f711760
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 14 deletions.
27 changes: 13 additions & 14 deletions aten/src/ATen/native/mps/operations/View.mm
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,6 @@ bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape) {
MPSNDArrayDescriptor *srcTensorNDArrayDesc = nil;
MPSNDArray *srcTensorNDArray = nil;
id<MTLCommandBuffer> commandBuffer = getCurrentMPSStream()->commandBuffer();

int64_t base_idx = 0;

std::vector<int64_t> src_base_shape_vec;
Expand Down Expand Up @@ -544,20 +543,20 @@ bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape) {
}

int64_t sliceOffset = src.storage_offset() / view_numel;
// There are cases where both dimensions of a view can shrink
// E.g: x = torch.randn((3,6))[1, 1:3]
int64_t nextSliceOffset = 0;
bool sliceNextDim = (firstDimToSlice < (src_base_shape.size() - 1)) &&
(src_view_shape[firstDimToSlice + 1] != src_base_shape[firstDimToSlice + 1]);

[srcTensorNDArrayDesc sliceDimension:src_ndim_base - 1 - firstDimToSlice withSubrange:{static_cast<NSUInteger>(sliceOffset), static_cast<NSUInteger>(src.sizes()[firstDimToSlice])}];
if (sliceNextDim) {
if (firstDimToSlice + 1 == src_base_shape.size() - 1) {
nextSliceOffset = src.storage_offset() % src_base_shape[src_base_shape.size() - 1];
} else {
nextSliceOffset = (src.storage_offset() % view_numel) / (view_numel / src_base_shape[firstDimToSlice + 1]);
[srcTensorNDArrayDesc sliceDimension:src_ndim_base - 1 - firstDimToSlice
withSubrange:{static_cast<NSUInteger>(sliceOffset), static_cast<NSUInteger>(src.sizes()[firstDimToSlice])}];

// Slice any remaining dimensions
for (const auto crtSliceOffset: c10::irange(firstDimToSlice + 1, src_base_shape.size())) {
if (src_view_shape[crtSliceOffset] != src_base_shape[crtSliceOffset]) {
if (crtSliceOffset == src_base_shape.size() - 1) {
sliceOffset = src.storage_offset() % src_base_shape[src_base_shape.size() - 1];
} else {
sliceOffset = (src.storage_offset() % view_numel) / (view_numel / src_base_shape[crtSliceOffset]);
}
[srcTensorNDArrayDesc sliceDimension:src_ndim_base - 1 - crtSliceOffset
withSubrange:{static_cast<NSUInteger>(sliceOffset), static_cast<NSUInteger>(src.sizes()[crtSliceOffset])}];
}
[srcTensorNDArrayDesc sliceDimension:src_ndim_base - 2 - firstDimToSlice withSubrange:{static_cast<NSUInteger>(nextSliceOffset), static_cast<NSUInteger>(src.sizes()[firstDimToSlice+1])}];
}
srcTensorNDArrayView = [srcTensorNDArray arrayViewWithCommandBuffer:commandBuffer
descriptor:srcTensorNDArrayDesc
Expand Down
9 changes: 9 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2031,6 +2031,15 @@ def helper(shape):
helper([3, 4, 18, 22])
helper([3, 4, 18, 22, 150])

def test_contiguous_slice_3d(self):
x = torch.randn(2, 3, 3, device="mps")
x_cpu = x.detach().clone().cpu()
x = x[:1]
x_cpu = x_cpu[:1]
out = x[:, 0:1, 0:1] * x[:, 1:2, 1:2]
out_cpu = x_cpu[:, 0:1, 0:1] * x_cpu[:, 1:2, 1:2]
self.assertEqual(out, out_cpu)

def test_view_slice(self):
# https://github.com/pytorch/pytorch/issues/83995
NUM_SAMPLES = 60
Expand Down

0 comments on commit f711760

Please sign in to comment.