diff --git a/brainpy/math/tests/test_numpy_indexing.py b/brainpy/math/tests/test_numpy_indexing.py index 05071745c..51ebea6bc 100644 --- a/brainpy/math/tests/test_numpy_indexing.py +++ b/brainpy/math/tests/test_numpy_indexing.py @@ -1013,6 +1013,7 @@ def _update_tol(op): return tol +@jtu.with_config(jax_numpy_dtype_promotion='standard') class IndexedUpdateTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ diff --git a/brainpy/math/tests/test_numpy_ops.py b/brainpy/math/tests/test_numpy_ops.py index 9cc81cbe9..678bcd555 100644 --- a/brainpy/math/tests/test_numpy_ops.py +++ b/brainpy/math/tests/test_numpy_ops.py @@ -546,6 +546,7 @@ def wrapper(*args, **kw): return wrapper +@jtu.with_config(jax_numpy_dtype_promotion='standard') class LaxBackedNumpyTests(jtu.JaxTestCase): """Tests for LAX-backed Numpy implementation.""" @@ -4214,8 +4215,6 @@ def bm_fun(a, c): else: self._CompileAndCheck(bm_func(bm_fun), args_maker) - -class u(jtu.JaxTestCase): def _GetArgsMaker(self, rng, shapes, dtypes, np_arrays=True): def f(): out = [rng(shape, dtype or jnp.float_) @@ -4251,8 +4250,6 @@ def np_fun(index, shape): self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) self._CompileAndCheck(bm_func(bm_fun), args_maker) - -class u1(jtu.JaxTestCase): def testAstype(self): rng = self.rng() args_maker = lambda: [rng.randn(3, 4).astype("float32")] @@ -5994,9 +5991,8 @@ def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None): GradSpecialValuesTestSpec(bm.sinc, [0.], 1), ] - +@jtu.with_config(jax_numpy_dtype_promotion='standard') class NumpyGradTests(jtu.JaxTestCase): - @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix( @@ -6099,9 +6095,8 @@ def _dtypes_for_ufunc(name: str) -> Iterator[Tuple[str, ...]]: else: yield arg_dtypes - +@jtu.with_config(jax_numpy_dtype_promotion='standard') class NumpyUfuncTests(jtu.JaxTestCase): - @parameterized.named_parameters( {"testcase_name": f"_{name}_{','.join(arg_dtypes)}", "name": name, "arg_dtypes": arg_dtypes}