Skip to content

Commit

Permalink
Disabled one and enabled several unit tests for ROCm.
Browse files Browse the repository at this point in the history
  • Loading branch information
rsanthanam-amd committed May 10, 2022
1 parent a62ca21 commit 8d9f17d
Show file tree
Hide file tree
Showing 8 changed files with 1 addition and 18 deletions.
1 change: 1 addition & 0 deletions tests/array_interoperability_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def setUp(self):
for dtype in dlpack_dtypes
for take_ownership in [False, True]
for gpu in [False, True]))
@jtu.skip_on_devices("rocm") # relevant dlpack protocol is N/A for ROCm ATM
def testJaxRoundTrip(self, shape, dtype, take_ownership, gpu):
rng = jtu.rand_default(self.rng())
np = rng(shape, dtype)
Expand Down
1 change: 0 additions & 1 deletion tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1541,7 +1541,6 @@ def np_fun(lhs, rhs):
for full in [False, True]
for w in [False, True]
for cov in [False, True, "unscaled"]))
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
def testPolyfit(self, shape, dtype, deg, rcond, full, w, cov):
rng = jtu.rand_default(self.rng())
tol_spec = {np.float32: 1e-3, np.float64: 1e-13, np.complex64: 1e-5}
Expand Down
3 changes: 0 additions & 3 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,6 @@ def testPinv(self, shape, dtype):
# TODO(phawkins): 1e-1 seems like a very loose tolerance.
jtu.check_grads(jnp.linalg.pinv, args_maker(), 2, rtol=1e-1, atol=2e-1)

@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
def testPinvGradIssue2792(self):
def f(p):
a = jnp.array([[0., 0.],[-p, 1.]], jnp.float32) * 1 / (1 + p**2)
Expand Down Expand Up @@ -910,7 +909,6 @@ def testMatrixPower(self, shape, dtype, n):
"shape": shape, "dtype": dtype}
for shape in [(3, ), (1, 2), (8, 5), (4, 4), (5, 5), (50, 50)]
for dtype in float_types + complex_types))
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
def testMatrixRank(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
Expand Down Expand Up @@ -958,7 +956,6 @@ def testMultiDot(self, shapes, dtype):
]
for rcond in [-1, None, 0.5]
for dtype in float_types + complex_types))
@jtu.skip_on_devices("tpu","rocm") # SVD not implemented on TPU. will be fixed in ROCm-5.1
def testLstsq(self, lhs_shape, rhs_shape, dtype, rcond):
rng = jtu.rand_default(self.rng())
np_fun = partial(np.linalg.lstsq, rcond=rcond)
Expand Down
2 changes: 0 additions & 2 deletions tests/qdwh_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ class QdwhTest(jtu.JaxTestCase):
'm': m, 'n': n, 'log_cond': log_cond}
for m, n in zip([8, 10, 20], [6, 10, 18])
for log_cond in np.linspace(1, _MAX_LOG_CONDITION_NUM, 4)))
@jtu.skip_on_devices("rocm") # will be fixed in rocm-5.1
def testQdwhUnconvergedAfterMaxNumberIterations(
self, m, n, log_cond):
"""Tests unconvergence after maximum number of iterations."""
Expand Down Expand Up @@ -138,7 +137,6 @@ def testQdwhWithUpperTriangularInputAllOnes(self, m, n, log_cond):
'm': m, 'n': n, 'log_cond': log_cond}
for m, n in zip([6, 8], [6, 4])
for log_cond in np.linspace(1, 4, 4)))
@jtu.skip_on_devices("rocm") # will be solved rocm-5.1
def testQdwhWithRandomMatrix(self, m, n, log_cond):
"""Tests qdwh with random input."""
rng = jtu.rand_uniform(self.rng(), low=0.3, high=0.9)
Expand Down
2 changes: 0 additions & 2 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,7 +1100,6 @@ def testT(self, df, dtype):
for dim in [1, 3, 5]
for dtype in float_dtypes
for method in ['svd', 'eigh', 'cholesky']))
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
def testMultivariateNormal(self, dim, dtype, method):
r = self.rng()
mean = r.randn(dim)
Expand Down Expand Up @@ -1138,7 +1137,6 @@ def testMultivariateNormal(self, dim, dtype, method):
for cov_batch_size in [(), (3,), (2, 3)]
for shape in [(), (1,), (5,)]
for method in ['cholesky', 'svd', 'eigh']))
@jtu.skip_on_devices("rocm") # will be solved in rocm-5.1
def testMultivariateNormalShapes(self, dim, mean_batch_size, cov_batch_size,
shape, method):
r = self.rng()
Expand Down
6 changes: 0 additions & 6 deletions tests/scipy_signal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def testConvolutions2D(self, xshape, yshape, dtype, mode, jsp_op, osp_op):
for axis in [0, -1]
for type in ['constant', 'linear']
for bp in [0, [0, 2]]))
@jtu.skip_on_devices("rocm") # will be fixed in rocm-5.1
def testDetrend(self, shape, dtype, axis, type, bp):
signal = np.random.normal(loc=2, size=shape)

Expand Down Expand Up @@ -171,7 +170,6 @@ def jsp_fun(signal, noise):
for detrend in ['constant', 'linear', False]
for boundary in [None, 'even', 'odd', 'zeros']
for padded in [True, False]))
@jtu.skip_on_devices("rocm") # will be fixed in ROCm 5.1
def testStftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
noverlap, nfft, detrend, boundary, padded,
timeaxis):
Expand Down Expand Up @@ -223,7 +221,6 @@ def testStftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
for detrend in ['constant', 'linear', False]
for scaling in ['density', 'spectrum']
for average in ['mean']))
@jtu.skip_on_devices("rocm") # will be fixed in next ROCm version
def testCsdAgainstNumpy(
self, *, xshape, yshape, dtype, fs, window, nperseg, noverlap, nfft,
detrend, scaling, timeaxis, average):
Expand Down Expand Up @@ -274,7 +271,6 @@ def testCsdAgainstNumpy(
for detrend in ['constant', 'linear', False]
for scaling in ['density', 'spectrum']
for average in ['mean']))
@jtu.skip_on_devices("rocm") # will be fixed in next rocm release
def testCsdWithSameParamAgainstNumpy(
self, *, shape, dtype, fs, window, nperseg, noverlap, nfft,
detrend, scaling, timeaxis, average):
Expand Down Expand Up @@ -332,7 +328,6 @@ def osp_fun(x, y):
for return_onesided in [True, False]
for scaling in ['density', 'spectrum']
for average in ['mean', 'median']))
@jtu.skip_on_devices("rocm") # will be fixed in next ROCm release
def testWelchAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
noverlap, nfft, detrend, return_onesided,
scaling, timeaxis, average):
Expand Down Expand Up @@ -423,7 +418,6 @@ def testWelchWithDefaultStepArgsAgainstNumpy(
for nfft in [None, nperseg, int(nperseg * 1.5), nperseg * 2]
for onesided in [False, True]
for boundary in [False, True]))
@jtu.skip_on_devices("rocm") # will be fixed in ROCm 5.1
def testIstftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
noverlap, nfft, onesided, boundary,
timeaxis, freqaxis):
Expand Down
1 change: 0 additions & 1 deletion tests/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,7 +1034,6 @@ def f_sparse(data, indices, lhs, rhs):
[(5, 3), (5, 2), [0], [0]],
]
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
def test_bcoo_dot_general_cusparse(
self, lhs_shape, rhs_shape, dtype, lhs_contracting, rhs_contracting):
rng = jtu.rand_small(self.rng())
Expand Down
3 changes: 0 additions & 3 deletions tests/svd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ class SvdTest(jtu.JaxTestCase):
for m, n in zip([2, 8, 10, 20], [4, 6, 10, 18])
for log_cond in np.linspace(1, _MAX_LOG_CONDITION_NUM, 4)
for full_matrices in [True, False]))
@jtu.skip_on_devices("rocm") # will be fixed on rocm-5.1
def testSvdWithRectangularInput(self, m, n, log_cond, full_matrices):
"""Tests SVD with rectangular input."""
with jax.default_matmul_precision('float32'):
Expand Down Expand Up @@ -122,7 +121,6 @@ def testSvdWithSkinnyTallInput(self, m, n):
'm': m, 'r': r, 'log_cond': log_cond}
for m, r in zip([8, 8, 8, 10], [3, 5, 7, 9])
for log_cond in np.linspace(1, 3, 3)))
@jtu.skip_on_devices("rocm") # will be fixed on rocm-5.1
def testSvdWithOnRankDeficientInput(self, m, r, log_cond):
"""Tests SVD with rank-deficient input."""
with jax.default_matmul_precision('float32'):
Expand All @@ -149,7 +147,6 @@ def testSvdWithOnRankDeficientInput(self, m, r, log_cond):
for m, n in zip([2, 8, 10, 20], [4, 6, 10, 18])
for log_cond in np.linspace(1, _MAX_LOG_CONDITION_NUM, 4)
for full_matrices in [True, False]))
@jtu.skip_on_devices("rocm") # will be fixed on rocm-5.1
def testSingularValues(self, m, n, log_cond, full_matrices):
"""Tests singular values."""
with jax.default_matmul_precision('float32'):
Expand Down

0 comments on commit 8d9f17d

Please sign in to comment.