Skip to content

Commit

Permalink
fix numpy compatibility in test for torch.kthvalue (pytorch#59214)
Browse files Browse the repository at this point in the history
Summary:
Fixes pytorch#59201. Should be merged after pytorch#59067 to ensure this actually working correctly.

Pull Request resolved: pytorch#59214

Reviewed By: albanD

Differential Revision: D28792363

Pulled By: mruberry

fbshipit-source-id: 0cf613463139352906fb567f1efcc582c2c25de8
  • Loading branch information
pmeier authored and deniskokarev committed Jun 9, 2021
1 parent 63216b2 commit 260ffc8
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions test/test_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2428,19 +2428,21 @@ def test_tensor_compare_ops_argmax_argmix_kthvalue_dim_empty(self, device):
test_functions = [
('argmax', torch.argmax, {'dtype': torch.int64}, np.argmax),
('argmin', torch.argmin, {'dtype': torch.int64}, np.argmin),
('kthvalue', lambda *args, **kwargs: torch.kthvalue(*args, k=1, **kwargs).values,
{}, lambda *args, **kwargs: np.partition(*args, 1, **kwargs))
('kthvalue', lambda *args, k=1, **kwargs: torch.kthvalue(*args, k=1, **kwargs).values,
{}, lambda *args, k=1, axis=None, **kwargs: np.partition(*args, k, **kwargs).take(k - 1, axis=axis))
]

for name, fn, dtype, np_function in test_functions:
error_msg = f"test function: {name}"
self.assertEqual(torch.empty((2, 0), device=device, **dtype), fn(master_input, dim=2), msg=error_msg)
self.assertEqual(np_function(np_input, axis=2),
fn(master_input, dim=2).cpu().numpy(), msg=error_msg)
self.assertEqual(
np_function(np_input, axis=2), fn(master_input, dim=2).cpu().numpy(), msg=error_msg, exact_dtype=False
)

self.assertEqual(torch.empty((2, 0), device=device, **dtype), fn(master_input, dim=-1), msg=error_msg)
self.assertEqual(np_function(np_input, axis=-1),
fn(master_input, dim=-1).cpu().numpy(), msg=error_msg)
self.assertEqual(
np_function(np_input, axis=-1), fn(master_input, dim=-1).cpu().numpy(), msg=error_msg, exact_dtype=False
)

# keepdim variant does not exist for numpy
self.assertEqual(torch.empty((2, 0, 1), device=device, **dtype), fn(master_input, dim=2, keepdim=True),
Expand Down

0 comments on commit 260ffc8

Please sign in to comment.