Skip to content

Commit

Permalink
Delete xla_bridge.xla.dtype_to_etype, replace it with jax.interpreter…
Browse files Browse the repository at this point in the history
…s.xla.dtype_to_primitive_type.

The new version does *not* canonicalize dtypes. We should be canonicalizing dtypes as part of tracing to a jaxpr, not in any way as part of XLA lowering. In all cases as best I can tell the dtypes from the callers are already canonical anyway.

jax.interpreters.xla is also a better location: I'm not even sure why we have a bunch of random things in xla_bridge any more, so it makes sense to consolidate them in xla.py along with the other registrations for things like avals.

Also delete the unused function xla_bridge.supported_numpy_dtypes.

PiperOrigin-RevId: 404246574
  • Loading branch information
hawkinsp authored and jax authors committed Oct 19, 2021
1 parent ee752b3 commit 185d7a9
Show file tree
Hide file tree
Showing 10 changed files with 85 additions and 52 deletions.
67 changes: 38 additions & 29 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,8 @@ def conv_general_dilated(
padding = padtype_to_pads(
np.take(lhs.shape, lhs_perm)[2:], effective_rhs_shape, # type: ignore[index]
window_strides, padding)
preferred_element_type = (None if preferred_element_type is None else
np.dtype(preferred_element_type))
return conv_general_dilated_p.bind(
lhs, rhs, window_strides=tuple(window_strides), padding=tuple(padding),
lhs_dilation=tuple(lhs_dilation), rhs_dilation=tuple(rhs_dilation),
Expand Down Expand Up @@ -684,7 +686,8 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
"""
if 1 <= lhs.ndim <= 2 and 1 <= rhs.ndim <= 2 and core.symbolic_equal_dim(lhs.shape[-1], rhs.shape[0]):
return dot_general(lhs, rhs, (((lhs.ndim - 1,), (0,)), ((), ())),
precision=precision, preferred_element_type=preferred_element_type)
precision=precision,
preferred_element_type=preferred_element_type)
else:
raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
lhs.shape, rhs.shape))
Expand Down Expand Up @@ -722,6 +725,8 @@ def dot_general(lhs: Array, rhs: Array, dimension_numbers: DotDimensionNumbers,
contract_dims_seq, batch_dims_seq = dimension_numbers
contract_dims = tuple(map(tuple, contract_dims_seq)) # type: ignore
batch_dims = tuple(map(tuple, batch_dims_seq)) # type: ignore
preferred_element_type = (None if preferred_element_type is None else
np.dtype(preferred_element_type))
return dot_general_p.bind(lhs, rhs,
dimension_numbers=(contract_dims, batch_dims),
precision=canonicalize_precision(precision),
Expand Down Expand Up @@ -3032,7 +3037,7 @@ def _convert_element_type_translation_rule(ctx, avals_in, avals_out, operand, *,
if (dtypes.issubdtype(old_dtype, np.complexfloating) and
not dtypes.issubdtype(new_dtype, np.complexfloating)):
operand = xops.Real(operand)
new_etype = xla_client.dtype_to_etype(new_dtype)
new_etype = xla.dtype_to_primitive_type(new_dtype)
return [xops.ConvertElementType(operand, new_element_type=new_etype)]

def _convert_element_type_transpose_rule(ct, operand, *, new_dtype, weak_type):
Expand Down Expand Up @@ -3082,7 +3087,7 @@ def _bitcast_convert_type_dtype_rule(operand, *, new_dtype):

def _bitcast_convert_type_translation_rule(ctx, avals_in, avals_out, operand, *,
new_dtype):
new_etype = xla_bridge.dtype_to_etype(new_dtype)
new_etype = xla.dtype_to_primitive_type(new_dtype)
return [xops.BitcastConvertType(operand, new_element_type=new_etype)]

bitcast_convert_type_p = standard_primitive(
Expand Down Expand Up @@ -3307,8 +3312,9 @@ def _conv_general_dilated_translation_rule(
if preferred_element_type is not None:
# Convert complex dtype to types used for real and imaginary parts
assert np.issubdtype(preferred_element_type, np.complexfloating)
preferred_element_type = xla_client.dtype_to_etype(
np.float64 if preferred_element_type == np.complex128 else np.float32)
preferred_element_type = xla.dtype_to_primitive_type(np.dtype(
np.float64 if preferred_element_type == np.complex128
else np.float32))

conv = lambda x, y: xops.ConvGeneralDilated(
x, y, window_strides, padding, lhs_dilation, rhs_dilation,
Expand All @@ -3323,7 +3329,7 @@ def _conv_general_dilated_translation_rule(
return [xops.Complex(xops.Sub(k1, k3), xops.Add(k1, k2))]

if preferred_element_type is not None:
preferred_element_type = xla_client.dtype_to_etype(preferred_element_type)
preferred_element_type = xla.dtype_to_primitive_type(preferred_element_type)

return [xops.ConvGeneralDilated(
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
Expand Down Expand Up @@ -3666,7 +3672,7 @@ def _dot_general_translation_rule(ctx, avals_in, avals_out, lhs, rhs, *,
dimension_numbers, precision,
preferred_element_type: Optional[DType]):
if preferred_element_type is not None:
preferred_element_type = xla_client.dtype_to_etype(preferred_element_type)
preferred_element_type = xla.dtype_to_primitive_type(preferred_element_type)
return [xops.DotGeneral(lhs, rhs,
xc.make_dot_dimension_numbers(dimension_numbers),
precision_config=_precision_config(precision),
Expand All @@ -3676,14 +3682,17 @@ def _dot_general_cpu_translation_rule(ctx, avals_in, avals_out, lhs, rhs, *,
dimension_numbers, precision,
preferred_element_type: Optional[DType]):
if preferred_element_type is not None:
preferred_element_type = xla_client.dtype_to_etype(preferred_element_type)
preferred_element_type = xla.dtype_to_primitive_type(preferred_element_type)

# TODO(b/195364460): Work around slow XLA/CPU implementation of float16 matmul
if avals_in[0].dtype == np.float16:
lhs = xops.ConvertElementType(lhs, xla_client.dtype_to_etype(np.float32))
rhs = xops.ConvertElementType(rhs, xla_client.dtype_to_etype(np.float32))
preferred_element_type = (preferred_element_type or
xla_client.dtype_to_etype(np.float16))
lhs = xops.ConvertElementType(
lhs, xla.dtype_to_primitive_type(np.dtype(np.float32)))
rhs = xops.ConvertElementType(
rhs, xla.dtype_to_primitive_type(np.dtype(np.float32)))
preferred_element_type = (
preferred_element_type or
xla.dtype_to_primitive_type(np.dtype(np.float16)))

return [xops.DotGeneral(lhs, rhs,
xc.make_dot_dimension_numbers(dimension_numbers),
Expand Down Expand Up @@ -4738,7 +4747,7 @@ def _gather_translation_rule(ctx, avals_in, avals_out, operand, indices, *,
intarray = partial(np.array, dtype=np.int64)
operand_dims = intarray(operand_aval.shape)
indices = xops.ConvertElementType(
indices, xb.dtype_to_etype(np.int64))
indices, xla.dtype_to_primitive_type(dtypes.canonicalize_dtype(np.int64)))
num_batch_dims = len(indices_aval.shape) - 1

upper_bound = operand_dims[intarray(dnums.start_index_map)]
Expand Down Expand Up @@ -6254,15 +6263,15 @@ def _select_and_gather_add_shape_rule(
window_dilation)

_UINT_DTYPES = {
16: np.uint16,
32: np.uint32,
64: np.uint64,
16: np.dtype(np.uint16),
32: np.dtype(np.uint32),
64: np.dtype(np.uint64),
}

_INT_DTYPES = {
16: np.int16,
32: np.int32,
64: np.int64,
16: np.dtype(np.int16),
32: np.dtype(np.int32),
64: np.dtype(np.int64),
}

def _select_and_gather_add_translation(
Expand All @@ -6272,7 +6281,7 @@ def _select_and_gather_add_translation(
c = ctx.builder
tangents_aval, operand_aval, = avals_in
dtype = operand_aval.dtype
etype = xla_client.dtype_to_etype(dtype)
etype = xla.dtype_to_primitive_type(dtype)
nbits = dtypes.finfo(dtype).bits

assert nbits <= max_bits
Expand All @@ -6287,8 +6296,8 @@ def _select_and_gather_add_translation(
# 2k-bit unsigned integer using bit tricks.
word_dtype = _UINT_DTYPES[nbits]
double_word_dtype = _UINT_DTYPES[nbits * 2]
word_type = xla_client.dtype_to_etype(word_dtype)
double_word_type = xla_client.dtype_to_etype(double_word_dtype)
word_type = xla.dtype_to_primitive_type(word_dtype)
double_word_type = xla.dtype_to_primitive_type(double_word_dtype)

# Packs two values into a tuple.
def pack(a, b):
Expand Down Expand Up @@ -6323,7 +6332,7 @@ def snd(t):
nmant = r_nbits - nexp - 1

double_word_dtype = word_dtype = _UINT_DTYPES[nbits]
word_type = xla_client.dtype_to_etype(word_dtype)
word_type = xla.dtype_to_primitive_type(word_dtype)

# Packs two values into a tuple.
def pack(a, b):
Expand Down Expand Up @@ -6497,7 +6506,7 @@ def _float_to_int_for_sort(x):
signed = bitcast_convert_type(x, signed_dtype)
unsigned = bitcast_convert_type(x, unsigned_dtype)
flipped = bitcast_convert_type(
sub(unsigned_dtype(np.iinfo(signed_dtype).max), unsigned), signed_dtype)
sub(unsigned_dtype.type(np.iinfo(signed_dtype).max), unsigned), signed_dtype)
return select(lt(signed, _zero(signed)), flipped, signed)

# Default comparator that sorts the operands lexicographically on the
Expand Down Expand Up @@ -6845,22 +6854,22 @@ def _rng_bit_generator_translation_rule(
# TODO(mattjj): the BitcastConvertType segfaults on GPU
# TODO(mattjj): remove fallback when minimum jaxlib is 0.1.72 or newer
if jaxlib_version >= (0, 1, 72) and not backend_is_gpu:
u64_etype = xc.dtype_to_etype(dtypes.dtype('uint64'))
u64_etype = xla.dtype_to_primitive_type(dtypes.dtype('uint64'))
key = xops.BitcastConvertType(xops.Reshape(key, (2, 2)), u64_etype)
else:
key = _convert_4xU32_to_2xU64_without_bitcast(c, key)
out_key, out_vals = xla.xla_destructure(
c, xops.RngBitGenerator(algorithm, key, xla_shape))
if key_dtype == dtypes.dtype('uint32'):
if jaxlib_version >= (0, 1, 72) and not backend_is_gpu:
u32_etype = xc.dtype_to_etype(dtypes.dtype('uint32'))
u32_etype = xla.dtype_to_primitive_type(dtypes.dtype('uint32'))
out_key = xops.Reshape(xops.BitcastConvertType(out_key, u32_etype), (4,))
else:
out_key = _convert_2xU64_to_4xU32_without_bitcast(c, out_key)
return [out_key, out_vals]

def _convert_4xU32_to_2xU64_without_bitcast(c, key):
u64_etype = xc.dtype_to_etype(dtypes.dtype('uint64'))
u64_etype = xla.dtype_to_primitive_type(dtypes.dtype('uint64'))
new_key = xb.constant(c, np.zeros(2, dtype=np.dtype('uint64')),
canonicalize_types=False)
_32 = xb.constant(c, np.uint64(32), canonicalize_types=False)
Expand All @@ -6872,7 +6881,7 @@ def _convert_4xU32_to_2xU64_without_bitcast(c, key):
return new_key

def _convert_2xU64_to_4xU32_without_bitcast(c, key):
u32_etype = xc.dtype_to_etype(dtypes.dtype('uint32'))
u32_etype = xla.dtype_to_primitive_type(dtypes.dtype('uint32'))
new_key = xb.constant(c, np.zeros(4, dtype=np.dtype('uint32')))
_32 = xb.constant(c, np.uint64(32), canonicalize_types=False)
for i in [0, 1]:
Expand Down Expand Up @@ -6937,7 +6946,7 @@ def _iota_abstract_eval(*, dtype, shape, dimension):

def _iota_translation_rule(ctx, avals_in, avals_out, *, dtype, shape,
dimension):
etype = xla_client.dtype_to_etype(dtype)
etype = xla.dtype_to_primitive_type(dtype)
xla_shape = xc.Shape.array_shape(etype, shape)
return [xops.Iota(ctx.builder, xla_shape, dimension)]

Expand Down
3 changes: 2 additions & 1 deletion jax/_src/lax/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1325,7 +1325,8 @@ def _build_axis_index_lowering(c, axis_name, axis_env):
dtype=np.uint32))
mod = xb.constant(c, np.array(axis_env.sizes[axis_pos], dtype=np.uint32))
unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))
return xops.ConvertElementType(
unsigned_index, xla.dtype_to_primitive_type(np.dtype(np.int32)))

def _axis_index_translation_rule(ctx, avals_in, avals_out, *, axis_name):
return [_build_axis_index_lowering(ctx.builder, axis_name, ctx.axis_env)]
Expand Down
12 changes: 0 additions & 12 deletions jax/_src/lib/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,18 +427,6 @@ def host_ids(backend=None):

### utility functions

@util.memoize
def dtype_to_etype(dtype):
"""Convert from dtype to canonical etype (reading config.x64_enabled)."""
return xla_client.dtype_to_etype(dtypes.canonicalize_dtype(dtype))


@util.memoize
def supported_numpy_dtypes():
return {dtypes.canonicalize_dtype(dtype)
for dtype in xla_client.XLA_ELEMENT_TYPE_TO_DTYPE.values()}


# TODO(mattjj,frostig): try to remove this function
def normalize_to_xla_dtypes(val):
"""Normalize dtypes in a value."""
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/djax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1357,7 +1357,7 @@ def _iota_translation_rule(c, dims, avals, operands, *, size=None):
shape = aval.shape
else:
shape = ()
etype = xc.dtype_to_etype(np.dtype('int32'))
etype = xla.dtype_to_primitive_type(np.dtype('int32'))
xla_shape = xc.Shape.array_shape(etype, (*shape, size))
return [[xops.Iota(c, xla_shape, len(shape))]]
translations[iota_p] = _iota_translation_rule
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/call_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def post_process_result(idx: int, res_aval: core.ShapedArray, res_shape: xla.Xla
if res_aval.dtype != res_shape.numpy_dtype():
res_op = xops.ConvertElementType(
res_op,
new_element_type=xla_client.dtype_to_etype(res_aval.dtype))
new_element_type=xla.dtype_to_primitive_type(res_aval.dtype))
return res_op

results = [
Expand Down
3 changes: 2 additions & 1 deletion jax/experimental/jax2tf/impl_no_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def error(msg):
raise error("Unimplemented support for batch_group_count != 1 "
f"(found {batch_group_count})")

if preferred_element_type is not None and preferred_element_type != lhs.dtype:
if (preferred_element_type is not None and
preferred_element_type != lhs.dtype.as_numpy_dtype):
raise error("Unimplemented support for preferred_element_type")

lhs, rhs = _transpose_for_tf_conv(lhs, rhs, dimension_numbers)
Expand Down
8 changes: 5 additions & 3 deletions jax/experimental/jax2tf/tests/primitive_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,8 @@ def _make_convert_element_type_harness(name,
"convert_element_type",
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_olddtype={jtu.dtype_str(dtype)}_newdtype={jtu.dtype_str(new_dtype)}",
lambda arg: (lax.convert_element_type_p.bind(
arg, new_dtype=new_dtype, weak_type=False)), [RandArg(shape, dtype)],
arg, new_dtype=np.dtype(new_dtype), weak_type=False)),
[RandArg(shape, dtype)],
shape=shape,
dtype=dtype,
new_dtype=new_dtype)
Expand Down Expand Up @@ -660,7 +661,8 @@ def _make_bitcast_convert_type_harness(name,
define(
"bitcast_convert_type",
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_newdtype={np.dtype(new_dtype).name}",
lambda x: (lax.bitcast_convert_type_p.bind(x, new_dtype=new_dtype)),
lambda x: lax.bitcast_convert_type_p.bind(x,
new_dtype=np.dtype(new_dtype)),
[RandArg(shape, dtype)],
shape=shape,
dtype=dtype,
Expand Down Expand Up @@ -856,7 +858,7 @@ def _make_iota_harness(name, *, shape=(2, 3), dtype=np.float32, dimension=0):
lax.iota_p,
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_dimension={dimension}",
lambda dtype, shape, dim:
(lax.iota_p.bind(dtype=dtype, shape=shape, dimension=dim)),
(lax.iota_p.bind(dtype=np.dtype(dtype), shape=shape, dimension=dim)),
[StaticArg(dtype),
StaticArg(shape),
StaticArg(dimension)],
Expand Down
6 changes: 4 additions & 2 deletions jax/experimental/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1396,7 +1396,8 @@ def _xla_untile(c, axis_env, x, out_axes, axis_sizes, backend):
convert_bool = (np.issubdtype(x_dtype, np.bool_)
and xb.get_backend(backend).platform in ('cpu', 'gpu'))
if convert_bool:
x = xops.ConvertElementType(x, xb.dtype_to_etype(np.float32))
x = xops.ConvertElementType(
x, xla.dtype_to_primitive_type(np.dtype(np.float32)))

tile_shape = list(xla_shape.dimensions())
shape = list(tile_shape)
Expand All @@ -1413,7 +1414,8 @@ def _xla_untile(c, axis_env, x, out_axes, axis_sizes, backend):
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
if convert_bool:
nonzero = xops.Ne(out, xb.constant(c, np.array(0, dtype=np.float32)))
out = xops.ConvertElementType(nonzero, xb.dtype_to_etype(np.bool_))
out = xops.ConvertElementType(
nonzero, xla.dtype_to_primitive_type(np.dtype(np.bool_)))
return out

def _xmap_translation_rule_spmd(c, axis_env,
Expand Down
6 changes: 4 additions & 2 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -1339,7 +1339,8 @@ def _xla_unshard(c, aval, axis_env, out_axis, x, backend):
convert_bool = (np.issubdtype(aval.dtype, np.bool_)
and xb.get_backend(backend).platform in ('cpu', 'gpu'))
if convert_bool:
x = xops.ConvertElementType(x, xb.dtype_to_etype(np.float32))
x = xops.ConvertElementType(
x, xla.dtype_to_primitive_type(np.dtype(np.float32)))

xla_shape = c.get_shape(x)
dims = list(xla_shape.dimensions())
Expand All @@ -1360,7 +1361,8 @@ def _xla_unshard(c, aval, axis_env, out_axis, x, backend):
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
if convert_bool:
nonzero = xops.Ne(out, xb.constant(c, np.array(0, dtype=np.float32)))
out = xops.ConvertElementType(nonzero, xb.dtype_to_etype(np.bool_))
out = xops.ConvertElementType(
nonzero, xla.dtype_to_primitive_type(np.dtype(np.bool_)))
return out
else:
raise TypeError((aval, c.get_shape(x)))
Expand Down
28 changes: 28 additions & 0 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,34 @@ def make_op_metadata(primitive: core.Primitive,

### handlers

_dtype_to_primitive_type: Dict[np.dtype, xc.PrimitiveType] = {
np.dtype('bool'): xc.PrimitiveType.PRED,
np.dtype('int8'): xc.PrimitiveType.S8,
np.dtype('int16'): xc.PrimitiveType.S16,
np.dtype('int32'): xc.PrimitiveType.S32,
np.dtype('int64'): xc.PrimitiveType.S64,
np.dtype('uint8'): xc.PrimitiveType.U8,
np.dtype('uint16'): xc.PrimitiveType.U16,
np.dtype('uint32'): xc.PrimitiveType.U32,
np.dtype('uint64'): xc.PrimitiveType.U64,
np.dtype(dtypes.bfloat16): xc.PrimitiveType.BF16,
np.dtype('float16'): xc.PrimitiveType.F16,
np.dtype('float32'): xc.PrimitiveType.F32,
np.dtype('float64'): xc.PrimitiveType.F64,
np.dtype('complex64'): xc.PrimitiveType.C64,
np.dtype('complex128'): xc.PrimitiveType.C128,
}

def dtype_to_primitive_type(dtype: np.dtype) -> xc.PrimitiveType:
"""Converts a NumPy dtype into an XLA PrimitiveType."""
# Many things (e.g., strings, scalar types) can be compared with NumPy dtypes,
# but may not hash correctly. Make sure we have a true np.dtype.
assert isinstance(dtype, np.dtype), type(dtype)
try:
return _dtype_to_primitive_type[dtype]
except KeyError as err:
raise TypeError(f"No XLA lowering for NumPy dtype: {dtype}") from err

xb.register_constant_handler(core.Unit, lambda c, *_: _make_unit_constant(c))

def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[XlaShape]:
Expand Down

0 comments on commit 185d7a9

Please sign in to comment.