diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 01323563d70d..a1da8d097625 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -1310,8 +1310,8 @@ def f1(x, y, indices): def f2(x, y, indices): return sparse.bcoo_dot_general_sampled(x, y, indices, dimension_numbers=dimension_numbers) - self._CheckAgainstNumpy(f1, f2, args_maker) - self._CompileAndCheck(f2, args_maker) + self._CheckAgainstNumpy(f1, f2, args_maker, tol=MATMUL_TOL) + self._CompileAndCheck(f2, args_maker, tol=MATMUL_TOL) @jtu.sample_product( [dict(n_batch=n_batch, n_dense=n_dense, lhs_shape=lhs_shape,