Skip to content

Commit

Permalink
Copybara import of the project:
Browse files Browse the repository at this point in the history
--
b243ea7 by Jake VanderPlas <jakevdp@google.com>:

Rename opaque dtype to extended dtype.

This includes three deprecations:
 - jax.core.is_opaque_dtype(dt) is deprecated in favor of jnp.issubdtype(dt, jax.dtypes.extended)
 - jax.core.has_opaque_dtype(x) is deprecated in favor of jnp.issubdtype(x.dtype, jax.dtypes.extended)
 - the allow_opaque_dtype argument to jax.core.canonicalize_dtype is now allow_extended_dtype
Because jax.core is explicitly excluded from the API deprecation policy, these changes will not be
subject to a standard 3-month deprecation period.

COPYBARA_INTEGRATE_REVIEW=#16824 from jakevdp:extended-dtype b243ea7
PiperOrigin-RevId: 550674205
  • Loading branch information
jakevdp authored and jax authors committed Jul 24, 2023
1 parent c6fa3d9 commit b4132b4
Show file tree
Hide file tree
Showing 19 changed files with 139 additions and 112 deletions.
14 changes: 7 additions & 7 deletions jax/_src/api.py
Expand Up @@ -773,7 +773,7 @@ def _check_input_dtype_revderiv(name, holomorphic, allow_int, x):
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError(f"{name} with holomorphic=True requires inputs with complex dtype, "
f"but got {aval.dtype.name}.")
if (dtypes.is_opaque_dtype(aval.dtype) or
if (dtypes.issubdtype(aval.dtype, dtypes.extended) or
dtypes.issubdtype(aval.dtype, np.integer) or
dtypes.issubdtype(aval.dtype, np.bool_)):
if not allow_int:
Expand All @@ -788,7 +788,7 @@ def _check_input_dtype_revderiv(name, holomorphic, allow_int, x):

def _check_output_dtype_revderiv(name, holomorphic, x):
aval = core.get_aval(x)
if dtypes.is_opaque_dtype(aval.dtype):
if dtypes.issubdtype(aval.dtype, dtypes.extended):
raise TypeError(
f"{name} with output element type {aval.dtype.name}")
if holomorphic:
Expand Down Expand Up @@ -874,7 +874,7 @@ def jacfun(*args, **kwargs):
def _check_input_dtype_jacfwd(holomorphic: bool, x: Any) -> None:
dispatch.check_arg(x)
aval = core.get_aval(x)
if dtypes.is_opaque_dtype(aval.dtype):
if dtypes.issubdtype(aval.dtype, dtypes.extended):
raise TypeError(
f"jacfwd with input element type {aval.dtype.name}")
if holomorphic:
Expand Down Expand Up @@ -2585,7 +2585,7 @@ def _device_put_sharded(*xs):
stacked_aval = avals[0].update(shape=(len(devices),) + avals[0].shape)
sharding_spec = sharding_specs.create_pmap_sharding_spec(stacked_aval.shape)
sharding = PmapSharding(np.array(devices), sharding_spec)
if dtypes.is_opaque_dtype(stacked_aval.dtype):
if dtypes.issubdtype(stacked_aval.dtype, dtypes.extended):
return stacked_aval.dtype._rules.device_put_sharded(xs, stacked_aval, sharding, devices)
return pxla.batched_device_put(stacked_aval, sharding, xs, list(devices))

Expand Down Expand Up @@ -2635,7 +2635,7 @@ def _device_put_replicated(x):
sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape)
buf = device_put(x, devices[0])
sharding = PmapSharding(np.array(devices), sharding_spec)
if dtypes.is_opaque_dtype(aval.dtype):
if dtypes.issubdtype(aval.dtype, dtypes.extended):
return aval.dtype._rules.device_put_replicated(buf, aval, sharding, devices)
assert len(xla.aval_to_xla_shapes(aval)) == 1
return pxla.batched_device_put(aval, sharding, [buf] * len(devices), devices)
Expand Down Expand Up @@ -2711,7 +2711,7 @@ def __init__(self, shape, dtype, named_shape=None, sharding=None):
self.shape = tuple(shape)
if dtype is None:
raise ValueError("ShapeDtypeStruct: dtype must be specified.")
self.dtype = dtype if dtypes.is_opaque_dtype(dtype) else np.dtype(dtype)
self.dtype = dtype if dtypes.issubdtype(dtype, dtypes.extended) else np.dtype(dtype)
if sharding is not None and not isinstance(sharding, Sharding):
raise ValueError(
"sharding should be an instance of `jax.sharding.Sharding`. "
Expand Down Expand Up @@ -2750,7 +2750,7 @@ def __hash__(self):
return hash((self.shape, self.dtype, named, self.sharding))

core.pytype_aval_mappings[ShapeDtypeStruct] = (
lambda x: ShapedArray(x.shape, dtypes.canonicalize_dtype(x.dtype, allow_opaque_dtype=True),
lambda x: ShapedArray(x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True),
weak_type=False, named_shape=x.named_shape))

@api_boundary
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/api_util.py
Expand Up @@ -567,7 +567,7 @@ def _shaped_abstractify_slow(x):
weak_type = getattr(x, 'weak_type', False)
named_shape = getattr(x, 'named_shape', {})
if hasattr(x, 'dtype'):
dtype = dtypes.canonicalize_dtype(x.dtype, allow_opaque_dtype=True)
dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True)
else:
raise TypeError(
f"Cannot interpret value of type {type(x)} as an abstract array; it "
Expand All @@ -592,14 +592,14 @@ def _numpy_array_abstractify(x: np.ndarray) -> ShapedArray:
dtype = x.dtype
dtypes.check_valid_dtype(dtype)
return ShapedArray(x.shape,
dtypes.canonicalize_dtype(dtype, allow_opaque_dtype=True))
dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True))
_shaped_abstractify_handlers[np.ndarray] = _numpy_array_abstractify

def _np_scalar_abstractify(x: np.generic) -> ShapedArray:
dtype = np.dtype(x)
dtypes.check_valid_dtype(dtype)
return ShapedArray(np.shape(x),
dtypes.canonicalize_dtype(dtype, allow_opaque_dtype=True))
dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True))
_shaped_abstractify_handlers.update((t, _np_scalar_abstractify)
for t in numpy_scalar_types)

Expand Down
8 changes: 4 additions & 4 deletions jax/_src/array.py
Expand Up @@ -600,7 +600,7 @@ def make_array_from_callback(
for device in sharding.addressable_devices
]
aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False)
if dtypes.is_opaque_dtype(aval.dtype):
if dtypes.issubdtype(aval.dtype, dtypes.extended):
return aval.dtype._rules.make_sharded_array(aval, sharding, arrays, committed=True)
return ArrayImpl(aval, sharding, arrays, committed=True)

Expand Down Expand Up @@ -661,7 +661,7 @@ def make_array_from_single_device_arrays(
# All input arrays should be committed. Checking it is expensive on
# single-controller systems.
aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False)
if dtypes.is_opaque_dtype(aval.dtype):
if dtypes.issubdtype(aval.dtype, dtypes.extended):
return aval.dtype._rules.make_sharded_array(aval, sharding, arrays, committed=True)
# TODO(phawkins): ideally the cast() could be checked. Revisit this after
# removing DeviceArray.
Expand Down Expand Up @@ -785,7 +785,7 @@ def _array_global_result_handler(global_aval, out_sharding, committed,
is_out_sharding_from_xla):
if global_aval.dtype == dtypes.float0:
return lambda _: np.zeros(global_aval.shape, dtypes.float0) # type: ignore
if dtypes.is_opaque_dtype(global_aval.dtype):
if dtypes.issubdtype(global_aval.dtype, dtypes.extended):
return global_aval.dtype._rules.global_sharded_result_handler(
global_aval, out_sharding, committed, is_out_sharding_from_xla)
return xc.array_result_handler(
Expand All @@ -800,7 +800,7 @@ def _array_global_result_handler(global_aval, out_sharding, committed,
def _array_local_result_handler(aval, sharding, indices):
if aval.dtype == dtypes.float0:
return lambda _: np.zeros(aval.shape, dtypes.float0) # type: ignore
if dtypes.is_opaque_dtype(aval.dtype):
if dtypes.issubdtype(aval.dtype, dtypes.extended):
return aval.dtype._rules.local_sharded_result_handler(
aval, sharding, indices)
return xc.array_result_handler(
Expand Down
32 changes: 16 additions & 16 deletions jax/_src/core.py
Expand Up @@ -1398,26 +1398,26 @@ def concrete_dim_or_error(val: Any, context=""):
else:
return concrete_or_error(operator.index, val, context=context)

### Opaque dtypes
### Extended dtypes
#
# Opaque dtypes are JAX-specific dtypes that allow us to represent logical
# Extended dtypes are JAX-specific dtypes that allow us to represent logical
# arrays of element types that do not have an obvious direct correspondence
# to ("physical") arrays of basic types in a compiler. In particular, their
# element types differ from those of XLA and NumPy (e.g. int32). These dtypes
# are only known to JAX. Their implementation is determined by:
# a) an object representing the opaque dtype, accessible via the `dtype`
# a) an object representing the extended dtype, accessible via the `dtype`
# attribute on corresponding JAX arrays and, internally, on avals such
# as ShapedArrays that correspond to such JAX arrays;
# b) a set of rules, available via a private attribute on the opaque dtype
# b) a set of rules, available via a private attribute on the extended dtype
# object in (a).
# The rules in (b) tell JAX internals how to ground out the element
# type for interaction with the compiler and runtime, e.g. when lowering
# to the compiler's language.


# TODO(frostig): update inliners of the four functions below to call them
# TODO(jakevdp): remove this function once it's unused downstream.
def has_opaque_dtype(x: Any) -> bool:
return dtypes.is_opaque_dtype(get_aval(x).dtype)
return dtypes.issubdtype(get_aval(x).dtype, dtypes.extended)

@overload
def physical_aval(aval: ShapedArray) -> ShapedArray: ...
Expand All @@ -1428,7 +1428,7 @@ def physical_aval(aval: AbstractValue) -> AbstractValue: ...

def physical_aval(aval):
aval_dtype = getattr(aval, 'dtype', None)
if aval_dtype and dtypes.is_opaque_dtype(aval_dtype):
if aval_dtype and dtypes.issubdtype(aval_dtype, dtypes.extended):
ctor = type(aval)
aval_shape = getattr(aval, 'shape', None)
assert aval_shape is not None, (ctor, aval)
Expand All @@ -1439,14 +1439,14 @@ def physical_aval(aval):
return aval

def _short_dtype_name(dtype) -> str:
if dtypes.issubdtype(dtype, dtypes.opaque):
if dtypes.issubdtype(dtype, dtypes.extended):
return str(dtype)
else:
return (dtype.name.replace('float', 'f').replace('uint' , 'u')
.replace('int' , 'i').replace('complex', 'c'))

def _dtype_object(dtype):
return dtype if dtypes.issubdtype(dtype, dtypes.opaque) else np.dtype(dtype)
return dtype if dtypes.issubdtype(dtype, dtypes.extended) else np.dtype(dtype)

class UnshapedArray(AbstractValue):
__slots__ = ['dtype', 'weak_type']
Expand Down Expand Up @@ -1653,11 +1653,11 @@ def str_short(self, short_dtypes=False) -> str:
_complex = concretization_function_error(complex, True)

def primal_dtype_to_tangent_dtype(primal_dtype):
# TODO(frostig,mattjj): determines that all opaque dtypes have
# float0 tangent type, which works fine for all our current opaque
# dtype applications. We may some day want to delegate this
# decision to the dtype rules.
if (dtypes.is_opaque_dtype(primal_dtype) or
# TODO(frostig,mattjj): determines that all extended dtypes have
# float0 tangent type, which works fine for all our current
# extended dtype applications. We may some day want to delegate
# this decision to the dtype rules.
if (dtypes.issubdtype(primal_dtype, dtypes.extended) or
not dtypes.issubdtype(primal_dtype, np.inexact)):
return dtypes.float0
else:
Expand Down Expand Up @@ -1780,12 +1780,12 @@ def __len__(self):
x._data)

@dataclass(frozen=True, eq=True)
class bint(dtypes.OpaqueDType):
class bint(dtypes.ExtendedDType):
bound: int

@property
def type(self) -> type:
return dtypes.opaque
return dtypes.extended

@property
def name(self) -> str:
Expand Down
60 changes: 32 additions & 28 deletions jax/_src/dtypes.py
Expand Up @@ -46,9 +46,8 @@
FLAGS = flags.FLAGS


# TODO(jakevdp): rename opaque dtypes to something more user-friendly
class opaque(np.generic):
"""Scalar class for opaque dtypes.
class extended(np.generic):
"""Scalar class for extended dtypes.
This is an abstract class that should never be instantiated, but rather
exists for the sake of `jnp.issubdtype`.
Expand All @@ -57,13 +56,13 @@ class opaque(np.generic):
>>> from jax import random
>>> from jax._src import dtypes
>>> key = random.key(0)
>>> jnp.issubdtype(key.dtype, dtypes.opaque)
>>> jnp.issubdtype(key.dtype, dtypes.extended)
True
"""
pass


class prng_key(opaque):
class prng_key(extended):
"""Scalar class for PRNG Key dtypes.
This is an abstract class that should never be instantiated, but rather
Expand All @@ -79,17 +78,16 @@ class prng_key(opaque):
pass


# TODO(jakevdp): rename opaque dtypes to something more user-friendly
class OpaqueDType(metaclass=abc.ABCMeta):
"""Abstract Base Class for opaque dtypes"""
class ExtendedDType(metaclass=abc.ABCMeta):
"""Abstract Base Class for extended dtypes"""
@property
@abc.abstractmethod
def type(self) -> type: ...


# TODO(jakevdp): remove this function once it's unused downstream.
def is_opaque_dtype(dtype: Any) -> bool:
# TODO(vanderplas, frostig): remove in favor of inlining `issubdtype`
return issubdtype(dtype, opaque)
return issubdtype(dtype, extended)

# fp8 support
float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz
Expand Down Expand Up @@ -185,11 +183,11 @@ def to_complex_dtype(dtype: DTypeLike) -> DType:


@functools.cache
def _canonicalize_dtype(x64_enabled: bool, allow_opaque_dtype: bool, dtype: Any) -> Union[DType, OpaqueDType]:
if is_opaque_dtype(dtype):
if not allow_opaque_dtype:
raise ValueError(f"Internal: canonicalize_dtype called on opaque dtype {dtype} "
"with allow_opaque_dtype=False")
def _canonicalize_dtype(x64_enabled: bool, allow_extended_dtype: bool, dtype: Any) -> Union[DType, ExtendedDType]:
if issubdtype(dtype, extended):
if not allow_extended_dtype:
raise ValueError(f"Internal: canonicalize_dtype called on extended dtype {dtype} "
"with allow_extended_dtype=False")
return dtype
try:
dtype_ = np.dtype(dtype)
Expand All @@ -202,14 +200,20 @@ def _canonicalize_dtype(x64_enabled: bool, allow_opaque_dtype: bool, dtype: Any)
return _dtype_to_32bit_dtype.get(dtype_, dtype_)

@overload
def canonicalize_dtype(dtype: Any, allow_opaque_dtype: Literal[False] = False) -> DType: ...
def canonicalize_dtype(dtype: Any, allow_extended_dtype: Literal[False] = False, allow_opaque_dtype: Any = None) -> DType: ...

@overload
def canonicalize_dtype(dtype: Any, allow_opaque_dtype: bool = False) -> Union[DType, OpaqueDType]: ...
def canonicalize_dtype(dtype: Any, allow_extended_dtype: bool = False, allow_opaque_dtype: Any = None) -> Union[DType, ExtendedDType]: ...

def canonicalize_dtype(dtype: Any, allow_opaque_dtype: bool = False) -> Union[DType, OpaqueDType]:
def canonicalize_dtype(dtype: Any, allow_extended_dtype: bool = False, allow_opaque_dtype: Any = None) -> Union[DType, ExtendedDType]:
"""Convert from a dtype to a canonical dtype based on config.x64_enabled."""
return _canonicalize_dtype(config.x64_enabled, allow_opaque_dtype, dtype) # type: ignore[bad-return-type]
if allow_opaque_dtype is not None:
# TODO(jakevdp): complete the deprecation cycle (Deprecated July 24 2023).
warnings.warn(
"allow_opaque_dtype argument is deprecated; use allow_extended_dtype.",
DeprecationWarning)
allow_extended_dtype = allow_opaque_dtype
return _canonicalize_dtype(config.x64_enabled, allow_extended_dtype, dtype) # type: ignore[bad-return-type]

# Default dtypes corresponding to Python scalars.
python_scalar_dtypes : dict[type, DType] = {
Expand Down Expand Up @@ -315,9 +319,9 @@ def issubdtype(a: DTypeLike, b: DTypeLike) -> bool:
This is like :func:`numpy.issubdtype`, but can handle dtype extensions such as
:obj:`jax.dtypes.bfloat16`.
"""
if isinstance(a, OpaqueDType):
if isinstance(a, ExtendedDType):
return _issubclass(a.type, b)
elif _issubclass(b, opaque):
elif _issubclass(b, extended):
return False
# Canonicalizes all concrete types to np.dtype instances
a = a if _is_typeclass(a) else np.dtype(a)
Expand Down Expand Up @@ -573,26 +577,26 @@ def dtype(x: Any, *, canonicalize: bool = False) -> DType:
dt = python_scalar_dtypes[x]
elif type(x) in python_scalar_dtypes:
dt = python_scalar_dtypes[type(x)]
elif is_opaque_dtype(getattr(x, 'dtype', None)):
elif issubdtype(getattr(x, 'dtype', None), extended):
dt = x.dtype
else:
try:
dt = np.result_type(x)
except TypeError as err:
raise TypeError(f"Cannot determine dtype of {x}") from err
if dt not in _jax_dtype_set and not is_opaque_dtype(dt):
if dt not in _jax_dtype_set and not issubdtype(dt, extended):
raise TypeError(f"Value '{x}' with dtype {dt} is not a valid JAX array "
"type. Only arrays of numeric types are supported by JAX.")
# TODO(jakevdp): fix return type annotation and remove this ignore.
return canonicalize_dtype(dt, allow_opaque_dtype=True) if canonicalize else dt # type: ignore[return-value]
return canonicalize_dtype(dt, allow_extended_dtype=True) if canonicalize else dt # type: ignore[return-value]

def _lattice_result_type(*args: Any) -> tuple[DType, bool]:
dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
if len(dtypes) == 1:
out_dtype = dtypes[0]
out_weak_type = weak_types[0]
elif len(set(dtypes)) == 1 and not all(weak_types):
# Trivial promotion case. This allows opaque dtypes through.
# Trivial promotion case. This allows extended dtypes through.
out_dtype = dtypes[0]
out_weak_type = False
elif all(weak_types) and config.jax_numpy_dtype_promotion != 'strict':
Expand Down Expand Up @@ -632,18 +636,18 @@ def result_type(*args: Any, return_weak_type_flag: bool = False) -> Union[DType,
"""
if len(args) == 0:
raise ValueError("at least one array or dtype is required")
dtype: DType | OpaqueDType
dtype: DType | ExtendedDType
dtype, weak_type = _lattice_result_type(*(float_ if arg is None else arg for arg in args))
if weak_type:
dtype = canonicalize_dtype(
_default_types['f' if dtype in _custom_float_dtypes else dtype.kind])
else:
dtype = canonicalize_dtype(dtype, allow_opaque_dtype=True)
dtype = canonicalize_dtype(dtype, allow_extended_dtype=True)
# TODO(jakevdp): fix return type annotation and remove this ignore.
return (dtype, weak_type) if return_weak_type_flag else dtype # type: ignore[return-value]

def check_user_dtype_supported(dtype, fun_name=None):
if is_opaque_dtype(dtype):
if issubdtype(dtype, extended):
return
# Avoid using `dtype in [...]` because of numpy dtype equality overloading.
if isinstance(dtype, type) and dtype in {bool, int, float, builtins.complex}:
Expand Down

0 comments on commit b4132b4

Please sign in to comment.