Skip to content

Commit

Permalink
Add amax and amin with tests (#33)
Browse files Browse the repository at this point in the history
* Add amax and amin with tests
  • Loading branch information
abhudev authored and kulinseth committed Jun 16, 2022
1 parent f1fb575 commit 7a740ea
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 0 deletions.
27 changes: 27 additions & 0 deletions aten/src/ATen/native/mps/operations/ReduceOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,15 @@ void set_axes_and_shapes(const Tensor& input_t,
axes:axes
name:nil];
}
else if(reduction_type == "amax") {
castOutputTensor = [mpsGraph reductionMaximumWithTensor:inputTensor
axes:axes
name:nil];
} else if(reduction_type == "amin") {
castOutputTensor = [mpsGraph reductionMinimumWithTensor:inputTensor
axes:axes
name:nil];
}

MPSGraphTensor* outputTensor = nil;

Expand Down Expand Up @@ -294,6 +303,24 @@ inline ScalarType get_dtype_from_self(
return src_type;
}

TORCH_IMPL_FUNC(amax_out_mps)
(const Tensor& input_t,
IntArrayRef dim,
bool keepdim,
const Tensor& output_t) {

reduction_out_mps(input_t, dim, keepdim, c10::nullopt, output_t, "amax", "amax_out_mps");
}

TORCH_IMPL_FUNC(amin_out_mps)
(const Tensor& input_t,
IntArrayRef dim,
bool keepdim,
const Tensor& output_t) {

reduction_out_mps(input_t, dim, keepdim, c10::nullopt, output_t, "amin", "amin_out_mps");
}

Tensor prod_mps(const Tensor &self, c10::optional<ScalarType> opt_dtype) {

auto num_dims = self.dim();
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3121,6 +3121,7 @@
structured: True
dispatch:
CPU, CUDA: amax_out
MPS: amax_out_mps

# Return: (Tensor output, Tensor indices)
- func: max_pool1d_with_indices(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
Expand Down Expand Up @@ -3277,6 +3278,7 @@
structured: True
dispatch:
CPU, CUDA: amin_out
MPS: amin_out_mps

# TODO: Add this function to MPS dispatch key so that we avoid declaring it in
# native_functions.yaml
Expand Down
44 changes: 44 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2618,6 +2618,50 @@ def helper(shape):

helper((4, 5, 6, 7))

# Test forward amax
def test_amax(self):
def helper(shape, dim, keepdim):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()

result = torch.amax(x, dim=dim, keepdim=keepdim)
result_cpu = torch.amax(cpu_x, dim=dim, keepdim=keepdim)

cpu_grad = torch.randn(result_cpu.shape)
grad = cpu_grad.to('mps')

result_cpu.backward(gradient=cpu_grad)
result.backward(gradient=grad)

self.assertEqual(result, result_cpu)
self.assertEqual(x.grad, cpu_x.grad)

for dim in ([], [0], [0, 1], [2, 3]):
for keepdim in [False, True]:
helper((2, 8, 4, 5), dim, keepdim)

# Test forward amin
def test_amin(self):
def helper(shape, dim, keepdim):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()

result = torch.amin(x, dim=dim, keepdim=keepdim)
result_cpu = torch.amin(cpu_x, dim=dim, keepdim=keepdim)

cpu_grad = torch.randn(result_cpu.shape)
grad = cpu_grad.to('mps')

result_cpu.backward(gradient=cpu_grad)
result.backward(gradient=grad)

self.assertEqual(result, result_cpu)
self.assertEqual(x.grad, cpu_x.grad)

for dim in ([], [0], [0, 1], [2, 3]):
for keepdim in [False, True]:
helper((2, 8, 4, 5), dim, keepdim)

# Test minimum and maximum
def test_minimum_maximum(self):
def helper(n, c, h, w):
Expand Down

0 comments on commit 7a740ea

Please sign in to comment.