Skip to content

Commit

Permalink
Add jax_numpy_dtype_promotion='strict' mode
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed May 26, 2022
1 parent 563a633 commit ceae6fe
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 42 deletions.
1 change: 1 addition & 0 deletions jax/__init__.py
Expand Up @@ -47,6 +47,7 @@
log_compiles as log_compiles,
default_matmul_precision as default_matmul_precision,
default_prng_impl as default_prng_impl,
numpy_dtype_promotion as numpy_dtype_promotion,
numpy_rank_promotion as numpy_rank_promotion,
jax2tf_associative_scan_reductions as jax2tf_associative_scan_reductions,
transfer_guard as transfer_guard,
Expand Down
18 changes: 17 additions & 1 deletion jax/_src/config.py
Expand Up @@ -350,7 +350,8 @@ def _trace_context(self):
Values included in this set should also most likely be included in
the C++ JIT state, which is handled separately."""
return (self.x64_enabled, self.jax_numpy_rank_promotion,
self.jax_default_matmul_precision, self.jax_dynamic_shapes)
self.jax_default_matmul_precision, self.jax_dynamic_shapes,
self.jax_numpy_dtype_promotion)

class NoDefault: pass
no_default = NoDefault()
Expand Down Expand Up @@ -437,6 +438,7 @@ def __setattr__(self, name, val):

class GlobalJitState(NamedTuple):
numpy_rank_promotion: Optional[str] = None
numpy_dtype_promotion: Optional[str] = None
default_matmul_precision: Optional[Any] = None
dynamic_shapes: bool = False

Expand All @@ -450,6 +452,7 @@ def update_global_jit_state(**kw):
class ThreadLocalJitState(NamedTuple):
dynamic_trace_state: Optional[Any] = None
numpy_rank_promotion: Optional[str] = None
numpy_dtype_promotion: Optional[str] = None
default_matmul_precision: Optional[Any] = None
dynamic_shapes: bool = False

Expand Down Expand Up @@ -627,6 +630,19 @@ def update_thread_local_jit_state(**kw):
'This is a temporary flag that will be used during the process '
'of deprecating the ``jax_enable_x64`` flag.'))

numpy_dtype_promotion = config.define_enum_state(
name='jax_numpy_dtype_promotion',
enum_values=['standard', 'strict'],
default='standard',
help=('Specify the rules used for implicit type promotion in operations '
'between arrays. Options are "standard" or "strict"; in strict-mode, '
'binary operations between arrays of differing strongly-specified '
'dtypes will result in an error.'),
update_global_hook=lambda val: \
update_global_jit_state(numpy_dtype_promotion=val),
update_thread_local_hook=lambda val: \
update_thread_local_jit_state(numpy_dtype_promotion=val))

def _update_x64_global(val):
lib.jax_jit.global_state().enable_x64 = val

Expand Down
115 changes: 77 additions & 38 deletions jax/_src/dtypes.py
Expand Up @@ -21,7 +21,7 @@


import functools
from typing import Any, Dict
from typing import Any, Dict, List

import numpy as np

Expand Down Expand Up @@ -243,24 +243,29 @@ def issubdtype(a, b):

# Enumeration of all valid JAX types in order.
_weak_types = [int, float, complex]
_jax_types = [
np.dtype('bool'),
np.dtype('uint8'),
np.dtype('uint16'),
np.dtype('uint32'),
np.dtype('uint64'),
np.dtype('int8'),
np.dtype('int16'),
np.dtype('int32'),
np.dtype('int64'),
np.dtype(bfloat16),
np.dtype('float16'),
np.dtype('float32'),
np.dtype('float64'),
np.dtype('complex64'),
np.dtype('complex128'),
_bool_types: List[np.dtype] = [np.dtype(bool)]
_int_types: List[np.dtype] = [
np.dtype('uint8'),
np.dtype('uint16'),
np.dtype('uint32'),
np.dtype('uint64'),
np.dtype('int8'),
np.dtype('int16'),
np.dtype('int32'),
np.dtype('int64'),
]
_jax_dtype_set = set(_jax_types) | {float0}
_float_types: List[np.dtype] = [
np.dtype(bfloat16),
np.dtype('float16'),
np.dtype('float32'),
np.dtype('float64'),
]
_complex_types: List[np.dtype] = [
np.dtype('complex64'),
np.dtype('complex128'),
]
_jax_types = _bool_types + _int_types + _float_types + _complex_types
_jax_dtype_set = {float0, *_bool_types, *_int_types, *_float_types, *_complex_types}

def _jax_type(dtype, weak_type):
"""Return the jax type for a dtype and weak type."""
Expand All @@ -270,23 +275,37 @@ def _dtype_and_weaktype(value):
"""Return a (dtype, weak_type) tuple for the given input."""
return dtype(value), any(value is typ for typ in _weak_types) or is_weakly_typed(value)

def _type_promotion_lattice():
def _type_promotion_lattice(jax_numpy_dtype_promotion):
"""
Return the type promotion lattice in the form of a DAG.
This DAG maps each type to its immediately higher type on the lattice.
"""
b1, u1, u2, u4, u8, i1, i2, i4, i8, bf, f2, f4, f8, c4, c8 = _jax_types
b1, = _bool_types
u1, u2, u4, u8, i1, i2, i4, i8 = _int_types
bf, f2, f4, f8 = _float_types
c4, c8 = _complex_types
i_, f_, c_ = _weak_types
return {
b1: [i_],
u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_],
i_: [u1, i1], i1: [i2], i2: [i4], i4: [i8], i8: [f_],
f_: [bf, f2, c_], bf: [f4], f2: [f4], f4: [f8, c4], f8: [c8],
c_: [c4], c4: [c8], c8: [],
}

def _make_lattice_upper_bounds():
lattice = _type_promotion_lattice()
if jax_numpy_dtype_promotion == 'standard':
return {
b1: [i_],
u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_],
i_: [u1, i1], i1: [i2], i2: [i4], i4: [i8], i8: [f_],
f_: [bf, f2, c_], bf: [f4], f2: [f4], f4: [f8, c4], f8: [c8],
c_: [c4], c4: [c8], c8: [],
}
elif jax_numpy_dtype_promotion == 'strict':
return {
i_: [f_] + _int_types,
f_: [c_] + _float_types,
c_: _complex_types,
**{t: [] for t in _jax_types}
}
else:
raise ValueError(
f"Unexpected value of jax_numpy_dtype_promotion={jax_numpy_dtype_promotion!r}")

def _make_lattice_upper_bounds(jax_numpy_dtype_promotion):
lattice = _type_promotion_lattice(jax_numpy_dtype_promotion)
upper_bounds = {node: {node} for node in lattice}
for n in lattice:
while True:
Expand All @@ -297,10 +316,17 @@ def _make_lattice_upper_bounds():
break
upper_bounds[n] |= new_upper_bounds
return upper_bounds
_lattice_upper_bounds = _make_lattice_upper_bounds()

_lattice_upper_bounds = {
'standard': _make_lattice_upper_bounds('standard'),
'strict': _make_lattice_upper_bounds('strict'),
}

class TypePromotionError(ValueError):
pass

@functools.lru_cache(512) # don't use util.memoize because there is no X64 dependence.
def _least_upper_bound(*nodes):
def _least_upper_bound(jax_numpy_dtype_promotion, *nodes):
"""Compute the least upper bound of a set of nodes.
Args:
Expand All @@ -327,13 +353,23 @@ def _least_upper_bound(*nodes):
# ∀ c ∈ N: CUB(N) ⊆ UB(c)
# So if N ∩ CUB(N) is nonempty, if follows that LUB(N) = N ∩ CUB(N).
N = set(nodes)
UB = _lattice_upper_bounds
UB = _lattice_upper_bounds[jax_numpy_dtype_promotion]
CUB = set.intersection(*(UB[n] for n in N))
LUB = (CUB & N) or {c for c in CUB if CUB.issubset(UB[c])}
if len(LUB) == 1:
return LUB.pop()
elif len(LUB) == 0:
# TODO(jakevdp): surface some error about jax_numpy_rank_promotion flag.
raise TypePromotionError(
f"Input dtypes {tuple(str(n) for n in nodes)} have no available implicit dtype "
"promotion path. Try explicitly casting inputs to the desired output type.")
else:
raise ValueError(f"{nodes} do not have a unique least upper bound.")
# If we get here, it means the lattice is ill-formed.
raise TypePromotionError(
f"Internal Type Promotion error: {nodes} do not have a unique least upper bound "
f"on the specified lattice; options are {LUB}. If you see this error, please "
"report it to the JAX maintainers."
)

def promote_types(a, b):
"""Returns the type to which a binary operation should cast its arguments.
Expand All @@ -351,7 +387,7 @@ def promote_types(a, b):
# object identity, not object equality, due to the behavior of np.dtype.__eq__
a = a if any(a is t for t in _weak_types) else np.dtype(a)
b = b if any(b is t for t in _weak_types) else np.dtype(b)
return np.dtype(_least_upper_bound(a, b))
return np.dtype(_least_upper_bound(config.jax_numpy_dtype_promotion, a, b))

def is_weakly_typed(x):
try:
Expand Down Expand Up @@ -388,11 +424,14 @@ def _lattice_result_type(*args):
# If all inputs are weakly typed, we compute the bound of the strongly-typed
# counterparts and apply the weak type at the end. This avoids returning the
# incorrect result with non-canonical weak types (e.g. weak int16).
if all(weak_types):
result_type = _least_upper_bound(*{_jax_type(dtype, False) for dtype in dtypes})
# TODO(jakevdp): explore removing this special case.
if all(weak_types) and config.jax_numpy_dtype_promotion != 'strict':
result_type = _least_upper_bound(config.jax_numpy_dtype_promotion,
*{_jax_type(dtype, False) for dtype in dtypes})
return dtype(result_type), True
else:
result_type = _least_upper_bound(*{_jax_type(d, w) for d, w in zip(dtypes, weak_types)})
result_type = _least_upper_bound(config.jax_numpy_dtype_promotion,
*{_jax_type(d, w) for d, w in zip(dtypes, weak_types)})
return dtype(result_type), any(result_type is t for t in _weak_types)

def result_type(*args, return_weak_type_flag=False):
Expand Down
51 changes: 48 additions & 3 deletions tests/dtypes_test.py
Expand Up @@ -29,7 +29,7 @@
from jax._src import test_util as jtu
from jax._src.lax import lax as lax_internal

from jax.config import config
from jax._src.config import config
config.parse_flags_with_absl()

FLAGS = config.FLAGS
Expand Down Expand Up @@ -136,7 +136,44 @@ def testBinaryPromotion(self, swap, jit):
self.assertTrue(isinstance(z, jnp.ndarray), msg=(x, y, z))
self.assertEqual(z.dtype, dtypes.canonicalize_dtype(dtype), msg=(x, y, z))

def testPromoteDtypes(self):
@jax.numpy_dtype_promotion('strict')
def testPromoteDtypesStrict(self):
# Check that strong types have diagonal promotion table:
for t1 in all_dtypes:
for t2 in all_dtypes:
if t1 == t2:
self.assertEqual(t1, dtypes.promote_types(t1, t2))
else:
self.assertRaises(dtypes.TypePromotionError, dtypes.promote_types, t1, t2)

# Promotion between weak types matches numpy promotion
for t1 in [int, float, complex]:
for t2 in [int, float, complex]:
py_result = type(t1(0) + t2(0))
lattice_dtype, lattice_weak_type = dtypes._lattice_result_type(t1, t2)
self.assertTrue(lattice_weak_type)
self.assertEqual(lattice_dtype, np.dtype(py_result))

# Check that weak promotion only works if strong value is not cast:
for t1 in bool_dtypes:
self.assertRaises(dtypes.TypePromotionError, dtypes.promote_types, t1, int)
self.assertRaises(dtypes.TypePromotionError, dtypes.promote_types, t1, float)
self.assertRaises(dtypes.TypePromotionError, dtypes.promote_types, t1, complex)
for t1 in signed_dtypes + unsigned_dtypes:
self.assertEqual(dtypes.promote_types(t1, int), t1)
self.assertRaises(dtypes.TypePromotionError, dtypes.promote_types, t1, float)
self.assertRaises(dtypes.TypePromotionError, dtypes.promote_types, t1, complex)
for t1 in float_dtypes:
self.assertEqual(dtypes.promote_types(t1, int), t1)
self.assertEqual(dtypes.promote_types(t1, float), t1)
self.assertRaises(dtypes.TypePromotionError, dtypes.promote_types, t1, complex)
for t1 in complex_dtypes:
self.assertEqual(dtypes.promote_types(t1, int), t1)
self.assertEqual(dtypes.promote_types(t1, float), t1)
self.assertEqual(dtypes.promote_types(t1, complex), t1)

@jax.numpy_dtype_promotion('standard')
def testPromoteDtypesStandard(self):
for t1 in all_dtypes:
self.assertEqual(t1, dtypes.promote_types(t1, t1))

Expand All @@ -163,7 +200,15 @@ def testPromoteDtypes(self):
np_float_dtypes + complex_dtypes]:
for t1, t2 in itertools.combinations(groups, 2):
self.assertEqual(np.promote_types(t1, t2),
dtypes.promote_types(t1, t2))
dtypes.promote_types(t1, t2))

# Promotion between weak types matches numpy promotion
for t1 in [int, float, complex]:
for t2 in [int, float, complex]:
py_result = type(t1(0) + t2(0))
lattice_dtype, lattice_weak_type = dtypes._lattice_result_type(t1, t2)
self.assertTrue(lattice_weak_type)
self.assertEqual(lattice_dtype, np.dtype(py_result))

@parameterized.parameters([jnp.bool_, jnp.int32, jnp.bfloat16, jnp.float32, jnp.complex64])
def testScalarInstantiation(self, scalar_type):
Expand Down

0 comments on commit ceae6fe

Please sign in to comment.