From bfed3d862ec8921ce323ece9a9e09d308c0d7f13 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 22 Sep 2023 13:46:09 -0700 Subject: [PATCH] Improve behavior of core.valid_jaxtype --- jax/_src/core.py | 2 +- tests/api_test.py | 13 +++++++++++-- tests/core_test.py | 10 ++++++++++ 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index bc062917f9ae..521cd657c183 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1671,10 +1671,10 @@ def __init__(self, dtype, val, weak_type=None): super().__init__( np.shape(val), dtype, weak_type=dtypes.is_weakly_typed(val) if weak_type is None else weak_type) + dtypes.check_valid_dtype(self.dtype) # Note: canonicalized self.dtype doesn't necessarily match self.val assert self.dtype == dtypes.canonicalize_dtype(np.result_type(val)), (val, dtype) self.val = val - assert self.dtype != np.dtype('O'), val def update(self, dtype=None, val=None, weak_type=None): dtype = self.dtype if dtype is None else dtype diff --git a/tests/api_test.py b/tests/api_test.py index 5e6ed2096416..146997e0908e 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -2958,8 +2958,12 @@ def check_warning(warn, nowarn): lambda: jnp.arange(1.0).astype(int)) def test_error_for_invalid_dtype(self): - with self.assertRaisesRegex(TypeError, ".*not a valid JAX array type.*"): - lax.add(jnp.array(7), np.array("hello")) + with jax.enable_checks(False): + with self.assertRaisesRegex(TypeError, ".*not a valid JAX array type.*"): + lax.add(jnp.array(7), np.array("hello")) + with jax.enable_checks(True): + with self.assertRaises(AssertionError): + lax.add(jnp.array(7), np.array("hello")) def test_vmap_preserves_docstr(self): def superfun(a): @@ -3213,6 +3217,11 @@ def f(x, y): return x + y "positional arguments.", lambda: partial(df, x=0.)(y=1.)) + def test_grad_object_array_error(self): + x = np.array([1, 2, 3], dtype=object) + with self.assertRaisesRegex(TypeError, ".*is not a valid JAX type"): + jax.grad(lambda x: x)(x) + def test_jit_compilation_time_logging(self): @api.jit def f(x): diff --git a/tests/core_test.py b/tests/core_test.py index 783d3721e7ef..3956d2a20a5f 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -196,6 +196,16 @@ def test_tree_unflatten(self): nodes_equal = tree_map(operator.eq, tree, tree2) assert tree_reduce(operator.and_, nodes_equal) + @jtu.sample_product( + dtype=[*jtu.dtypes.all, object, [('i', 'i4'), ('f', 'f4')]] + ) + def test_is_valid_jaxtype(self, dtype): + arr = np.zeros(10, dtype=dtype) + if dtype in jtu.dtypes.all: + self.assertTrue(core.valid_jaxtype(arr)) + else: + self.assertFalse(core.valid_jaxtype(arr)) + @parameterized.named_parameters( (str(i), *spec) for i, spec in enumerate(test_specs)) def test_jit(self, f, args):