Skip to content

Commit

Permalink
enabled min/max tests in test_cuda (ROCm#236)
Browse files Browse the repository at this point in the history
  • Loading branch information
lcskrishna authored and iotamudelta committed Oct 9, 2018
1 parent 2647a53 commit 85b22a3
Showing 1 changed file with 12 additions and 18 deletions.
30 changes: 12 additions & 18 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,25 +353,23 @@ def tmp(t):
('kthvalue', small_3d_unique, lambda t: [3],),
('kthvalue', small_3d_unique, lambda t: [3, 1], 'dim'),
('kthvalue', small_3d_unique, lambda t: [3, -1], 'neg_dim'),
('lerp', small_3d, lambda t: [small_3d(t), 0.3], '', types, False, "skipIfRocm:HalfTensor"),
('max', small_3d_unique, lambda t: [], '', types, False, "skipIfRocm:HalfTensor"),
('max', small_3d_unique, lambda t: [1], 'dim', types, False, skipIfRocm),
('max', small_3d_unique, lambda t: [-1], 'neg_dim', types, False, skipIfRocm),
('lerp', small_3d, lambda t: [small_3d(t), 0.3]),
('max', small_3d_unique, lambda t: []),
('max', small_3d_unique, lambda t: [1], 'dim'),
('max', small_3d_unique, lambda t: [-1], 'neg_dim'),
('max', medium_2d, lambda t: [medium_2d(t)], 'elementwise'),
('min', small_3d_unique, lambda t: [], '', types, False, "skipIfRocm:HalfTensor"),
('min', small_3d_unique, lambda t: [1], 'dim', types, False, skipIfRocm),
('min', small_3d_unique, lambda t: [-1], 'neg_dim', types, False, skipIfRocm),
('min', medium_2d, lambda t: [medium_2d(t)], 'elementwise'),
('mean', small_3d, lambda t: [], '', types, False, "skipIfRocm:HalfTensor"),
('mean', small_3d, lambda t: [-1], 'neg_dim', types, False, "skipIfRocm:DoubleTensor,FloatTensor,HalfTensor"),
('mean', small_3d, lambda t: [1], 'dim', types, False, "skipIfRocm:DoubleTensor,FloatTensor,HalfTensor"),
('mode', small_3d, lambda t: [], '', types, False, skipIfRocm),
('mode', small_3d, lambda t: [1], 'dim', types, False, skipIfRocm),
('mode', small_3d, lambda t: [-1], 'neg_dim', types, False, skipIfRocm),
('mvlgamma', lambda t: tensor_clamp(small_2d(t), 0.1, 10), lambda t: [1], '2d_p=1', float_types_no_half,
False, "skipIfRocm:DoubleTensor,FloatTensor"),
('mvlgamma', lambda t: tensor_clamp(small_2d(t), 0.6, 10), lambda t: [2], '2d_p=2', float_types_no_half,
False, "skipIfRocm:DoubleTensor,FloatTensor"),
('mean', small_3d, lambda t: []),
('mean', small_3d, lambda t: [-1], 'neg_dim'),
('mean', small_3d, lambda t: [1], 'dim'),
('mode', small_3d, lambda t: []),
('mode', small_3d, lambda t: [1], 'dim'),
('mode', small_3d, lambda t: [-1], 'neg_dim'),
('mvlgamma', lambda t: tensor_clamp(small_2d(t), 0.1, 10), lambda t: [1], '2d_p=1', float_types_no_half),
('mvlgamma', lambda t: tensor_clamp(small_2d(t), 0.6, 10), lambda t: [2], '2d_p=2', float_types_no_half),
('remainder', small_3d, lambda t: [3], 'value', types, False, "skipIfRocm:HalfTensor"),
('remainder', small_3d, lambda t: [-3], 'negative_value', signed_types),
('remainder', small_3d, lambda t: [small_3d_positive(t)], 'tensor'),
Expand Down Expand Up @@ -977,7 +975,6 @@ def test_broadcast_cpu(self):
def test_broadcast_gpu(self):
self._test_broadcast(torch.randn(5, 5).cuda())

@skipIfRocm
def test_min_max_nan(self):
tests = [(lambda x: x.min(), 'min'),
(lambda x: x.max(), 'max'),
Expand Down Expand Up @@ -1743,7 +1740,6 @@ def test_tensor_scatterAdd(self):
def test_tensor_scatterFill(self):
TestTorch._test_scatter_base(self, lambda t: t.cuda(), 'scatter_', True, test_bounds=False)

@skipIfRocm
def test_min_max_inits(self):
# Testing if THC_reduceAll received the correct index initialization.
# This affects the result of THC_reduceAll operations at extreme values
Expand All @@ -1757,11 +1753,9 @@ def test_min_max_inits(self):
_, v = y.min(dim=0)
self.assertEqual(v, expected)

@skipIfRocm
def test_max_with_inf(self):
TestTorch._test_max_with_inf(self, (torch.half, torch.float, torch.double), 'cuda')

@skipIfRocm
def test_min_with_inf(self):
TestTorch._test_min_with_inf(self, (torch.half, torch.float, torch.double), 'cuda')

Expand Down

0 comments on commit 85b22a3

Please sign in to comment.