From c56452e86ad3613342b43f428fdf6524c4a4a1e4 Mon Sep 17 00:00:00 2001 From: Brandon Zhang Date: Tue, 12 Jul 2022 20:51:50 +0800 Subject: [PATCH 1/2] fix: add dtype promotion = standard --- brainpy/math/tests/test_numpy_ops.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/brainpy/math/tests/test_numpy_ops.py b/brainpy/math/tests/test_numpy_ops.py index 9cc81cbe9..c182242c3 100644 --- a/brainpy/math/tests/test_numpy_ops.py +++ b/brainpy/math/tests/test_numpy_ops.py @@ -546,6 +546,7 @@ def wrapper(*args, **kw): return wrapper +@jtu.with_config(jax_numpy_dtype_promotion='standard') class LaxBackedNumpyTests(jtu.JaxTestCase): """Tests for LAX-backed Numpy implementation.""" @@ -4214,8 +4215,6 @@ def bm_fun(a, c): else: self._CompileAndCheck(bm_func(bm_fun), args_maker) - -class u(jtu.JaxTestCase): def _GetArgsMaker(self, rng, shapes, dtypes, np_arrays=True): def f(): out = [rng(shape, dtype or jnp.float_) @@ -4251,8 +4250,6 @@ def np_fun(index, shape): self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) self._CompileAndCheck(bm_func(bm_fun), args_maker) - -class u1(jtu.JaxTestCase): def testAstype(self): rng = self.rng() args_maker = lambda: [rng.randn(3, 4).astype("float32")] From 2a299392f7cda9837b4a0c7bb84535bb5a595e8a Mon Sep 17 00:00:00 2001 From: Brandon Zhang Date: Tue, 12 Jul 2022 22:02:18 +0800 Subject: [PATCH 2/2] fix: add dtype promotion = standard --- brainpy/math/tests/test_numpy_indexing.py | 1 + brainpy/math/tests/test_numpy_ops.py | 6 ++---- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/brainpy/math/tests/test_numpy_indexing.py b/brainpy/math/tests/test_numpy_indexing.py index 05071745c..51ebea6bc 100644 --- a/brainpy/math/tests/test_numpy_indexing.py +++ b/brainpy/math/tests/test_numpy_indexing.py @@ -1013,6 +1013,7 @@ def _update_tol(op): return tol +@jtu.with_config(jax_numpy_dtype_promotion='standard') class IndexedUpdateTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ diff --git a/brainpy/math/tests/test_numpy_ops.py b/brainpy/math/tests/test_numpy_ops.py index c182242c3..678bcd555 100644 --- a/brainpy/math/tests/test_numpy_ops.py +++ b/brainpy/math/tests/test_numpy_ops.py @@ -5991,9 +5991,8 @@ def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None): GradSpecialValuesTestSpec(bm.sinc, [0.], 1), ] - +@jtu.with_config(jax_numpy_dtype_promotion='standard') class NumpyGradTests(jtu.JaxTestCase): - @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix( @@ -6096,9 +6095,8 @@ def _dtypes_for_ufunc(name: str) -> Iterator[Tuple[str, ...]]: else: yield arg_dtypes - +@jtu.with_config(jax_numpy_dtype_promotion='standard') class NumpyUfuncTests(jtu.JaxTestCase): - @parameterized.named_parameters( {"testcase_name": f"_{name}_{','.join(arg_dtypes)}", "name": name, "arg_dtypes": arg_dtypes}