Skip to content

Commit

Permalink
Relax some test tolerances for TPU.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 576192162
  • Loading branch information
hawkinsp authored and jax authors committed Oct 24, 2023
1 parent 8b05b16 commit e7f1d29
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 6 deletions.
7 changes: 4 additions & 3 deletions tests/lax_numpy_indexing_test.py
Expand Up @@ -455,20 +455,21 @@ def jnp_op(x, idx):

# Test with traced integer index
args_maker = lambda: [rng(size, dtype), idx_rng(size, int)]
atol = (
tol = (
5e-5
if jtu.test_device_matches(["tpu"]) and funcname in ("log", "exp")
else None
)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, atol=atol)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker, atol=tol)
self._CompileAndCheck(jnp_op, args_maker)

# Test with slice index
idx = slice(1, 5)
np_op_idx = partial(np_op, idx=idx)
jnp_op_idx = partial(jnp_op, idx=idx)
args_maker = lambda: [rng(size, dtype)]
self._CheckAgainstNumpy(np_op_idx, jnp_op_idx, args_maker, atol=atol)
self._CheckAgainstNumpy(np_op_idx, jnp_op_idx, args_maker, atol=tol,
rtol=tol)
self._CompileAndCheck(jnp_op_idx, args_maker)

def testIndexApplyBatchingBug(self):
Expand Down
5 changes: 4 additions & 1 deletion tests/lax_numpy_test.py
Expand Up @@ -2565,7 +2565,10 @@ def testWindowFunction(self, name, size, **kwds):
jnp_fun = partial(getattr(jnp, name), size, **kwds)
np_fun = jtu.with_jax_dtype_defaults(partial(getattr(np, name), size, **kwds))
args_maker = lambda: []
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
tol = (
5e-6 if jtu.test_device_matches(['tpu']) and name == 'kaiser' else None
)
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker, atol=tol, rtol=tol)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
Expand Down
2 changes: 1 addition & 1 deletion tests/linalg_test.py
Expand Up @@ -726,7 +726,7 @@ def compare_orthogonal(q1, q2):

# Check a ~= qr
norm_error = norm(a - np.matmul(lq, lr))
self.assertTrue(np.all(norm_error < 45), msg=np.amax(norm_error))
self.assertTrue(np.all(norm_error < 60), msg=np.amax(norm_error))

# Compare the first 'k' vectors of Q; the remainder form an arbitrary
# orthonormal basis for the null space.
Expand Down
2 changes: 1 addition & 1 deletion tests/nn_test.py
Expand Up @@ -150,7 +150,7 @@ def testSoftmaxWhereMask(self, fn):

def testSoftmaxGrad(self):
x = jnp.array([5.5, 1.3, -4.2, 0.9])
jtu.check_grads(nn.softmax, (x,), order=2, atol=3e-3)
jtu.check_grads(nn.softmax, (x,), order=2, atol=5e-3)

def testSoftmaxGradResiduals(self):
if not config.softmax_custom_jvp.value:
Expand Down

0 comments on commit e7f1d29

Please sign in to comment.