From 8125e8bd0321f34d0d4776deb5620ea57a4b27b1 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 22 Sep 2023 11:37:31 -0700 Subject: [PATCH] issubdtype: fix corner cases with extended dtypes --- jax/_src/dtypes.py | 31 +++++++++++++++++++++++-------- tests/dtypes_test.py | 10 ++++++++++ tests/random_test.py | 6 ++++++ 3 files changed, 39 insertions(+), 8 deletions(-) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index f4e243b608f7..878712d4060a 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -325,20 +325,33 @@ def issubdtype(a: DTypeLike, b: DTypeLike) -> bool: """Returns True if first argument is a typecode lower/equal in type hierarchy. This is like :func:`numpy.issubdtype`, but can handle dtype extensions such as - :obj:`jax.dtypes.bfloat16`. + :obj:`jax.dtypes.bfloat16` and `jax.dtypes.prng_key`. """ - # Handle extended types & canonicalizes all concrete types to np.dtype instances. + # Main departures from np.issubdtype are: + # - "extended" dtypes (like prng key types) are not normal numpy dtypes, so we + # need to handle them specifically. However, their scalar types do conform to + # the numpy scalar type hierarchy. + # - custom dtypes (like bfloat16, int4, etc.) are normal numpy dtypes, but they + # don't conform to the standard numpy type hierarchy (e.g. the bfloat16 scalar + # type is not a subclass of np.floating) so we must also handle these specially. + + # First we handle extended dtypes (like prng key types) if isinstance(a, ExtendedDType): - return _issubclass(a.type, b) - a = a if _is_typeclass(a) else np.dtype(a) - - if _issubclass(b, extended): + if isinstance(b, ExtendedDType): + return a == b + else: + a = a.type + elif isinstance(b, ExtendedDType): return False - b = b if _is_typeclass(b) else np.dtype(b) + + # Now do special handling of custom float and int types. To do this, we first + # convert scalar types to dtypes in order to recognize custom floats & ints. + # We cannot use issubclass(a, np.generic) because scalar types would satisfy this. + a = a if _is_typeclass(a) or _issubclass(a, extended) else np.dtype(a) + b = b if _is_typeclass(b) or _issubclass(b, extended) else np.dtype(b) if isinstance(a, np.dtype): if a in _custom_float_dtypes: - # Avoid implicitly casting list elements below to a dtype. if isinstance(b, np.dtype): return a == b return b in [np.floating, np.inexact, np.number, np.generic] @@ -350,6 +363,8 @@ def issubdtype(a: DTypeLike, b: DTypeLike) -> bool: if isinstance(b, np.dtype): return a == b return b in [np.unsignedinteger, np.integer, np.number, np.generic] + + # Otherwise, fall back to numpy.issubdtype return np.issubdtype(a, b) can_cast = np.can_cast diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index 0c25b50cda3a..3512362a04a7 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -288,6 +288,16 @@ def testIsSubdtype(self): self.assertEqual(dtypes.issubdtype(t, category), np.issubdtype(np.dtype(t).type, category)) + def testIsSubdtypeExtended(self): + self.assertTrue(dtypes.issubdtype(dtypes.extended, dtypes.extended)) + self.assertTrue(dtypes.issubdtype(dtypes.extended, np.generic)) + self.assertFalse(dtypes.issubdtype(dtypes.extended, np.number)) + + self.assertTrue(jnp.issubdtype(dtypes.prng_key, dtypes.prng_key)) + self.assertTrue(jnp.issubdtype(dtypes.prng_key, dtypes.extended)) + self.assertTrue(jnp.issubdtype(dtypes.prng_key, np.generic)) + self.assertFalse(dtypes.issubdtype(dtypes.prng_key, np.number)) + @parameterized.product(dtype=custom_float_dtypes) def testIsSubdtypeCustomFloats(self, dtype): for dt in [dtype, np.dtype(dtype), str(np.dtype(dtype))]: diff --git a/tests/random_test.py b/tests/random_test.py index f64d301af5b8..83de9adf550c 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -1740,8 +1740,14 @@ def test_construction(self): def test_issubdtype(self): key = random.key(42) + + self.assertTrue(jnp.issubdtype(key.dtype, key.dtype)) self.assertTrue(jnp.issubdtype(key.dtype, dtypes.prng_key)) + self.assertTrue(jnp.issubdtype(key.dtype, dtypes.extended)) + self.assertTrue(jnp.issubdtype(key.dtype, np.generic)) + self.assertFalse(jnp.issubdtype(key.dtype, np.integer)) + self.assertFalse(jnp.issubdtype(key.dtype, np.number)) with self.assertRaisesRegex(TypeError, "Cannot interpret"): jnp.issubdtype(key, dtypes.prng_key)