From b4132b4c500564225af24316a12798e6c89e531f Mon Sep 17 00:00:00 2001 From: Jake Vanderplas Date: Mon, 24 Jul 2023 14:29:37 -0700 Subject: [PATCH] Copybara import of the project: -- b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b by Jake VanderPlas : Rename opaque dtype to extended dtype. This includes three deprecations: - jax.core.is_opaque_dtype(dt) is deprecated in favor of jnp.issubdtype(dt, jax.dtypes.extended) - jax.core.has_opaque_dtype(x) is deprecated in favor of jnp.issubdtype(x.dtype, jax.dtypes.extended) - the allow_opaque_dtype argument to jax.core.canonicalize_dtype is now allow_extended_dtype Because jax.core is explicitly excluded from the API deprecation policy, these changes will not be subject to a standard 3-month deprecation period. COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16824 from jakevdp:extended-dtype b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b PiperOrigin-RevId: 550674205 --- jax/_src/api.py | 14 ++++---- jax/_src/api_util.py | 6 ++-- jax/_src/array.py | 8 ++--- jax/_src/core.py | 32 ++++++++--------- jax/_src/dtypes.py | 60 ++++++++++++++++--------------- jax/_src/interpreters/mlir.py | 12 +++---- jax/_src/lax/lax.py | 42 +++++++++++----------- jax/_src/lax/slicing.py | 8 ++--- jax/_src/lax/utils.py | 2 +- jax/_src/numpy/lax_numpy.py | 8 ++--- jax/_src/numpy/util.py | 4 +-- jax/_src/prng.py | 2 +- jax/_src/test_util.py | 4 +-- jax/_src/typing.py | 4 +-- jax/core.py | 26 ++++++++++++-- jax/dtypes.py | 1 + jax/experimental/jax2tf/jax2tf.py | 10 +++--- jax/experimental/shard_map.py | 4 +-- tests/lax_test.py | 4 +-- 19 files changed, 139 insertions(+), 112 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 019887d1050d..1c853bdbd589 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -773,7 +773,7 @@ def _check_input_dtype_revderiv(name, holomorphic, allow_int, x): if not dtypes.issubdtype(aval.dtype, np.complexfloating): raise TypeError(f"{name} with holomorphic=True requires inputs with complex dtype, " f"but got {aval.dtype.name}.") - if (dtypes.is_opaque_dtype(aval.dtype) or + if (dtypes.issubdtype(aval.dtype, dtypes.extended) or dtypes.issubdtype(aval.dtype, np.integer) or dtypes.issubdtype(aval.dtype, np.bool_)): if not allow_int: @@ -788,7 +788,7 @@ def _check_input_dtype_revderiv(name, holomorphic, allow_int, x): def _check_output_dtype_revderiv(name, holomorphic, x): aval = core.get_aval(x) - if dtypes.is_opaque_dtype(aval.dtype): + if dtypes.issubdtype(aval.dtype, dtypes.extended): raise TypeError( f"{name} with output element type {aval.dtype.name}") if holomorphic: @@ -874,7 +874,7 @@ def jacfun(*args, **kwargs): def _check_input_dtype_jacfwd(holomorphic: bool, x: Any) -> None: dispatch.check_arg(x) aval = core.get_aval(x) - if dtypes.is_opaque_dtype(aval.dtype): + if dtypes.issubdtype(aval.dtype, dtypes.extended): raise TypeError( f"jacfwd with input element type {aval.dtype.name}") if holomorphic: @@ -2585,7 +2585,7 @@ def _device_put_sharded(*xs): stacked_aval = avals[0].update(shape=(len(devices),) + avals[0].shape) sharding_spec = sharding_specs.create_pmap_sharding_spec(stacked_aval.shape) sharding = PmapSharding(np.array(devices), sharding_spec) - if dtypes.is_opaque_dtype(stacked_aval.dtype): + if dtypes.issubdtype(stacked_aval.dtype, dtypes.extended): return stacked_aval.dtype._rules.device_put_sharded(xs, stacked_aval, sharding, devices) return pxla.batched_device_put(stacked_aval, sharding, xs, list(devices)) @@ -2635,7 +2635,7 @@ def _device_put_replicated(x): sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape) buf = device_put(x, devices[0]) sharding = PmapSharding(np.array(devices), sharding_spec) - if dtypes.is_opaque_dtype(aval.dtype): + if dtypes.issubdtype(aval.dtype, dtypes.extended): return aval.dtype._rules.device_put_replicated(buf, aval, sharding, devices) assert len(xla.aval_to_xla_shapes(aval)) == 1 return pxla.batched_device_put(aval, sharding, [buf] * len(devices), devices) @@ -2711,7 +2711,7 @@ def __init__(self, shape, dtype, named_shape=None, sharding=None): self.shape = tuple(shape) if dtype is None: raise ValueError("ShapeDtypeStruct: dtype must be specified.") - self.dtype = dtype if dtypes.is_opaque_dtype(dtype) else np.dtype(dtype) + self.dtype = dtype if dtypes.issubdtype(dtype, dtypes.extended) else np.dtype(dtype) if sharding is not None and not isinstance(sharding, Sharding): raise ValueError( "sharding should be an instance of `jax.sharding.Sharding`. " @@ -2750,7 +2750,7 @@ def __hash__(self): return hash((self.shape, self.dtype, named, self.sharding)) core.pytype_aval_mappings[ShapeDtypeStruct] = ( - lambda x: ShapedArray(x.shape, dtypes.canonicalize_dtype(x.dtype, allow_opaque_dtype=True), + lambda x: ShapedArray(x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True), weak_type=False, named_shape=x.named_shape)) @api_boundary diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index acf72c84d8d2..cb7a41ed02a5 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -567,7 +567,7 @@ def _shaped_abstractify_slow(x): weak_type = getattr(x, 'weak_type', False) named_shape = getattr(x, 'named_shape', {}) if hasattr(x, 'dtype'): - dtype = dtypes.canonicalize_dtype(x.dtype, allow_opaque_dtype=True) + dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True) else: raise TypeError( f"Cannot interpret value of type {type(x)} as an abstract array; it " @@ -592,14 +592,14 @@ def _numpy_array_abstractify(x: np.ndarray) -> ShapedArray: dtype = x.dtype dtypes.check_valid_dtype(dtype) return ShapedArray(x.shape, - dtypes.canonicalize_dtype(dtype, allow_opaque_dtype=True)) + dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True)) _shaped_abstractify_handlers[np.ndarray] = _numpy_array_abstractify def _np_scalar_abstractify(x: np.generic) -> ShapedArray: dtype = np.dtype(x) dtypes.check_valid_dtype(dtype) return ShapedArray(np.shape(x), - dtypes.canonicalize_dtype(dtype, allow_opaque_dtype=True)) + dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True)) _shaped_abstractify_handlers.update((t, _np_scalar_abstractify) for t in numpy_scalar_types) diff --git a/jax/_src/array.py b/jax/_src/array.py index d7301fd7ab47..2e2471a0c9e6 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -600,7 +600,7 @@ def make_array_from_callback( for device in sharding.addressable_devices ] aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False) - if dtypes.is_opaque_dtype(aval.dtype): + if dtypes.issubdtype(aval.dtype, dtypes.extended): return aval.dtype._rules.make_sharded_array(aval, sharding, arrays, committed=True) return ArrayImpl(aval, sharding, arrays, committed=True) @@ -661,7 +661,7 @@ def make_array_from_single_device_arrays( # All input arrays should be committed. Checking it is expensive on # single-controller systems. aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False) - if dtypes.is_opaque_dtype(aval.dtype): + if dtypes.issubdtype(aval.dtype, dtypes.extended): return aval.dtype._rules.make_sharded_array(aval, sharding, arrays, committed=True) # TODO(phawkins): ideally the cast() could be checked. Revisit this after # removing DeviceArray. @@ -785,7 +785,7 @@ def _array_global_result_handler(global_aval, out_sharding, committed, is_out_sharding_from_xla): if global_aval.dtype == dtypes.float0: return lambda _: np.zeros(global_aval.shape, dtypes.float0) # type: ignore - if dtypes.is_opaque_dtype(global_aval.dtype): + if dtypes.issubdtype(global_aval.dtype, dtypes.extended): return global_aval.dtype._rules.global_sharded_result_handler( global_aval, out_sharding, committed, is_out_sharding_from_xla) return xc.array_result_handler( @@ -800,7 +800,7 @@ def _array_global_result_handler(global_aval, out_sharding, committed, def _array_local_result_handler(aval, sharding, indices): if aval.dtype == dtypes.float0: return lambda _: np.zeros(aval.shape, dtypes.float0) # type: ignore - if dtypes.is_opaque_dtype(aval.dtype): + if dtypes.issubdtype(aval.dtype, dtypes.extended): return aval.dtype._rules.local_sharded_result_handler( aval, sharding, indices) return xc.array_result_handler( diff --git a/jax/_src/core.py b/jax/_src/core.py index 7cba1e44f532..badb0d619df4 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1398,26 +1398,26 @@ def concrete_dim_or_error(val: Any, context=""): else: return concrete_or_error(operator.index, val, context=context) -### Opaque dtypes +### Extended dtypes # -# Opaque dtypes are JAX-specific dtypes that allow us to represent logical +# Extended dtypes are JAX-specific dtypes that allow us to represent logical # arrays of element types that do not have an obvious direct correspondence # to ("physical") arrays of basic types in a compiler. In particular, their # element types differ from those of XLA and NumPy (e.g. int32). These dtypes # are only known to JAX. Their implementation is determined by: -# a) an object representing the opaque dtype, accessible via the `dtype` +# a) an object representing the extended dtype, accessible via the `dtype` # attribute on corresponding JAX arrays and, internally, on avals such # as ShapedArrays that correspond to such JAX arrays; -# b) a set of rules, available via a private attribute on the opaque dtype +# b) a set of rules, available via a private attribute on the extended dtype # object in (a). # The rules in (b) tell JAX internals how to ground out the element # type for interaction with the compiler and runtime, e.g. when lowering # to the compiler's language. -# TODO(frostig): update inliners of the four functions below to call them +# TODO(jakevdp): remove this function once it's unused downstream. def has_opaque_dtype(x: Any) -> bool: - return dtypes.is_opaque_dtype(get_aval(x).dtype) + return dtypes.issubdtype(get_aval(x).dtype, dtypes.extended) @overload def physical_aval(aval: ShapedArray) -> ShapedArray: ... @@ -1428,7 +1428,7 @@ def physical_aval(aval: AbstractValue) -> AbstractValue: ... def physical_aval(aval): aval_dtype = getattr(aval, 'dtype', None) - if aval_dtype and dtypes.is_opaque_dtype(aval_dtype): + if aval_dtype and dtypes.issubdtype(aval_dtype, dtypes.extended): ctor = type(aval) aval_shape = getattr(aval, 'shape', None) assert aval_shape is not None, (ctor, aval) @@ -1439,14 +1439,14 @@ def physical_aval(aval): return aval def _short_dtype_name(dtype) -> str: - if dtypes.issubdtype(dtype, dtypes.opaque): + if dtypes.issubdtype(dtype, dtypes.extended): return str(dtype) else: return (dtype.name.replace('float', 'f').replace('uint' , 'u') .replace('int' , 'i').replace('complex', 'c')) def _dtype_object(dtype): - return dtype if dtypes.issubdtype(dtype, dtypes.opaque) else np.dtype(dtype) + return dtype if dtypes.issubdtype(dtype, dtypes.extended) else np.dtype(dtype) class UnshapedArray(AbstractValue): __slots__ = ['dtype', 'weak_type'] @@ -1653,11 +1653,11 @@ def str_short(self, short_dtypes=False) -> str: _complex = concretization_function_error(complex, True) def primal_dtype_to_tangent_dtype(primal_dtype): - # TODO(frostig,mattjj): determines that all opaque dtypes have - # float0 tangent type, which works fine for all our current opaque - # dtype applications. We may some day want to delegate this - # decision to the dtype rules. - if (dtypes.is_opaque_dtype(primal_dtype) or + # TODO(frostig,mattjj): determines that all extended dtypes have + # float0 tangent type, which works fine for all our current + # extended dtype applications. We may some day want to delegate + # this decision to the dtype rules. + if (dtypes.issubdtype(primal_dtype, dtypes.extended) or not dtypes.issubdtype(primal_dtype, np.inexact)): return dtypes.float0 else: @@ -1780,12 +1780,12 @@ def __len__(self): x._data) @dataclass(frozen=True, eq=True) -class bint(dtypes.OpaqueDType): +class bint(dtypes.ExtendedDType): bound: int @property def type(self) -> type: - return dtypes.opaque + return dtypes.extended @property def name(self) -> str: diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index a18d9cc94a49..f69b7ab923da 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -46,9 +46,8 @@ FLAGS = flags.FLAGS -# TODO(jakevdp): rename opaque dtypes to something more user-friendly -class opaque(np.generic): - """Scalar class for opaque dtypes. +class extended(np.generic): + """Scalar class for extended dtypes. This is an abstract class that should never be instantiated, but rather exists for the sake of `jnp.issubdtype`. @@ -57,13 +56,13 @@ class opaque(np.generic): >>> from jax import random >>> from jax._src import dtypes >>> key = random.key(0) - >>> jnp.issubdtype(key.dtype, dtypes.opaque) + >>> jnp.issubdtype(key.dtype, dtypes.extended) True """ pass -class prng_key(opaque): +class prng_key(extended): """Scalar class for PRNG Key dtypes. This is an abstract class that should never be instantiated, but rather @@ -79,17 +78,16 @@ class prng_key(opaque): pass -# TODO(jakevdp): rename opaque dtypes to something more user-friendly -class OpaqueDType(metaclass=abc.ABCMeta): - """Abstract Base Class for opaque dtypes""" +class ExtendedDType(metaclass=abc.ABCMeta): + """Abstract Base Class for extended dtypes""" @property @abc.abstractmethod def type(self) -> type: ... +# TODO(jakevdp): remove this function once it's unused downstream. def is_opaque_dtype(dtype: Any) -> bool: - # TODO(vanderplas, frostig): remove in favor of inlining `issubdtype` - return issubdtype(dtype, opaque) + return issubdtype(dtype, extended) # fp8 support float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz @@ -185,11 +183,11 @@ def to_complex_dtype(dtype: DTypeLike) -> DType: @functools.cache -def _canonicalize_dtype(x64_enabled: bool, allow_opaque_dtype: bool, dtype: Any) -> Union[DType, OpaqueDType]: - if is_opaque_dtype(dtype): - if not allow_opaque_dtype: - raise ValueError(f"Internal: canonicalize_dtype called on opaque dtype {dtype} " - "with allow_opaque_dtype=False") +def _canonicalize_dtype(x64_enabled: bool, allow_extended_dtype: bool, dtype: Any) -> Union[DType, ExtendedDType]: + if issubdtype(dtype, extended): + if not allow_extended_dtype: + raise ValueError(f"Internal: canonicalize_dtype called on extended dtype {dtype} " + "with allow_extended_dtype=False") return dtype try: dtype_ = np.dtype(dtype) @@ -202,14 +200,20 @@ def _canonicalize_dtype(x64_enabled: bool, allow_opaque_dtype: bool, dtype: Any) return _dtype_to_32bit_dtype.get(dtype_, dtype_) @overload -def canonicalize_dtype(dtype: Any, allow_opaque_dtype: Literal[False] = False) -> DType: ... +def canonicalize_dtype(dtype: Any, allow_extended_dtype: Literal[False] = False, allow_opaque_dtype: Any = None) -> DType: ... @overload -def canonicalize_dtype(dtype: Any, allow_opaque_dtype: bool = False) -> Union[DType, OpaqueDType]: ... +def canonicalize_dtype(dtype: Any, allow_extended_dtype: bool = False, allow_opaque_dtype: Any = None) -> Union[DType, ExtendedDType]: ... -def canonicalize_dtype(dtype: Any, allow_opaque_dtype: bool = False) -> Union[DType, OpaqueDType]: +def canonicalize_dtype(dtype: Any, allow_extended_dtype: bool = False, allow_opaque_dtype: Any = None) -> Union[DType, ExtendedDType]: """Convert from a dtype to a canonical dtype based on config.x64_enabled.""" - return _canonicalize_dtype(config.x64_enabled, allow_opaque_dtype, dtype) # type: ignore[bad-return-type] + if allow_opaque_dtype is not None: + # TODO(jakevdp): complete the deprecation cycle (Deprecated July 24 2023). + warnings.warn( + "allow_opaque_dtype argument is deprecated; use allow_extended_dtype.", + DeprecationWarning) + allow_extended_dtype = allow_opaque_dtype + return _canonicalize_dtype(config.x64_enabled, allow_extended_dtype, dtype) # type: ignore[bad-return-type] # Default dtypes corresponding to Python scalars. python_scalar_dtypes : dict[type, DType] = { @@ -315,9 +319,9 @@ def issubdtype(a: DTypeLike, b: DTypeLike) -> bool: This is like :func:`numpy.issubdtype`, but can handle dtype extensions such as :obj:`jax.dtypes.bfloat16`. """ - if isinstance(a, OpaqueDType): + if isinstance(a, ExtendedDType): return _issubclass(a.type, b) - elif _issubclass(b, opaque): + elif _issubclass(b, extended): return False # Canonicalizes all concrete types to np.dtype instances a = a if _is_typeclass(a) else np.dtype(a) @@ -573,18 +577,18 @@ def dtype(x: Any, *, canonicalize: bool = False) -> DType: dt = python_scalar_dtypes[x] elif type(x) in python_scalar_dtypes: dt = python_scalar_dtypes[type(x)] - elif is_opaque_dtype(getattr(x, 'dtype', None)): + elif issubdtype(getattr(x, 'dtype', None), extended): dt = x.dtype else: try: dt = np.result_type(x) except TypeError as err: raise TypeError(f"Cannot determine dtype of {x}") from err - if dt not in _jax_dtype_set and not is_opaque_dtype(dt): + if dt not in _jax_dtype_set and not issubdtype(dt, extended): raise TypeError(f"Value '{x}' with dtype {dt} is not a valid JAX array " "type. Only arrays of numeric types are supported by JAX.") # TODO(jakevdp): fix return type annotation and remove this ignore. - return canonicalize_dtype(dt, allow_opaque_dtype=True) if canonicalize else dt # type: ignore[return-value] + return canonicalize_dtype(dt, allow_extended_dtype=True) if canonicalize else dt # type: ignore[return-value] def _lattice_result_type(*args: Any) -> tuple[DType, bool]: dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args)) @@ -592,7 +596,7 @@ def _lattice_result_type(*args: Any) -> tuple[DType, bool]: out_dtype = dtypes[0] out_weak_type = weak_types[0] elif len(set(dtypes)) == 1 and not all(weak_types): - # Trivial promotion case. This allows opaque dtypes through. + # Trivial promotion case. This allows extended dtypes through. out_dtype = dtypes[0] out_weak_type = False elif all(weak_types) and config.jax_numpy_dtype_promotion != 'strict': @@ -632,18 +636,18 @@ def result_type(*args: Any, return_weak_type_flag: bool = False) -> Union[DType, """ if len(args) == 0: raise ValueError("at least one array or dtype is required") - dtype: DType | OpaqueDType + dtype: DType | ExtendedDType dtype, weak_type = _lattice_result_type(*(float_ if arg is None else arg for arg in args)) if weak_type: dtype = canonicalize_dtype( _default_types['f' if dtype in _custom_float_dtypes else dtype.kind]) else: - dtype = canonicalize_dtype(dtype, allow_opaque_dtype=True) + dtype = canonicalize_dtype(dtype, allow_extended_dtype=True) # TODO(jakevdp): fix return type annotation and remove this ignore. return (dtype, weak_type) if return_weak_type_flag else dtype # type: ignore[return-value] def check_user_dtype_supported(dtype, fun_name=None): - if is_opaque_dtype(dtype): + if issubdtype(dtype, extended): return # Avoid using `dtype in [...]` because of numpy dtype equality overloading. if isinstance(dtype, type) and dtype in {bool, int, float, builtins.complex}: diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index d22e0bbd1734..937c153e4e36 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1095,7 +1095,7 @@ def aval_to_types(aval): def _to_physical_op_sharding( aval: core.AbstractValue | None, sharding: xc.HloSharding | None ) -> xc.OpSharding | None: - if (isinstance(aval, core.ShapedArray) and dtypes.is_opaque_dtype(aval.dtype) + if (isinstance(aval, core.ShapedArray) and dtypes.issubdtype(aval.dtype, dtypes.extended) and sharding is not None): return aval.dtype._rules.physical_hlo_sharding(aval, sharding).to_proto() return None if sharding is None else sharding.to_proto() # type: ignore @@ -1355,7 +1355,7 @@ def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue, # broadcast_dimension[i] is the axis of the result where the axis i of # op is broadcast. # Lower a possibly-dynamic broadcast_in_dim - if dtypes.is_opaque_dtype(aval_out.dtype): # type: ignore + if dtypes.issubdtype(aval_out.dtype, dtypes.extended): # type: ignore elt_shape = aval_out.dtype._rules.physical_element_aval( # type: ignore aval_out.dtype).shape # type: ignore trailing_dims = [aval_out.ndim + i for i in range(len(elt_shape))] # type: ignore @@ -1408,7 +1408,7 @@ def reshape(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue) -> ir.Va def slice_op(ctx: LoweringRuleContext, x, aval_out, *, start_indices, limit_indices, strides) -> ir.Value: - if dtypes.is_opaque_dtype(aval_out.dtype): + if dtypes.issubdtype(aval_out.dtype, dtypes.extended): elt_shape = aval_out.dtype._rules.physical_element_aval( aval_out.dtype).shape trailing_zeros = [0] * len(elt_shape) @@ -1436,7 +1436,7 @@ def slice_op(ctx: LoweringRuleContext, x, aval_out, *, def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *, start_indices) -> ir.Value: x_aval = ctx.avals_in[0] - if dtypes.is_opaque_dtype(aval_out.dtype): + if dtypes.issubdtype(aval_out.dtype, dtypes.extended): elt_shape = aval_out.dtype._rules.physical_element_aval( aval_out.dtype).shape index_avals = ctx.avals_in[1:] @@ -1471,7 +1471,7 @@ def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *, def dynamic_update_slice(ctx: LoweringRuleContext, aval_out, x, update, *, start_indices) -> ir.Value: - if dtypes.is_opaque_dtype(aval_out.dtype): + if dtypes.issubdtype(aval_out.dtype, dtypes.extended): elt_shape = aval_out.dtype._rules.physical_element_aval( aval_out.dtype).shape index_avals = ctx.avals_in[2:] @@ -1574,7 +1574,7 @@ def convert_hlo(ctx: LoweringRuleContext, x, aval_in, aval_out): In particular, treat casts to boolean as x != 0, rather than truncating integer values (b/209440332).""" - if (not dtypes.is_opaque_dtype(aval_out.dtype) and + if (not dtypes.issubdtype(aval_out.dtype, dtypes.extended) and aval_out.dtype == np.dtype(np.bool_)): if dtypes.issubdtype(aval_in.dtype, np.inexact): compare_type = "FLOAT" diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 72c8e181ef17..4ed66d9024ba 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -514,8 +514,8 @@ def _convert_element_type(operand: ArrayLike, new_dtype: Optional[DTypeLike] = N if hasattr(operand, '__jax_array__'): operand = operand.__jax_array__() # type: ignore - if (dtypes.is_opaque_dtype(new_dtype) or - dtypes.is_opaque_dtype(getattr(operand, 'dtype', None))): + if (dtypes.issubdtype(new_dtype, dtypes.extended) or + dtypes.issubdtype(getattr(operand, 'dtype', None), dtypes.extended)): return convert_element_type_p.bind(operand, new_dtype=new_dtype, weak_type=bool(weak_type)) @@ -1201,7 +1201,7 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: Optional[DTypeLike] = None) if np.shape(fill_value): msg = "full must be called with scalar fill_value, got fill_value.shape {}." raise TypeError(msg.format(np.shape(fill_value))) - if dtypes.is_opaque_dtype(dtype): + if dtypes.issubdtype(dtype, dtypes.extended): return dtype._rules.full(shape, fill_value, dtype) # type: ignore[union-attr] weak_type = dtype is None and dtypes.is_weakly_typed(fill_value) dtype = dtypes.canonicalize_dtype(dtype or _dtype(fill_value)) @@ -1352,7 +1352,7 @@ def full_like(x: Union[ArrayLike, DuckTypedArray], fill_shape = np.shape(x) if shape is None else canonicalize_shape(shape) # type: ignore[arg-type] weak_type = dtype is None and dtypes.is_weakly_typed(x) dtype = dtype or _dtype(x) - if dtypes.is_opaque_dtype(dtype): + if dtypes.issubdtype(dtype, dtypes.extended): return dtype._rules.full(fill_shape, fill_value, dtype) # type: ignore[union-attr] val = full(fill_shape, _convert_element_type(fill_value, dtype, weak_type)) # If the sharding is SingleDeviceSharding then don't take the `if` branch @@ -1535,11 +1535,11 @@ def unop(result_dtype, accepted_dtypes, name): def naryop_dtype_rule(result_dtype, accepted_dtypes, name, *avals, - allow_opaque_dtype=False, **kwargs): + allow_extended_dtype=False, **kwargs): del kwargs assert len(avals) == len(accepted_dtypes), (avals, accepted_dtypes) for i, aval in enumerate(avals): - if allow_opaque_dtype and dtypes.is_opaque_dtype(aval.dtype): + if allow_extended_dtype and dtypes.issubdtype(aval.dtype, dtypes.extended): continue types = accepted_dtypes[i] if not any(dtypes.issubdtype(aval.dtype, t) for t in types): @@ -1601,9 +1601,9 @@ def _naryop_weak_type_rule(name, *avals, **kwargs): "taken a gradient with respect to an integer argument.") return all(aval.weak_type for aval in avals) -def naryop(result_dtype, accepted_dtypes, name, allow_opaque_dtype=False): +def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False): dtype_rule = partial(naryop_dtype_rule, result_dtype, accepted_dtypes, name, - allow_opaque_dtype=allow_opaque_dtype) + allow_extended_dtype=allow_extended_dtype) shape_rule = partial(broadcasting_shape_rule, name) weak_type_rule = partial(_naryop_weak_type_rule, name) prim = standard_primitive(shape_rule, dtype_rule, name, @@ -2217,13 +2217,13 @@ def _compare_lower_hlo_opaque(direction: str, ctx, avals_in, aval_out, x, y): return _opaque_ne_hlo(ctx, broadcast_avals_in, aval_out, x, y) else: raise NotImplementedError( - f"HLO comparison {direction} for opaque dtype {avals_in[0].dtype}") + f"HLO comparison {direction} for extended dtype {avals_in[0].dtype}") def _compare_lower_hlo(direction: str, ctx, x, y): avals_in, (aval_out,) = ctx.avals_in, ctx.avals_out x_dtype = avals_in[0].dtype x, y = mlir.multi_broadcast_in_dim(ctx, (x, y), avals_in, aval_out.shape) - if dtypes.is_opaque_dtype(x_dtype): + if dtypes.issubdtype(x_dtype, dtypes.extended): return _compare_lower_hlo_opaque(direction, ctx, avals_in, aval_out, x, y) if dtypes.issubdtype(x_dtype, np.inexact): compare_type = "FLOAT" @@ -2233,11 +2233,11 @@ def _compare_lower_hlo(direction: str, ctx, x, y): compare_type = "UNSIGNED" return mlir.compare_hlo(x, y, direction, compare_type).results -eq_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq', allow_opaque_dtype=True) +eq_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq', allow_extended_dtype=True) ad.defjvp_zero(eq_p) mlir.register_lowering(eq_p, partial(_compare_lower_hlo, "EQ")) -ne_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'ne', allow_opaque_dtype=True) +ne_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'ne', allow_extended_dtype=True) ad.defjvp_zero(ne_p) mlir.register_lowering(ne_p, partial(_compare_lower_hlo, "NE")) @@ -2263,11 +2263,11 @@ def _convert_element_type_shape_rule(operand, *, new_dtype, weak_type): def _convert_element_type_dtype_rule(operand, *, new_dtype, weak_type): if operand.dtype != new_dtype: - if (dtypes.is_opaque_dtype(operand.dtype) and + if (dtypes.issubdtype(operand.dtype, dtypes.extended) and not isinstance(operand.dtype, core.bint)): raise ValueError( f"Cannot call convert_element_type on dtype {dtype_to_string(operand.dtype)}") - if (dtypes.is_opaque_dtype(new_dtype) and + if (dtypes.issubdtype(new_dtype, dtypes.extended) and not isinstance(new_dtype, core.bint)): raise ValueError( f"Cannot convert_element_type to dtype={dtype_to_string(new_dtype)}") @@ -2308,7 +2308,7 @@ def _convert_elt_type_folding_rule(consts, eqn): o, = eqn.outvars if (type(c) in {np.ndarray, *dtypes.python_scalar_dtypes} and isinstance(o.aval, core.UnshapedArray) and not np.shape(c) and - not dtypes.is_opaque_dtype(eqn.params['new_dtype'])): + not dtypes.issubdtype(eqn.params['new_dtype'], dtypes.extended)): out = np.array(c, eqn.params['new_dtype']) if not o.aval.weak_type: return [out], None @@ -2319,8 +2319,8 @@ def _convert_elt_type_folding_rule(consts, eqn): def _convert_elt_type_fwd_rule(eqn): v, = eqn.invars - if (not dtypes.is_opaque_dtype(eqn.params['new_dtype']) and - not dtypes.is_opaque_dtype(v.aval.dtype) and + if (not dtypes.issubdtype(eqn.params['new_dtype'], dtypes.extended) and + not dtypes.issubdtype(v.aval.dtype, dtypes.extended) and v.aval.dtype == eqn.params['new_dtype'] and v.aval.weak_type == eqn.params['weak_type']): return [v], None @@ -3424,7 +3424,7 @@ def _transpose_batch_rule(batched_args, batch_dims, *, permutation): def _transpose_lower(ctx, x, *, permutation): aval_out, = ctx.avals_out - if dtypes.is_opaque_dtype(aval_out.dtype): + if dtypes.issubdtype(aval_out.dtype, dtypes.extended): elt_shape = aval_out.dtype._rules.physical_element_aval( aval_out.dtype).shape trailing_dims = [aval_out.ndim + i for i in range(len(elt_shape))] @@ -3555,7 +3555,7 @@ def _select_hlo_lowering(ctx, which, *cases): which_aval = ctx.avals_in[0] aval_out, = ctx.avals_out - if dtypes.is_opaque_dtype(aval_out.dtype): + if dtypes.issubdtype(aval_out.dtype, dtypes.extended): return [_select_hlo_lowering_opaque(ctx, which, *cases)] if which_aval.dtype == np.dtype(np.bool_): @@ -4770,7 +4770,7 @@ def check_same_dtypes(name: str, *avals: core.UnshapedArray) -> None: """Check that dtypes agree, possibly ignoring float precision.""" # the `ignore_fp_precision` flag exists because the XLA shape inference logic # allows mixed floating point precision, but the HLO verifier often rejects it - if any(dtypes.is_opaque_dtype(aval.dtype) for aval in avals): + if any(dtypes.issubdtype(aval.dtype, dtypes.extended) for aval in avals): return # TODO(mattjj,frostig): do some checking, friend if len(avals) < 2: return @@ -4912,7 +4912,7 @@ def empty(dtype): empty_p = core.Primitive('empty') empty_p.def_abstract_eval(lambda *, dtype: core.ShapedArray((), dtype)) def _empty_lower(ctx, *, dtype): - dtype = dtype if dtypes.is_opaque_dtype(dtype) else np.dtype(dtype) + dtype = dtype if dtypes.issubdtype(dtype, dtypes.extended) else np.dtype(dtype) phys_aval = core.physical_aval(core.ShapedArray((), dtype)) return mlir.ir_constants(np.zeros(phys_aval.shape, phys_aval.dtype)) mlir.register_lowering(empty_p, _empty_lower) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index f697f9c21019..2d242b16284c 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -1410,7 +1410,7 @@ def _dynamic_update_slice_lower(ctx, x, update, *start_indices): def _gather_dtype_rule(operand, indices, *, fill_value, **kwargs): if not dtypes.issubdtype(indices.dtype, np.integer): raise ValueError("indices must have an integer type") - return dtypes.canonicalize_dtype(operand.dtype, allow_opaque_dtype=True) + return dtypes.canonicalize_dtype(operand.dtype, allow_extended_dtype=True) _rank = lambda arr: len(arr.shape) @@ -1784,7 +1784,7 @@ def _gather_lower(ctx, operand, indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): aval_out, = ctx.avals_out - if dtypes.is_opaque_dtype(aval_out.dtype): + if dtypes.issubdtype(aval_out.dtype, dtypes.extended): return [_gather_lower_opaque( ctx, operand, indices, dimension_numbers=dimension_numbers, slice_sizes=slice_sizes, unique_indices=unique_indices, @@ -1835,7 +1835,7 @@ def _scatter_dtype_rule(operand, indices, updates, **kwargs): if not dtypes.issubdtype(indices.dtype, np.integer): raise ValueError("indices must have an integer type") lax.check_same_dtypes("scatter", operand, updates) - return dtypes.canonicalize_dtype(operand.dtype, allow_opaque_dtype=True) + return dtypes.canonicalize_dtype(operand.dtype, allow_extended_dtype=True) def _scatter_shape_rule(operand, indices, updates, *, update_jaxpr, update_consts, dimension_numbers, indices_are_sorted, @@ -2438,7 +2438,7 @@ def _scatter_lower(ctx, operand, indices, updates, *, _scatter_reduction_computation, core.ShapedArray((), operand_dtype)) aval_out, = ctx.avals_out - if dtypes.is_opaque_dtype(aval_out.dtype): + if dtypes.issubdtype(aval_out.dtype, dtypes.extended): return [_scatter_lower_opaque( ctx, operand, indices, updates, update_jaxpr=update_jaxpr, update_consts=update_consts, diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index 5556e38ae4d5..d92379eaac44 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -31,7 +31,7 @@ xops = xla_client.ops -_input_dtype: Callable = lambda *args, **_: dtypes.canonicalize_dtype(args[0].dtype, allow_opaque_dtype=True) +_input_dtype: Callable = lambda *args, **_: dtypes.canonicalize_dtype(args[0].dtype, allow_extended_dtype=True) def _argnum_weak_type(*argnums): return lambda *args, **_: all(args[i].weak_type for i in argnums) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 9b13db0663a4..3defada04227 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -217,7 +217,7 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta: def _jnp_dtype(obj: Optional[DTypeLike], *, align: bool = False, copy: bool = False) -> DType: """Similar to np.dtype, but respects JAX dtype defaults.""" - if dtypes.is_opaque_dtype(obj): + if dtypes.issubdtype(obj, dtypes.extended): return obj # type: ignore[return-value] if obj is None: obj = dtypes.float_ @@ -2038,7 +2038,7 @@ def array(object: Any, dtype: Optional[DTypeLike] = None, copy: bool = True, dtype = dtypes._lattice_result_type(*leaves)[0] if not weak_type: - dtype = dtypes.canonicalize_dtype(dtype, allow_opaque_dtype=True) + dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) out: ArrayLike @@ -2085,7 +2085,7 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike: def asarray(a: Any, dtype: Optional[DTypeLike] = None, order: Optional[str] = None) -> Array: dtypes.check_user_dtype_supported(dtype, "asarray") if dtype is not None: - dtype = dtypes.canonicalize_dtype(dtype, allow_opaque_dtype=True) + dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) return array(a, dtype=dtype, copy=False, order=order) # type: ignore @@ -2328,7 +2328,7 @@ def arange(start: DimSize, stop: Optional[DimSize] = None, if stop is None and step is None: start_dtype = _dtype(start) if (not dtypes.issubdtype(start_dtype, np.integer) and - not dtypes.is_opaque_dtype(start_dtype)): + not dtypes.issubdtype(start_dtype, dtypes.extended)): ceil_ = ufuncs.ceil if isinstance(start, core.Tracer) else np.ceil start = ceil_(start).astype(int) # type: ignore return lax.iota(dtype, start) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 4a9f501d6280..10625b037c25 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -271,7 +271,7 @@ def promote_dtypes(*args: ArrayLike) -> list[Array]: return [lax.asarray(arg) for arg in args] else: to_dtype, weak_type = dtypes._lattice_result_type(*args) - to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_opaque_dtype=True) # type: ignore[assignment] + to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_extended_dtype=True) # type: ignore[assignment] return [lax._convert_element_type(x, to_dtype, weak_type) for x in args] @@ -280,7 +280,7 @@ def promote_dtypes_inexact(*args: ArrayLike) -> list[Array]: Promotes arguments to an inexact type.""" to_dtype, weak_type = dtypes._lattice_result_type(*args) - to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_opaque_dtype=True) # type: ignore[assignment] + to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_extended_dtype=True) # type: ignore[assignment] to_dtype_inexact = dtypes.to_inexact_dtype(to_dtype) return [lax._convert_element_type(x, to_dtype_inexact, weak_type) for x in args] diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 8a50214bc9fe..b5360fe9229f 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -577,7 +577,7 @@ def device_put_replicated(val, aval, sharding, devices): return random_wrap(physical_result, impl=aval.dtype.impl) -class KeyTy(dtypes.OpaqueDType): +class KeyTy(dtypes.ExtendedDType): impl: Hashable # prng.PRNGImpl. TODO(mattjj,frostig): protocol really _rules = KeyTyRules type = dtypes.prng_key diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index a469054c6e0c..08369c08dd40 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -915,8 +915,8 @@ def assertArraysAllClose(self, x, y, *, check_dtypes=True, atol=None, def assertDtypesMatch(self, x, y, *, canonicalize_dtypes=True): if not config.x64_enabled and canonicalize_dtypes: - self.assertEqual(_dtypes.canonicalize_dtype(_dtype(x), allow_opaque_dtype=True), - _dtypes.canonicalize_dtype(_dtype(y), allow_opaque_dtype=True)) + self.assertEqual(_dtypes.canonicalize_dtype(_dtype(x), allow_extended_dtype=True), + _dtypes.canonicalize_dtype(_dtype(y), allow_extended_dtype=True)) else: self.assertEqual(_dtype(x), _dtype(y)) diff --git a/jax/_src/typing.py b/jax/_src/typing.py index 88035319ed6c..4911f158d4e5 100644 --- a/jax/_src/typing.py +++ b/jax/_src/typing.py @@ -37,8 +37,8 @@ DType = np.dtype -# TODO(jakevdp, froystig): make OpaqueDType a protocol -OpaqueDType = Any +# TODO(jakevdp, froystig): make ExtendedDType a protocol +ExtendedDType = Any class SupportsDType(Protocol): @property diff --git a/jax/core.py b/jax/core.py index bb7ee0e98744..379c9cb82a73 100644 --- a/jax/core.py +++ b/jax/core.py @@ -104,7 +104,7 @@ gensym as gensym, get_aval as get_aval, get_referent as get_referent, - has_opaque_dtype as has_opaque_dtype, + has_opaque_dtype as _deprecated_has_opaque_dtype, is_constant_dim as is_constant_dim, is_constant_shape as is_constant_shape, jaxpr_as_fun as jaxpr_as_fun, @@ -185,5 +185,27 @@ symbolic_equal_dim = definitely_equal # TODO(necula): remove this API from jax._src.dtypes import ( - is_opaque_dtype as is_opaque_dtype, + is_opaque_dtype as _deprecated_is_opaque_dtype, ) + +_deprecations = { + # Added May 23, 2023: + "is_opaque_dtype": ( + "jax.core.is_opaque_dtype is deprecated. Use jnp.issubdtype(dt, dtypes.extended).", + _deprecated_is_opaque_dtype, + ), + "has_opaque_dtype": ( + "jax.core.is_opaque_dtype is deprecated. Use jnp.issubdtype(x.dtype, dtypes.extended).", + _deprecated_has_opaque_dtype, + ), +} + +import typing +if typing.TYPE_CHECKING: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +else: + from jax._src.dtypes import is_opaque_dtype as is_opaque_dtype + from jax._src.core import has_opaque_dtype as has_opaque_dtype +del typing, _deprecated_is_opaque_dtype, _deprecated_has_opaque_dtype diff --git a/jax/dtypes.py b/jax/dtypes.py index fbc3f158b01e..f2071fd4fe56 100644 --- a/jax/dtypes.py +++ b/jax/dtypes.py @@ -22,6 +22,7 @@ float0 as float0, iinfo, # TODO(phawkins): switch callers to jnp.iinfo? issubdtype, # TODO(phawkins): switch callers to jnp.issubdtype? + extended as extended, prng_key as prng_key, result_type as result_type, scalar_type_of as scalar_type_of, diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 8f35057f2cd0..4517a869a5c3 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1093,7 +1093,7 @@ def _tfval_to_tensor_jax_dtype(val: TfVal, # The float0 type is not known to TF. if jax_dtype == dtypes.float0: val = np.zeros(np.shape(val), conversion_dtype.as_numpy_dtype) - if hasattr(val, 'dtype') and dtypes.is_opaque_dtype(val.dtype): + if hasattr(val, 'dtype') and dtypes.issubdtype(val.dtype, dtypes.extended): val = val.dtype._rules.physical_const(val) tf_val = tf.convert_to_tensor(val, dtype=conversion_dtype) if do_memoize: @@ -2089,7 +2089,7 @@ def _broadcast_in_dim(operand, *, shape, broadcast_dimensions, def _empty(*, dtype): - if dtypes.is_opaque_dtype(dtype): + if dtypes.issubdtype(dtype, dtypes.extended): raise NotImplementedError # TODO(frostig,mattjj): jax2tf handlers return tf.constant(np.array(0, dtype=dtype)) @@ -2727,7 +2727,7 @@ def _gather(operand, start_indices, *, dimension_numbers, slice_sizes: core.Shap operand_aval = _in_avals[0] start_indices = _maybe_cast_to_int64(start_indices) - if dtypes.is_opaque_dtype(operand_aval.dtype): + if dtypes.issubdtype(operand_aval.dtype, dtypes.extended): opaque_shape = _jax_physical_aval(operand_aval).shape[len(operand_aval.shape):] trailing_offset_dims = [len(_out_aval.shape) + i for i in range(len(opaque_shape))] dimension_numbers = dimension_numbers._replace( @@ -2768,7 +2768,7 @@ def _dynamic_slice(operand, *start_indices, slice_sizes: core.Shape, _out_aval: core.ShapedArray): start_indices = _maybe_cast_to_int64(tf.stack(start_indices)) operand_aval = _in_avals[0] - if dtypes.is_opaque_dtype(operand_aval.dtype): + if dtypes.issubdtype(operand_aval.dtype, dtypes.extended): opaque_shape = _jax_physical_aval(operand_aval).shape[len(operand_aval.shape):] slice_sizes = (*slice_sizes, *opaque_shape) start_indices = tf.concat([start_indices, tf.zeros((len(opaque_shape),), @@ -2790,7 +2790,7 @@ def _dynamic_update_slice(operand, update, *start_indices, _out_aval: core.ShapedArray): start_indices = _maybe_cast_to_int64(tf.stack(start_indices)) operand_aval = _in_avals[0] - if dtypes.is_opaque_dtype(operand_aval.dtype): + if dtypes.issubdtype(operand_aval.dtype, dtypes.extended): opaque_shape = _jax_physical_aval(operand_aval).shape[len(operand_aval.shape):] start_indices = tf.concat([start_indices, tf.zeros((len(opaque_shape),), dtype=start_indices.dtype)], diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index ef94ac6b141a..37ddb4453348 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -525,7 +525,7 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, shard_proto = NamedSharding( mesh, sharding_impls.array_mapping_to_axis_resources(axes) # type: ignore )._to_xla_hlo_sharding(aval_in.ndim) - if dtypes.is_opaque_dtype(aval_in.dtype): + if dtypes.issubdtype(aval_in.dtype, dtypes.extended): shard_proto = aval_in.dtype._rules.physical_hlo_sharding(aval_in, shard_proto) sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto.to_proto(), # type: ignore unspecified_dims=set()) @@ -540,7 +540,7 @@ def _xla_unshard(ctx: mlir.LoweringRuleContext, shard_proto = NamedSharding( mesh, sharding_impls.array_mapping_to_axis_resources(axes) # type: ignore )._to_xla_hlo_sharding(aval_out.ndim) - if dtypes.is_opaque_dtype(aval_out.dtype): + if dtypes.issubdtype(aval_out.dtype, dtypes.extended): shard_proto = aval_out.dtype._rules.physical_hlo_sharding(aval_out, shard_proto) return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, shard_proto.to_proto(), set()) # type: ignore diff --git a/tests/lax_test.py b/tests/lax_test.py index 54212693ad5e..eb84c81a1959 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -2913,8 +2913,8 @@ def handler(arr): return handler -class FooTy(dtypes.OpaqueDType): - type = dtypes.opaque +class FooTy(dtypes.ExtendedDType): + type = dtypes.extended name = 'foo' _rules = FooTyRules