Skip to content

Commit

Permalink
Clamp argmin / argmax min indicies values to 0 (#193)
Browse files Browse the repository at this point in the history
  • Loading branch information
DenisVieriu97 committed Dec 8, 2022
1 parent d541bb3 commit 41a36aa
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
7 changes: 6 additions & 1 deletion aten/src/ATen/native/mps/operations/ReduceOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -1620,8 +1620,13 @@ Tensor min_mps(const Tensor& input_t) {
toType:MPSDataTypeInt64
name:@"castOutputTensor"];

MPSGraphTensor* outputClampedTensor = [mpsGraph clampWithTensor:outputTensor
minValueTensor:[mpsGraph constantWithScalar:0 dataType:MPSDataTypeInt64]
maxValueTensor:[mpsGraph constantWithScalar:LLONG_MAX dataType:MPSDataTypeInt64]
name: nil];

newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = outputTensor;
newCachedGraph->outputTensor_ = outputClampedTensor;
}
return newCachedGraph;
});
Expand Down
8 changes: 2 additions & 6 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -8001,8 +8001,8 @@ class TestConsistency(TestCase):
'__ror__': ['b8', 'i16', 'i32', 'i64', 'u8'],
'__rpow__': ['f16'],
'__rxor__': ['b8', 'i16', 'i32', 'i64', 'u8'],
'masked.argmax': ['i16', 'i64', 'u8'],
'masked.argmin': ['i16', 'i64', 'u8'],
'masked.argmax': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'masked.argmin': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'masked.log_softmax': ['f32'],
'masked.logaddexp': ['f32'],
'masked.logsumexp': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
Expand Down Expand Up @@ -8460,10 +8460,6 @@ class TestConsistency(TestCase):
'masked.std': [torch.int32],
'masked.var': [torch.int32],

# Failures due to inconsistency between CPU and GPU for `inf` case
'masked.argmax': ['f16', 'f32', 'i32'],
'masked.argmin': ['f16', 'f32', 'i32'],

'as_strided_scatter': [torch.uint8],
'atan2': [torch.int64],
'bfloat16': None,
Expand Down

0 comments on commit 41a36aa

Please sign in to comment.