Skip to content

Commit

Permalink
tweaks to enable adding custom tangent dtypes
Browse files Browse the repository at this point in the history
tweaks to enable adding custom tangent dtypes:
* fix a bug in zeros_like_shaped_array and KeyTyRules.zero to ensure `scalar_zero` is actually a scalar
* upgrade the adder handler for ShapedArray to delegate to an extended dtype rule for addition
* convert_element_type shouldnt blanket-disallow extended dtypes; actually that can be a key operation for working with them! instead, add new `convert_from` and `convert_to` rules. instead of letting these rules perform arbitrary logic, for now they can just return a bool indicating whether the conversion is legit; if false, an error is raised, and if true, the existing convert_element_type lowering rule just generates a ConvertElementType HLO from one physical type to the other

this pr also adds a test for a custom tangent dtype of interest for plumbing quantization scales out of a backward pass
  • Loading branch information
mattjj committed Dec 22, 2023
1 parent 32e1a0c commit 05da18a
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 15 deletions.
2 changes: 1 addition & 1 deletion jax/_src/core.py
Expand Up @@ -1911,7 +1911,7 @@ def __len__(self):
lambda x: DConcreteArray(x._aval.shape, x._aval.dtype, x._aval.weak_type,
x._data)

@dataclass(frozen=True, eq=True)
@dataclass(frozen=True)
class bint(dtypes.ExtendedDType):
bound: int

Expand Down
35 changes: 24 additions & 11 deletions jax/_src/lax/lax.py
Expand Up @@ -1224,14 +1224,20 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None) ->
def zeros_like_shaped_array(aval: ShapedArray) -> Array:
assert isinstance(aval, ShapedArray)
if dtypes.issubdtype(aval.dtype, dtypes.extended):
scalar_zero = aval.dtype._rules.zero(aval)
scalar_zero = aval.dtype._rules.zero(aval.dtype)
elif aval.dtype == dtypes.float0:
scalar_zero = np.zeros((), dtype=aval.dtype)
else:
scalar_zero = _convert_element_type(0, aval.dtype, aval.weak_type)
return broadcast(scalar_zero, aval.shape)

ad_util.aval_adders[ShapedArray] = add
def add_shaped_arrays(x, y) -> Array:
aval = core.raise_to_shaped(core.get_aval(x))
if dtypes.issubdtype(aval.dtype, dtypes.extended):
return aval.dtype._rules.add(aval.dtype, x, y) # type: ignore
return add(x, y)

ad_util.aval_adders[ShapedArray] = add_shaped_arrays
ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array

def iota(dtype: DTypeLike, size: int) -> Array:
Expand Down Expand Up @@ -2332,15 +2338,14 @@ def _convert_element_type_shape_rule(operand, *, new_dtype, weak_type):
return operand.shape

def _convert_element_type_dtype_rule(operand, *, new_dtype, weak_type):
if operand.dtype != new_dtype:
if (dtypes.issubdtype(operand.dtype, dtypes.extended) and
not isinstance(operand.dtype, core.bint)):
raise ValueError(
f"Cannot call convert_element_type on dtype {dtype_to_string(operand.dtype)}")
if (dtypes.issubdtype(new_dtype, dtypes.extended) and
not isinstance(new_dtype, core.bint)):
raise ValueError(
f"Cannot convert_element_type to dtype={dtype_to_string(new_dtype)}")
if (operand.dtype != new_dtype and
((dtypes.issubdtype(operand.dtype, dtypes.extended) and
not operand.dtype._rules.convert_from(operand.dtype, new_dtype)) or # type: ignore
(dtypes.issubdtype(new_dtype, dtypes.extended) and
not new_dtype._rules.convert_to(operand.dtype, new_dtype)))): # type: ignore
raise ValueError(
f"Cannot convert_element_type from {dtype_to_string(operand.dtype)} "
f"to {dtype_to_string(new_dtype)}")
return new_dtype

def _convert_element_type_weak_type_rule(operand, *, new_dtype, weak_type):
Expand Down Expand Up @@ -5015,4 +5020,12 @@ def handler(bufs):
def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
return hlo_sharding

@staticmethod
def convert_from(bint_dtype, other_dtype) -> bool:
return other_dtype in (np.dtype('int32'), np.dtype('int64'))

@staticmethod
def convert_to(other_dtype, bint_dtype) -> bool:
return other_dtype in (np.dtype('int32'), np.dtype('int64'))

core.bint._rules = BIntRules
12 changes: 10 additions & 2 deletions jax/_src/prng.py
Expand Up @@ -599,8 +599,16 @@ def tangent_dtype(_):
# tangents, our ad.replace_float0s in custom_jvp/vjp means passing in zeros
# like the primal to user rules
@staticmethod
def zero(aval):
return lax_internal.zeros_like_shaped_array(aval.update(dtype=dtypes.float0))
def zero(_):
return np.zeros((), dtypes.float0)

@staticmethod
def convert_from(key_dtype, other_dtype) -> bool:
return False

@staticmethod
def convert_to(other_dtype, key_dtype) -> bool:
return False


class KeyTy(dtypes.ExtendedDType):
Expand Down
75 changes: 75 additions & 0 deletions tests/dtypes_test.py
Expand Up @@ -13,8 +13,10 @@
# limitations under the License.


import dataclasses
import enum
import functools
from functools import partial
import itertools
import operator

Expand Down Expand Up @@ -376,6 +378,79 @@ def testDefaultDtypes(self):
self.assertEqual(dtypes.float_, np.float32 if precision == '32' else np.float64)
self.assertEqual(dtypes.complex_, np.complex64 if precision == '32' else np.complex128)

def test_custom_tangent_dtype(self):
from jax._src import core

class scale(dtypes.extended):
pass

class ScalesTyRules:
@staticmethod
def physical_element_aval(dtype) -> core.ShapedArray:
return core.ShapedArray((), dtype.float_dtype)

@staticmethod
def global_sharded_result_handler(aval, sharding, committed, is_from_xla):
raise NotImplementedError("convert back under the jit")

@staticmethod
def add(dt, x, y):
fromscale = partial(jax.lax.convert_element_type, new_dtype=dt.float_dtype)
toscale = partial(jax.lax.convert_element_type, new_dtype=dt)
return toscale(jax.lax.max(fromscale(x), fromscale(y)))

@staticmethod
def zero(dt):
neginf = np.array(-np.inf if dtypes.supports_inf(dt.float_dtype)
else dtypes.finfo(dt.float_dtype).min, dt.float_dtype)
return jax.lax.convert_element_type(neginf, dt)

@staticmethod
def convert_from(dtype, other_dtype) -> bool:
return dtype.float_dtype == other_dtype

@staticmethod
def convert_to(other_dtype, dtype) -> bool:
return dtype.float_dtype == other_dtype

@dataclasses.dataclass(frozen=True)
class ScaleTy(dtypes.ExtendedDType):
float_dtype: dtypes.DType
name: str = 'scale'
_rules: type = ScalesTyRules
type: type = scale

@jax.custom_vjp
def g(x):
return x
def g_fwd(x):
return x, None
def g_bwd(_, ct):
ct = jax.lax.convert_element_type(ct, ScaleTy(dtypes.float8_e5m2))
return ct,
g.defvjp(g_fwd, g_bwd)

@jax.custom_vjp
def convert(x):
return x
def convert_fwd(x):
return x, None
def convert_bwd(_, ct):
ct = jax.lax.convert_element_type(ct, ct.dtype.float_dtype)
return ct,
convert.defvjp(convert_fwd, convert_bwd)

@jax.jit
def f(x):
x = convert(x)
x = g(x) + g(x)
return x

x = jnp.array(3., dtypes.float8_e5m2)
out = jax.grad(f)(x)
self.assertAllClose(out, 1., check_dtypes=False)
self.assertTrue(dtypes.issubdtype(ScaleTy(dtypes.float8_e5m2), scale))


class TestPromotionTables(jtu.JaxTestCase):

Expand Down
4 changes: 3 additions & 1 deletion tests/random_test.py
Expand Up @@ -1399,8 +1399,10 @@ def test_errors(self):
jnp.negative(key)
with self.assertRaisesRegex(TypeError, "neg does not accept dtype key<fry>"):
-key
with self.assertRaisesRegex(ValueError, "Cannot call convert_element_type on dtype key<fry>"):
with self.assertRaisesRegex(ValueError, "Cannot convert_element_type from key<fry> to int64"):
lax.convert_element_type(key, int)
with self.assertRaisesRegex(ValueError, "Cannot convert_element_type from int32 to key<fry>"):
lax.convert_element_type(np.int32(0), key.dtype)

def test_eval_shape(self):
key = random.key(1701)
Expand Down

0 comments on commit 05da18a

Please sign in to comment.