Skip to content

Commit

Permalink
issubdtype: fix corner cases with extended dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Sep 22, 2023
1 parent 4269705 commit 8125e8b
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 8 deletions.
31 changes: 23 additions & 8 deletions jax/_src/dtypes.py
Expand Up @@ -325,20 +325,33 @@ def issubdtype(a: DTypeLike, b: DTypeLike) -> bool:
"""Returns True if first argument is a typecode lower/equal in type hierarchy.
This is like :func:`numpy.issubdtype`, but can handle dtype extensions such as
:obj:`jax.dtypes.bfloat16`.
:obj:`jax.dtypes.bfloat16` and `jax.dtypes.prng_key`.
"""
# Handle extended types & canonicalizes all concrete types to np.dtype instances.
# Main departures from np.issubdtype are:
# - "extended" dtypes (like prng key types) are not normal numpy dtypes, so we
# need to handle them specifically. However, their scalar types do conform to
# the numpy scalar type hierarchy.
# - custom dtypes (like bfloat16, int4, etc.) are normal numpy dtypes, but they
# don't conform to the standard numpy type hierarchy (e.g. the bfloat16 scalar
# type is not a subclass of np.floating) so we must also handle these specially.

# First we handle extended dtypes (like prng key types)
if isinstance(a, ExtendedDType):
return _issubclass(a.type, b)
a = a if _is_typeclass(a) else np.dtype(a)

if _issubclass(b, extended):
if isinstance(b, ExtendedDType):
return a == b
else:
a = a.type
elif isinstance(b, ExtendedDType):
return False
b = b if _is_typeclass(b) else np.dtype(b)

# Now do special handling of custom float and int types. To do this, we first
# convert scalar types to dtypes in order to recognize custom floats & ints.
# We cannot use issubclass(a, np.generic) because scalar types would satisfy this.
a = a if _is_typeclass(a) or _issubclass(a, extended) else np.dtype(a)
b = b if _is_typeclass(b) or _issubclass(b, extended) else np.dtype(b)

if isinstance(a, np.dtype):
if a in _custom_float_dtypes:
# Avoid implicitly casting list elements below to a dtype.
if isinstance(b, np.dtype):
return a == b
return b in [np.floating, np.inexact, np.number, np.generic]
Expand All @@ -350,6 +363,8 @@ def issubdtype(a: DTypeLike, b: DTypeLike) -> bool:
if isinstance(b, np.dtype):
return a == b
return b in [np.unsignedinteger, np.integer, np.number, np.generic]

# Otherwise, fall back to numpy.issubdtype
return np.issubdtype(a, b)

can_cast = np.can_cast
Expand Down
10 changes: 10 additions & 0 deletions tests/dtypes_test.py
Expand Up @@ -288,6 +288,16 @@ def testIsSubdtype(self):
self.assertEqual(dtypes.issubdtype(t, category),
np.issubdtype(np.dtype(t).type, category))

def testIsSubdtypeExtended(self):
self.assertTrue(dtypes.issubdtype(dtypes.extended, dtypes.extended))
self.assertTrue(dtypes.issubdtype(dtypes.extended, np.generic))
self.assertFalse(dtypes.issubdtype(dtypes.extended, np.number))

self.assertTrue(jnp.issubdtype(dtypes.prng_key, dtypes.prng_key))
self.assertTrue(jnp.issubdtype(dtypes.prng_key, dtypes.extended))
self.assertTrue(jnp.issubdtype(dtypes.prng_key, np.generic))
self.assertFalse(dtypes.issubdtype(dtypes.prng_key, np.number))

@parameterized.product(dtype=custom_float_dtypes)
def testIsSubdtypeCustomFloats(self, dtype):
for dt in [dtype, np.dtype(dtype), str(np.dtype(dtype))]:
Expand Down
6 changes: 6 additions & 0 deletions tests/random_test.py
Expand Up @@ -1740,8 +1740,14 @@ def test_construction(self):

def test_issubdtype(self):
key = random.key(42)

self.assertTrue(jnp.issubdtype(key.dtype, key.dtype))
self.assertTrue(jnp.issubdtype(key.dtype, dtypes.prng_key))
self.assertTrue(jnp.issubdtype(key.dtype, dtypes.extended))
self.assertTrue(jnp.issubdtype(key.dtype, np.generic))

self.assertFalse(jnp.issubdtype(key.dtype, np.integer))
self.assertFalse(jnp.issubdtype(key.dtype, np.number))
with self.assertRaisesRegex(TypeError, "Cannot interpret"):
jnp.issubdtype(key, dtypes.prng_key)

Expand Down

0 comments on commit 8125e8b

Please sign in to comment.