From 05da18ab54c11eed30bb53d6551e5ac42f8d129a Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 21 Dec 2023 17:43:31 -0800 Subject: [PATCH] tweaks to enable adding custom tangent dtypes tweaks to enable adding custom tangent dtypes: * fix a bug in zeros_like_shaped_array and KeyTyRules.zero to ensure `scalar_zero` is actually a scalar * upgrade the adder handler for ShapedArray to delegate to an extended dtype rule for addition * convert_element_type shouldnt blanket-disallow extended dtypes; actually that can be a key operation for working with them! instead, add new `convert_from` and `convert_to` rules. instead of letting these rules perform arbitrary logic, for now they can just return a bool indicating whether the conversion is legit; if false, an error is raised, and if true, the existing convert_element_type lowering rule just generates a ConvertElementType HLO from one physical type to the other this pr also adds a test for a custom tangent dtype of interest for plumbing quantization scales out of a backward pass --- jax/_src/core.py | 2 +- jax/_src/lax/lax.py | 35 ++++++++++++++------- jax/_src/prng.py | 12 +++++-- tests/dtypes_test.py | 75 ++++++++++++++++++++++++++++++++++++++++++++ tests/random_test.py | 4 ++- 5 files changed, 113 insertions(+), 15 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index cf7078bb8798..e181525aa5ba 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1911,7 +1911,7 @@ def __len__(self): lambda x: DConcreteArray(x._aval.shape, x._aval.dtype, x._aval.weak_type, x._data) -@dataclass(frozen=True, eq=True) +@dataclass(frozen=True) class bint(dtypes.ExtendedDType): bound: int diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 5861dbb023e0..d386a09fac64 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1224,14 +1224,20 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None) -> def zeros_like_shaped_array(aval: ShapedArray) -> Array: assert isinstance(aval, ShapedArray) if dtypes.issubdtype(aval.dtype, dtypes.extended): - scalar_zero = aval.dtype._rules.zero(aval) + scalar_zero = aval.dtype._rules.zero(aval.dtype) elif aval.dtype == dtypes.float0: scalar_zero = np.zeros((), dtype=aval.dtype) else: scalar_zero = _convert_element_type(0, aval.dtype, aval.weak_type) return broadcast(scalar_zero, aval.shape) -ad_util.aval_adders[ShapedArray] = add +def add_shaped_arrays(x, y) -> Array: + aval = core.raise_to_shaped(core.get_aval(x)) + if dtypes.issubdtype(aval.dtype, dtypes.extended): + return aval.dtype._rules.add(aval.dtype, x, y) # type: ignore + return add(x, y) + +ad_util.aval_adders[ShapedArray] = add_shaped_arrays ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array def iota(dtype: DTypeLike, size: int) -> Array: @@ -2332,15 +2338,14 @@ def _convert_element_type_shape_rule(operand, *, new_dtype, weak_type): return operand.shape def _convert_element_type_dtype_rule(operand, *, new_dtype, weak_type): - if operand.dtype != new_dtype: - if (dtypes.issubdtype(operand.dtype, dtypes.extended) and - not isinstance(operand.dtype, core.bint)): - raise ValueError( - f"Cannot call convert_element_type on dtype {dtype_to_string(operand.dtype)}") - if (dtypes.issubdtype(new_dtype, dtypes.extended) and - not isinstance(new_dtype, core.bint)): - raise ValueError( - f"Cannot convert_element_type to dtype={dtype_to_string(new_dtype)}") + if (operand.dtype != new_dtype and + ((dtypes.issubdtype(operand.dtype, dtypes.extended) and + not operand.dtype._rules.convert_from(operand.dtype, new_dtype)) or # type: ignore + (dtypes.issubdtype(new_dtype, dtypes.extended) and + not new_dtype._rules.convert_to(operand.dtype, new_dtype)))): # type: ignore + raise ValueError( + f"Cannot convert_element_type from {dtype_to_string(operand.dtype)} " + f"to {dtype_to_string(new_dtype)}") return new_dtype def _convert_element_type_weak_type_rule(operand, *, new_dtype, weak_type): @@ -5015,4 +5020,12 @@ def handler(bufs): def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding: return hlo_sharding + @staticmethod + def convert_from(bint_dtype, other_dtype) -> bool: + return other_dtype in (np.dtype('int32'), np.dtype('int64')) + + @staticmethod + def convert_to(other_dtype, bint_dtype) -> bool: + return other_dtype in (np.dtype('int32'), np.dtype('int64')) + core.bint._rules = BIntRules diff --git a/jax/_src/prng.py b/jax/_src/prng.py index bba41633f682..e6af55718a64 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -599,8 +599,16 @@ def tangent_dtype(_): # tangents, our ad.replace_float0s in custom_jvp/vjp means passing in zeros # like the primal to user rules @staticmethod - def zero(aval): - return lax_internal.zeros_like_shaped_array(aval.update(dtype=dtypes.float0)) + def zero(_): + return np.zeros((), dtypes.float0) + + @staticmethod + def convert_from(key_dtype, other_dtype) -> bool: + return False + + @staticmethod + def convert_to(other_dtype, key_dtype) -> bool: + return False class KeyTy(dtypes.ExtendedDType): diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index 6e3b45d5f6cf..6d972c9504cf 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -13,8 +13,10 @@ # limitations under the License. +import dataclasses import enum import functools +from functools import partial import itertools import operator @@ -376,6 +378,79 @@ def testDefaultDtypes(self): self.assertEqual(dtypes.float_, np.float32 if precision == '32' else np.float64) self.assertEqual(dtypes.complex_, np.complex64 if precision == '32' else np.complex128) + def test_custom_tangent_dtype(self): + from jax._src import core + + class scale(dtypes.extended): + pass + + class ScalesTyRules: + @staticmethod + def physical_element_aval(dtype) -> core.ShapedArray: + return core.ShapedArray((), dtype.float_dtype) + + @staticmethod + def global_sharded_result_handler(aval, sharding, committed, is_from_xla): + raise NotImplementedError("convert back under the jit") + + @staticmethod + def add(dt, x, y): + fromscale = partial(jax.lax.convert_element_type, new_dtype=dt.float_dtype) + toscale = partial(jax.lax.convert_element_type, new_dtype=dt) + return toscale(jax.lax.max(fromscale(x), fromscale(y))) + + @staticmethod + def zero(dt): + neginf = np.array(-np.inf if dtypes.supports_inf(dt.float_dtype) + else dtypes.finfo(dt.float_dtype).min, dt.float_dtype) + return jax.lax.convert_element_type(neginf, dt) + + @staticmethod + def convert_from(dtype, other_dtype) -> bool: + return dtype.float_dtype == other_dtype + + @staticmethod + def convert_to(other_dtype, dtype) -> bool: + return dtype.float_dtype == other_dtype + + @dataclasses.dataclass(frozen=True) + class ScaleTy(dtypes.ExtendedDType): + float_dtype: dtypes.DType + name: str = 'scale' + _rules: type = ScalesTyRules + type: type = scale + + @jax.custom_vjp + def g(x): + return x + def g_fwd(x): + return x, None + def g_bwd(_, ct): + ct = jax.lax.convert_element_type(ct, ScaleTy(dtypes.float8_e5m2)) + return ct, + g.defvjp(g_fwd, g_bwd) + + @jax.custom_vjp + def convert(x): + return x + def convert_fwd(x): + return x, None + def convert_bwd(_, ct): + ct = jax.lax.convert_element_type(ct, ct.dtype.float_dtype) + return ct, + convert.defvjp(convert_fwd, convert_bwd) + + @jax.jit + def f(x): + x = convert(x) + x = g(x) + g(x) + return x + + x = jnp.array(3., dtypes.float8_e5m2) + out = jax.grad(f)(x) + self.assertAllClose(out, 1., check_dtypes=False) + self.assertTrue(dtypes.issubdtype(ScaleTy(dtypes.float8_e5m2), scale)) + class TestPromotionTables(jtu.JaxTestCase): diff --git a/tests/random_test.py b/tests/random_test.py index fb76aeb5d26c..6c572c72de9c 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -1399,8 +1399,10 @@ def test_errors(self): jnp.negative(key) with self.assertRaisesRegex(TypeError, "neg does not accept dtype key"): -key - with self.assertRaisesRegex(ValueError, "Cannot call convert_element_type on dtype key"): + with self.assertRaisesRegex(ValueError, "Cannot convert_element_type from key to int64"): lax.convert_element_type(key, int) + with self.assertRaisesRegex(ValueError, "Cannot convert_element_type from int32 to key"): + lax.convert_element_type(np.int32(0), key.dtype) def test_eval_shape(self): key = random.key(1701)