Skip to content

Commit

Permalink
Disabled tests known to fail on Mac, and optionally slow tests.
Browse files Browse the repository at this point in the history
Issue: google#2166

Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known
to be slow.
  • Loading branch information
gnecula committed Feb 5, 2020
1 parent d01210e commit b18a4d8
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -21,6 +21,7 @@ These are the release notes for JAX.
and Numba.
* JAX CPU device buffers now implement the Python buffer protocol, which allows
zero-copy buffer sharing between JAX and NumPy.
* Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known as slow.

## jaxlib 0.1.38 (January 29, 2020)

Expand Down
3 changes: 3 additions & 0 deletions docs/developer.rst
Expand Up @@ -124,6 +124,9 @@ file directly to see more detailed information about the cases being run:
python tests/lax_numpy_test.py --num_generated_cases=5
You can skip a few tests known as slow, by passing environment variable
JAX_SKIP_SLOW_TESTS=1.

The Colab notebooks are tested for errors as part of the documentation build.

Update documentation
Expand Down
8 changes: 8 additions & 0 deletions jax/test_util.py
Expand Up @@ -14,6 +14,7 @@


from contextlib import contextmanager
from distutils.util import strtobool
import functools
import re
import itertools as it
Expand Down Expand Up @@ -49,6 +50,13 @@
int(os.getenv('JAX_NUM_GENERATED_CASES', 10)),
help='Number of generated cases to test')

flags.DEFINE_bool(
'jax_skip_slow_tests',
strtobool(os.getenv('JAX_SKIP_SLOW_TESTS', '0')),
help=
'Skip tests marked as slow (> 5 sec).'
)

EPS = 1e-4

def _dtype(x):
Expand Down
9 changes: 9 additions & 0 deletions tests/lax_control_flow_test.py
Expand Up @@ -957,6 +957,7 @@ def f(c, a):
"jit_scan": jit_scan, "jit_f": jit_f}
for jit_scan in [False, True]
for jit_f in [False, True])
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testScanGrad(self, jit_scan, jit_f):
rng = onp.random.RandomState(0)

Expand Down Expand Up @@ -987,6 +988,7 @@ def f(c, a):
jtu.check_grads(partial(scan, f), (c, as_), order=2, modes=["rev"],
atol=1e-3, rtol=2e-3)

@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testScanRnn(self):
r = npr.RandomState(0)

Expand Down Expand Up @@ -1444,6 +1446,7 @@ def sqrt_cubed(x, tangent_solve=scalar_solve):
self.assertAllClose(results, 5.0 ** 1.5, check_dtypes=False,
rtol={onp.float64:1e-7})

@jtu.skip_on_flag("jax_skip_slow_tests", True)
def test_custom_root_vector_with_solve_closure(self):

def vector_solve(f, y):
Expand Down Expand Up @@ -1513,6 +1516,7 @@ def dummy_root_usage(x):
{"testcase_name": "nonsymmetric", "symmetric": False},
{"testcase_name": "symmetric", "symmetric": True},
)
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def test_custom_linear_solve(self, symmetric):

def explicit_jacobian_solve(matvec, b):
Expand Down Expand Up @@ -1542,6 +1546,7 @@ def linear_solve(a, b):
actual = api.vmap(linear_solve, (None, 1), 1)(a, c)
self.assertAllClose(expected, actual, check_dtypes=True)

@jtu.skip_on_flag("jax_skip_slow_tests", True)
def test_custom_linear_solve_zeros(self):
def explicit_jacobian_solve(matvec, b):
return lax.stop_gradient(np.linalg.solve(api.jacobian(matvec)(b), b))
Expand All @@ -1561,6 +1566,7 @@ def linear_solve(a, b):
jtu.check_grads(lambda x: linear_solve(a, x), (b,), order=2,
rtol={onp.float32: 5e-3})

@jtu.skip_on_flag("jax_skip_slow_tests", True)
def test_custom_linear_solve_iterative(self):

def richardson_iteration(matvec, b, omega=0.1, tolerance=1e-6):
Expand Down Expand Up @@ -1622,6 +1628,7 @@ def solve(matvec, x):
lambda x, y: positive_definite_solve(high_precision_dot(x, x.T), y),
(a, b), order=2, rtol=1e-2)

@jtu.skip_on_flag("jax_skip_slow_tests", True)
def test_custom_linear_solve_lu(self):

# TODO(b/143528110): re-enable when underlying XLA TPU issue is fixed
Expand Down Expand Up @@ -1652,6 +1659,7 @@ def transpose_solve(vecmat, x):
jtu.check_grads(api.jit(linear_solve), (a, b), order=2,
rtol={onp.float32: 2e-3})

@jtu.skip_on_flag("jax_skip_slow_tests", True)
def test_custom_linear_solve_without_transpose_solve(self):

def explicit_jacobian_solve(matvec, b):
Expand All @@ -1674,6 +1682,7 @@ def loss(a, b):
with self.assertRaisesRegex(TypeError, "transpose_solve required"):
api.grad(loss)(a, b)

@jtu.skip_on_flag("jax_skip_slow_tests", True)
def test_custom_linear_solve_pytree(self):
"""Test custom linear solve with inputs and outputs that are pytrees."""

Expand Down
4 changes: 4 additions & 0 deletions tests/lax_test.py
Expand Up @@ -564,6 +564,7 @@ def _transpose_conv_kernel(data, kernel, dimension_numbers):
for dspec in [('NHWC', 'HWIO', 'NHWC'),]
for rhs_dilation in [None, (2, 2)]
for rng_factory in [jtu.rand_small]))
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testConvTranspose2DT(self, lhs_shape, rhs_shape, dtype, strides,
padding, dspec, rhs_dilation, rng_factory):
rng = rng_factory()
Expand Down Expand Up @@ -602,6 +603,7 @@ def fun_via_grad(lhs, rhs):
for dspec in [('NHWC', 'HWIO', 'NHWC'),]
for rhs_dilation in [None, (2, 2)]
for rng_factory in [jtu.rand_small]))
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testConvTranspose2D(self, lhs_shape, rhs_shape, dtype, strides,
padding, dspec, rhs_dilation, rng_factory):
rng = rng_factory()
Expand Down Expand Up @@ -2281,6 +2283,7 @@ def testReduceGrad(self, op, init_val, shape, dtype, dims, rng_factory):
for dtype in dtypes
for padding in ["VALID", "SAME"]
for rng_factory in [jtu.rand_default]))
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testReduceWindowGrad(self, op, init_val, dtype, padding, rng_factory):
rng = rng_factory()
tol = {onp.float16: 1e-1, onp.float32: 1e-3}
Expand Down Expand Up @@ -2970,6 +2973,7 @@ def fun(operand):
for dtype in float_dtypes
for padding in ["VALID", "SAME"]
for rng_factory in [jtu.rand_small]))
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testSelectAndGatherAdd(self, dtype, padding, rng_factory):
if jtu.device_under_test() == "tpu" and dtype == dtypes.bfloat16:
raise SkipTest("bfloat16 _select_and_gather_add doesn't work on tpu")
Expand Down
19 changes: 19 additions & 0 deletions tests/linalg_test.py
Expand Up @@ -17,6 +17,7 @@
from functools import partial
import itertools
import unittest
import sys

import numpy as onp
import scipy as osp
Expand Down Expand Up @@ -51,6 +52,10 @@ def _skip_if_unsupported_type(dtype):
dtype in (onp.dtype('float64'), onp.dtype('complex128'))):
raise unittest.SkipTest("--jax_enable_x64 is not set")

# TODO(phawkins): bug https://github.com/google/jax/issues/2166
def _skip_on_mac_xla_bug():
if sys.platform == "darwin" and osp.version.version > "1.0.0":
raise unittest.SkipTest("Test fails on Mac with new scipy (issue #2166)")

class NumpyLinalgTest(jtu.JaxTestCase):

Expand Down Expand Up @@ -134,6 +139,7 @@ def testSlogdet(self, shape, dtype, rng_factory):
for dtype in float_types
for rng_factory in [jtu.rand_default]))
@jtu.skip_on_devices("tpu")
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testSlogdetGrad(self, shape, dtype, rng_factory):
rng = rng_factory()
_skip_if_unsupported_type(dtype)
Expand Down Expand Up @@ -188,6 +194,8 @@ def norm(x):
def testEigvals(self, shape, dtype, rng_factory):
rng = rng_factory()
_skip_if_unsupported_type(dtype)
if shape == (50, 50) and dtype == onp.complex64:
_skip_on_mac_xla_bug()
n = shape[-1]
args_maker = lambda: [rng(shape, dtype)]
a, = args_maker()
Expand Down Expand Up @@ -565,6 +573,8 @@ def testSolve(self, lhs_shape, rhs_shape, dtype, rng_factory):
def testInv(self, shape, dtype, rng_factory):
rng = rng_factory()
_skip_if_unsupported_type(dtype)
if shape == (200, 200) and dtype == onp.float32:
_skip_on_mac_xla_bug()
if jtu.device_under_test() == "gpu" and shape == (200, 200):
raise unittest.SkipTest("Test is flaky on GPU")

Expand Down Expand Up @@ -594,6 +604,8 @@ def args_maker():
def testPinv(self, shape, dtype, rng_factory):
rng = rng_factory()
_skip_if_unsupported_type(dtype)
if shape == (7, 10000) and dtype in [onp.complex64, onp.float32]:
_skip_on_mac_xla_bug()
args_maker = lambda: [rng(shape, dtype)]

self._CheckAgainstNumpy(onp.linalg.pinv, np.linalg.pinv, args_maker,
Expand Down Expand Up @@ -650,6 +662,7 @@ def test(x):
xc = onp.eye(3, dtype=onp.complex)
self.assertAllClose(xc, grad_test_jc(xc), check_dtypes=True)

@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testIssue1151(self):
A = np.array(onp.random.randn(100, 3, 3), dtype=np.float32)
b = np.array(onp.random.randn(100, 3), dtype=np.float32)
Expand All @@ -661,6 +674,7 @@ def testIssue1151(self):
jac0 = jax.jacobian(np.linalg.solve, argnums=0)(A[0], b[0])
jac1 = jax.jacobian(np.linalg.solve, argnums=1)(A[0], b[0])

@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testIssue1383(self):
seed = jax.random.PRNGKey(0)
tmp = jax.random.uniform(seed, (2,2))
Expand Down Expand Up @@ -726,6 +740,7 @@ def testLuOfSingularMatrix(self):
for dtype in float_types + complex_types
for rng_factory in [jtu.rand_default]))
@jtu.skip_on_devices("tpu") # TODO(phawkins): precision problems on TPU.
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testLuGrad(self, shape, dtype, rng_factory):
rng = rng_factory()
_skip_if_unsupported_type(dtype)
Expand Down Expand Up @@ -764,6 +779,8 @@ def testLuBatching(self, shape, dtype, rng_factory):
def testLuFactor(self, n, dtype, rng_factory):
rng = rng_factory()
_skip_if_unsupported_type(dtype)
if n == 200 and dtype == onp.complex64:
_skip_on_mac_xla_bug()
args_maker = lambda: [rng((n, n), dtype)]

x, = args_maker()
Expand Down Expand Up @@ -985,6 +1002,8 @@ def testTriangularSolveGradPrecision(self):
def testExpm(self, n, dtype, rng_factory):
rng = rng_factory()
_skip_if_unsupported_type(dtype)
if n == 50 and dtype in [onp.complex64, onp.float32]:
_skip_on_mac_xla_bug()
args_maker = lambda: [rng((n, n), dtype)]

osp_fun = lambda a: osp.linalg.expm(a)
Expand Down
2 changes: 2 additions & 0 deletions tests/nn_test.py
Expand Up @@ -34,6 +34,7 @@

class NNFunctionsTest(jtu.JaxTestCase):

@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testSoftplusGrad(self):
check_grads(nn.softplus, (1e-8,), 4,
rtol=1e-2 if jtu.device_under_test() == "tpu" else None)
Expand All @@ -42,6 +43,7 @@ def testSoftplusValue(self):
val = nn.softplus(89.)
self.assertAllClose(val, 89., check_dtypes=False)

@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testEluGrad(self):
check_grads(nn.elu, (1e4,), 4, eps=1.)

Expand Down

0 comments on commit b18a4d8

Please sign in to comment.