diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 72475e96f55b..c79d01acfdd3 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -5090,8 +5090,9 @@ def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol): ) for rec in GRAD_SPECIAL_VALUE_TEST_RECORDS)) def testOpGradSpecialValue(self, op, special_value, order): - check_grads(op, (special_value,), order, ["fwd", "rev"], - atol={np.float32: 3e-3}) + check_grads( + op, (special_value,), order, ['fwd', 'rev'], atol={np.float32: 4e-3} + ) def testSincAtZero(self): # Some manual tests for sinc at zero, since it doesn't have well-behaved