Skip to content

Commit

Permalink
Improve behavior of core.valid_jaxtype
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Sep 22, 2023
1 parent 4269705 commit bfed3d8
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
2 changes: 1 addition & 1 deletion jax/_src/core.py
Expand Up @@ -1671,10 +1671,10 @@ def __init__(self, dtype, val, weak_type=None):
super().__init__(
np.shape(val), dtype,
weak_type=dtypes.is_weakly_typed(val) if weak_type is None else weak_type)
dtypes.check_valid_dtype(self.dtype)
# Note: canonicalized self.dtype doesn't necessarily match self.val
assert self.dtype == dtypes.canonicalize_dtype(np.result_type(val)), (val, dtype)
self.val = val
assert self.dtype != np.dtype('O'), val

def update(self, dtype=None, val=None, weak_type=None):
dtype = self.dtype if dtype is None else dtype
Expand Down
13 changes: 11 additions & 2 deletions tests/api_test.py
Expand Up @@ -2958,8 +2958,12 @@ def check_warning(warn, nowarn):
lambda: jnp.arange(1.0).astype(int))

def test_error_for_invalid_dtype(self):
with self.assertRaisesRegex(TypeError, ".*not a valid JAX array type.*"):
lax.add(jnp.array(7), np.array("hello"))
with jax.enable_checks(False):
with self.assertRaisesRegex(TypeError, ".*not a valid JAX array type.*"):
lax.add(jnp.array(7), np.array("hello"))
with jax.enable_checks(True):
with self.assertRaises(AssertionError):
lax.add(jnp.array(7), np.array("hello"))

def test_vmap_preserves_docstr(self):
def superfun(a):
Expand Down Expand Up @@ -3213,6 +3217,11 @@ def f(x, y): return x + y
"positional arguments.",
lambda: partial(df, x=0.)(y=1.))

def test_grad_object_array_error(self):
x = np.array([1, 2, 3], dtype=object)
with self.assertRaisesRegex(TypeError, ".*is not a valid JAX type"):
jax.grad(lambda x: x)(x)

def test_jit_compilation_time_logging(self):
@api.jit
def f(x):
Expand Down
10 changes: 10 additions & 0 deletions tests/core_test.py
Expand Up @@ -196,6 +196,16 @@ def test_tree_unflatten(self):
nodes_equal = tree_map(operator.eq, tree, tree2)
assert tree_reduce(operator.and_, nodes_equal)

@jtu.sample_product(
dtype=[*jtu.dtypes.all, object, [('i', 'i4'), ('f', 'f4')]]
)
def test_is_valid_jaxtype(self, dtype):
arr = np.zeros(10, dtype=dtype)
if dtype in jtu.dtypes.all:
self.assertTrue(core.valid_jaxtype(arr))
else:
self.assertFalse(core.valid_jaxtype(arr))

@parameterized.named_parameters(
(str(i), *spec) for i, spec in enumerate(test_specs))
def test_jit(self, f, args):
Expand Down

0 comments on commit bfed3d8

Please sign in to comment.