Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adaptive max pool 2d forward and backward with test #11

Merged
merged 3 commits into from
May 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
kernel_sizeW = isizeW - (osizeW-1) * strideW;
}

// Adaptive average pooling

Tensor& adaptive_avg_pool2d_out_mps
(const Tensor& input,
IntArrayRef output_size,
Expand Down Expand Up @@ -150,5 +152,93 @@

}

// Adaptive max pooling

TORCH_IMPL_FUNC(adaptive_max_pool2d_out_mps)
(const Tensor& input,
IntArrayRef output_size,
const Tensor& output,
const Tensor& indices) {

for (int64_t i = 1; i < input.ndimension(); i++) {
TORCH_CHECK(input.size(i) > 0,
"adaptive_max_pool2d(): Expected input to have non-zero size for non-batch dimensions, "
"but input has sizes ", input.sizes(), " with dimension ", i, " being "
"empty");
}

int64_t isizeH = input.size(-2);
int64_t isizeW = input.size(-1);

int64_t osizeH = output_size[0];
int64_t osizeW = output_size[1];

if(input.suggest_memory_format() == at::MemoryFormat::ChannelsLast)
TORCH_CHECK(input.ndimension() == 4,
"adaptive_avg_pool2d(): Expected 4D tensor, but got ",
input.sizes())

switch (input.suggest_memory_format()) {
case at::MemoryFormat::Contiguous:
case at::MemoryFormat::ChannelsLast:
break;
default:
TORCH_CHECK(
false,
"Unsupported memory format. Supports only ChannelsLast, Contiguous")
}

int64_t strideH;
int64_t strideW;
int64_t kernel_sizeH;
int64_t kernel_sizeW;

set_kernel_params(isizeH, isizeW,
osizeH, osizeW,
strideH, strideW,
kernel_sizeH, kernel_sizeW);

auto outputs = at::max_pool2d_with_indices(input,
IntArrayRef({kernel_sizeH, kernel_sizeW}),
IntArrayRef({strideH, strideW}),
IntArrayRef({0, 0}),
IntArrayRef({1, 1}),
false);

output.copy_(std::get<0>(outputs));
indices.copy_(std::get<1>(outputs));
}

TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_mps)
(const Tensor& gradOutput,
const Tensor& input,
const Tensor& indices,
const Tensor& gradInput) {

int64_t isizeH = input.size(-2);
int64_t isizeW = input.size(-1);
int64_t osizeH = gradOutput.size(-2);
int64_t osizeW = gradOutput.size(-1);

int64_t strideH, strideW, kernel_sizeH, kernel_sizeW;

set_kernel_params(isizeH, isizeW,
osizeH, osizeW,
strideH, strideW,
kernel_sizeH, kernel_sizeW);

auto returnGradInput = at::max_pool2d_with_indices_backward(gradOutput,
input,
IntArrayRef({kernel_sizeH, kernel_sizeW}),
IntArrayRef({strideH, strideW}),
IntArrayRef({0, 0}),
IntArrayRef({1, 1}),
false,
indices);

gradInput.copy_(returnGradInput);

}

}
}
2 changes: 2 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9793,6 +9793,7 @@
dispatch:
CPU: adaptive_max_pool2d_out_cpu
CUDA: adaptive_max_pool2d_out_cuda
MPS: adaptive_max_pool2d_out_mps

# Return: (Tensor output, Tensor indices)
- func: adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor)
Expand All @@ -9805,6 +9806,7 @@
dispatch:
CPU: adaptive_max_pool2d_backward_out_cpu
CUDA: adaptive_max_pool2d_backward_out_cuda
MPS: adaptive_max_pool2d_backward_out_mps

- func: adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor
python_module: nn
Expand Down
44 changes: 44 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -3089,6 +3089,50 @@ def helper(input_shape, out_shape, channels_last):

helper((2, 16, 16), (4, 4), False)

# Test max avg pool2d - when the input size is a multiple of output size
# Not testing for channels last right now
def test_adaptive_max_pool2d_simple(self):
def helper(input_shape, out_shape, return_indices, dtype, channels_last=False):
cpu_x = None
if(dtype in [torch.float16, torch.float32]):
cpu_x = torch.randn(input_shape, device='cpu', dtype=dtype, requires_grad=True)
else:
cpu_x = torch.randint(50, input_shape, device='cpu', dtype=dtype, requires_grad=True)
if(channels_last):
cpu_x = cpu_x.to(memory_format=torch.channels_last)
cpu_x.retain_grad()
x = cpu_x.detach().clone().to('mps').requires_grad_()

max_result, max_indices = None, None
max_result_cpu, max_indices_cpu = None, None

if(return_indices):
max_result, max_indices = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x)
max_result_cpu, max_indices_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x)
else:
max_result = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x)
max_result_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x)

cpu_grad = torch.randn(max_result_cpu.shape)
grad = cpu_grad.to('mps')

max_result.backward(gradient=grad)
max_result_cpu.backward(gradient=cpu_grad)

self.assertEqual(max_result, max_result_cpu)
if(return_indices):
self.assertEqual(max_indices, max_indices_cpu)
self.assertEqual(x.grad, cpu_x.grad)

for dtype in [torch.float32]:
for return_indices in [False, True]:
helper((2, 2, 4, 4), (2, 2), return_indices, dtype)
helper((2, 2, 9, 9), (3, 3), return_indices, dtype)
helper((2, 2, 9, 9), (9, 9), return_indices, dtype)
helper((2, 2, 16, 16), (2, 2), return_indices, dtype)
helper((2, 2, 16, 16), (2, 16), return_indices, dtype)
helper((2, 16, 16), (4, 4), return_indices, dtype)

def test_gelu_simple(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
Expand Down