From b18a4d8583c0e11e228a0792793d6f6e99292766 Mon Sep 17 00:00:00 2001 From: George Necula Date: Wed, 5 Feb 2020 17:35:46 +0100 Subject: [PATCH] Disabled tests known to fail on Mac, and optionally slow tests. Issue: #2166 Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known to be slow. --- CHANGELOG.md | 1 + docs/developer.rst | 3 +++ jax/test_util.py | 8 ++++++++ tests/lax_control_flow_test.py | 9 +++++++++ tests/lax_test.py | 4 ++++ tests/linalg_test.py | 19 +++++++++++++++++++ tests/nn_test.py | 2 ++ 7 files changed, 46 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c7dce33adef4..071e89b2b59e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ These are the release notes for JAX. and Numba. * JAX CPU device buffers now implement the Python buffer protocol, which allows zero-copy buffer sharing between JAX and NumPy. +* Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known as slow. ## jaxlib 0.1.38 (January 29, 2020) diff --git a/docs/developer.rst b/docs/developer.rst index ea55a9fd6b77..5d47127bd626 100644 --- a/docs/developer.rst +++ b/docs/developer.rst @@ -124,6 +124,9 @@ file directly to see more detailed information about the cases being run: python tests/lax_numpy_test.py --num_generated_cases=5 +You can skip a few tests known as slow, by passing environment variable +JAX_SKIP_SLOW_TESTS=1. + The Colab notebooks are tested for errors as part of the documentation build. Update documentation diff --git a/jax/test_util.py b/jax/test_util.py index fee9c33b2cfd..b6dca099bece 100644 --- a/jax/test_util.py +++ b/jax/test_util.py @@ -14,6 +14,7 @@ from contextlib import contextmanager +from distutils.util import strtobool import functools import re import itertools as it @@ -49,6 +50,13 @@ int(os.getenv('JAX_NUM_GENERATED_CASES', 10)), help='Number of generated cases to test') +flags.DEFINE_bool( + 'jax_skip_slow_tests', + strtobool(os.getenv('JAX_SKIP_SLOW_TESTS', '0')), + help= + 'Skip tests marked as slow (> 5 sec).' +) + EPS = 1e-4 def _dtype(x): diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 0b4943d339cf..fa06b1bcc18b 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -957,6 +957,7 @@ def f(c, a): "jit_scan": jit_scan, "jit_f": jit_f} for jit_scan in [False, True] for jit_f in [False, True]) + @jtu.skip_on_flag("jax_skip_slow_tests", True) def testScanGrad(self, jit_scan, jit_f): rng = onp.random.RandomState(0) @@ -987,6 +988,7 @@ def f(c, a): jtu.check_grads(partial(scan, f), (c, as_), order=2, modes=["rev"], atol=1e-3, rtol=2e-3) + @jtu.skip_on_flag("jax_skip_slow_tests", True) def testScanRnn(self): r = npr.RandomState(0) @@ -1444,6 +1446,7 @@ def sqrt_cubed(x, tangent_solve=scalar_solve): self.assertAllClose(results, 5.0 ** 1.5, check_dtypes=False, rtol={onp.float64:1e-7}) + @jtu.skip_on_flag("jax_skip_slow_tests", True) def test_custom_root_vector_with_solve_closure(self): def vector_solve(f, y): @@ -1513,6 +1516,7 @@ def dummy_root_usage(x): {"testcase_name": "nonsymmetric", "symmetric": False}, {"testcase_name": "symmetric", "symmetric": True}, ) + @jtu.skip_on_flag("jax_skip_slow_tests", True) def test_custom_linear_solve(self, symmetric): def explicit_jacobian_solve(matvec, b): @@ -1542,6 +1546,7 @@ def linear_solve(a, b): actual = api.vmap(linear_solve, (None, 1), 1)(a, c) self.assertAllClose(expected, actual, check_dtypes=True) + @jtu.skip_on_flag("jax_skip_slow_tests", True) def test_custom_linear_solve_zeros(self): def explicit_jacobian_solve(matvec, b): return lax.stop_gradient(np.linalg.solve(api.jacobian(matvec)(b), b)) @@ -1561,6 +1566,7 @@ def linear_solve(a, b): jtu.check_grads(lambda x: linear_solve(a, x), (b,), order=2, rtol={onp.float32: 5e-3}) + @jtu.skip_on_flag("jax_skip_slow_tests", True) def test_custom_linear_solve_iterative(self): def richardson_iteration(matvec, b, omega=0.1, tolerance=1e-6): @@ -1622,6 +1628,7 @@ def solve(matvec, x): lambda x, y: positive_definite_solve(high_precision_dot(x, x.T), y), (a, b), order=2, rtol=1e-2) + @jtu.skip_on_flag("jax_skip_slow_tests", True) def test_custom_linear_solve_lu(self): # TODO(b/143528110): re-enable when underlying XLA TPU issue is fixed @@ -1652,6 +1659,7 @@ def transpose_solve(vecmat, x): jtu.check_grads(api.jit(linear_solve), (a, b), order=2, rtol={onp.float32: 2e-3}) + @jtu.skip_on_flag("jax_skip_slow_tests", True) def test_custom_linear_solve_without_transpose_solve(self): def explicit_jacobian_solve(matvec, b): @@ -1674,6 +1682,7 @@ def loss(a, b): with self.assertRaisesRegex(TypeError, "transpose_solve required"): api.grad(loss)(a, b) + @jtu.skip_on_flag("jax_skip_slow_tests", True) def test_custom_linear_solve_pytree(self): """Test custom linear solve with inputs and outputs that are pytrees.""" diff --git a/tests/lax_test.py b/tests/lax_test.py index d8101901b894..0ec72acbd851 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -564,6 +564,7 @@ def _transpose_conv_kernel(data, kernel, dimension_numbers): for dspec in [('NHWC', 'HWIO', 'NHWC'),] for rhs_dilation in [None, (2, 2)] for rng_factory in [jtu.rand_small])) + @jtu.skip_on_flag("jax_skip_slow_tests", True) def testConvTranspose2DT(self, lhs_shape, rhs_shape, dtype, strides, padding, dspec, rhs_dilation, rng_factory): rng = rng_factory() @@ -602,6 +603,7 @@ def fun_via_grad(lhs, rhs): for dspec in [('NHWC', 'HWIO', 'NHWC'),] for rhs_dilation in [None, (2, 2)] for rng_factory in [jtu.rand_small])) + @jtu.skip_on_flag("jax_skip_slow_tests", True) def testConvTranspose2D(self, lhs_shape, rhs_shape, dtype, strides, padding, dspec, rhs_dilation, rng_factory): rng = rng_factory() @@ -2281,6 +2283,7 @@ def testReduceGrad(self, op, init_val, shape, dtype, dims, rng_factory): for dtype in dtypes for padding in ["VALID", "SAME"] for rng_factory in [jtu.rand_default])) + @jtu.skip_on_flag("jax_skip_slow_tests", True) def testReduceWindowGrad(self, op, init_val, dtype, padding, rng_factory): rng = rng_factory() tol = {onp.float16: 1e-1, onp.float32: 1e-3} @@ -2970,6 +2973,7 @@ def fun(operand): for dtype in float_dtypes for padding in ["VALID", "SAME"] for rng_factory in [jtu.rand_small])) + @jtu.skip_on_flag("jax_skip_slow_tests", True) def testSelectAndGatherAdd(self, dtype, padding, rng_factory): if jtu.device_under_test() == "tpu" and dtype == dtypes.bfloat16: raise SkipTest("bfloat16 _select_and_gather_add doesn't work on tpu") diff --git a/tests/linalg_test.py b/tests/linalg_test.py index e40b90b729a9..48247569ad7a 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -17,6 +17,7 @@ from functools import partial import itertools import unittest +import sys import numpy as onp import scipy as osp @@ -51,6 +52,10 @@ def _skip_if_unsupported_type(dtype): dtype in (onp.dtype('float64'), onp.dtype('complex128'))): raise unittest.SkipTest("--jax_enable_x64 is not set") +# TODO(phawkins): bug https://github.com/google/jax/issues/2166 +def _skip_on_mac_xla_bug(): + if sys.platform == "darwin" and osp.version.version > "1.0.0": + raise unittest.SkipTest("Test fails on Mac with new scipy (issue #2166)") class NumpyLinalgTest(jtu.JaxTestCase): @@ -134,6 +139,7 @@ def testSlogdet(self, shape, dtype, rng_factory): for dtype in float_types for rng_factory in [jtu.rand_default])) @jtu.skip_on_devices("tpu") + @jtu.skip_on_flag("jax_skip_slow_tests", True) def testSlogdetGrad(self, shape, dtype, rng_factory): rng = rng_factory() _skip_if_unsupported_type(dtype) @@ -188,6 +194,8 @@ def norm(x): def testEigvals(self, shape, dtype, rng_factory): rng = rng_factory() _skip_if_unsupported_type(dtype) + if shape == (50, 50) and dtype == onp.complex64: + _skip_on_mac_xla_bug() n = shape[-1] args_maker = lambda: [rng(shape, dtype)] a, = args_maker() @@ -565,6 +573,8 @@ def testSolve(self, lhs_shape, rhs_shape, dtype, rng_factory): def testInv(self, shape, dtype, rng_factory): rng = rng_factory() _skip_if_unsupported_type(dtype) + if shape == (200, 200) and dtype == onp.float32: + _skip_on_mac_xla_bug() if jtu.device_under_test() == "gpu" and shape == (200, 200): raise unittest.SkipTest("Test is flaky on GPU") @@ -594,6 +604,8 @@ def args_maker(): def testPinv(self, shape, dtype, rng_factory): rng = rng_factory() _skip_if_unsupported_type(dtype) + if shape == (7, 10000) and dtype in [onp.complex64, onp.float32]: + _skip_on_mac_xla_bug() args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp.linalg.pinv, np.linalg.pinv, args_maker, @@ -650,6 +662,7 @@ def test(x): xc = onp.eye(3, dtype=onp.complex) self.assertAllClose(xc, grad_test_jc(xc), check_dtypes=True) + @jtu.skip_on_flag("jax_skip_slow_tests", True) def testIssue1151(self): A = np.array(onp.random.randn(100, 3, 3), dtype=np.float32) b = np.array(onp.random.randn(100, 3), dtype=np.float32) @@ -661,6 +674,7 @@ def testIssue1151(self): jac0 = jax.jacobian(np.linalg.solve, argnums=0)(A[0], b[0]) jac1 = jax.jacobian(np.linalg.solve, argnums=1)(A[0], b[0]) + @jtu.skip_on_flag("jax_skip_slow_tests", True) def testIssue1383(self): seed = jax.random.PRNGKey(0) tmp = jax.random.uniform(seed, (2,2)) @@ -726,6 +740,7 @@ def testLuOfSingularMatrix(self): for dtype in float_types + complex_types for rng_factory in [jtu.rand_default])) @jtu.skip_on_devices("tpu") # TODO(phawkins): precision problems on TPU. + @jtu.skip_on_flag("jax_skip_slow_tests", True) def testLuGrad(self, shape, dtype, rng_factory): rng = rng_factory() _skip_if_unsupported_type(dtype) @@ -764,6 +779,8 @@ def testLuBatching(self, shape, dtype, rng_factory): def testLuFactor(self, n, dtype, rng_factory): rng = rng_factory() _skip_if_unsupported_type(dtype) + if n == 200 and dtype == onp.complex64: + _skip_on_mac_xla_bug() args_maker = lambda: [rng((n, n), dtype)] x, = args_maker() @@ -985,6 +1002,8 @@ def testTriangularSolveGradPrecision(self): def testExpm(self, n, dtype, rng_factory): rng = rng_factory() _skip_if_unsupported_type(dtype) + if n == 50 and dtype in [onp.complex64, onp.float32]: + _skip_on_mac_xla_bug() args_maker = lambda: [rng((n, n), dtype)] osp_fun = lambda a: osp.linalg.expm(a) diff --git a/tests/nn_test.py b/tests/nn_test.py index 0d6bf3add13e..b621f66307db 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -34,6 +34,7 @@ class NNFunctionsTest(jtu.JaxTestCase): + @jtu.skip_on_flag("jax_skip_slow_tests", True) def testSoftplusGrad(self): check_grads(nn.softplus, (1e-8,), 4, rtol=1e-2 if jtu.device_under_test() == "tpu" else None) @@ -42,6 +43,7 @@ def testSoftplusValue(self): val = nn.softplus(89.) self.assertAllClose(val, 89., check_dtypes=False) + @jtu.skip_on_flag("jax_skip_slow_tests", True) def testEluGrad(self): check_grads(nn.elu, (1e4,), 4, eps=1.)