Skip to content

Commit

Permalink
Remove code that guards against int4 not being available.
Browse files Browse the repository at this point in the history
JAX depends on ml_dtypes >= 0.2, and int4 was added in ml_dtypes 0.2.

PiperOrigin-RevId: 607345852
  • Loading branch information
hawkinsp authored and jax authors committed Feb 15, 2024
1 parent 6a98382 commit c55f187
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 50 deletions.
9 changes: 2 additions & 7 deletions jax/_src/abstract_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,12 @@
raise_to_shaped = core.raise_to_shaped

numpy_scalar_types: set[type] = { # pylint: disable=g-bare-generic
np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64,
dtypes.int4, np.int8, np.int16, np.int32, np.int64,
dtypes.uint4, np.uint8, np.uint16, np.uint32, np.uint64,
np.complex64, np.complex128,
np.bool_, np.longlong, np.intc,
} | {np.dtype(dt).type for dt in dtypes._float_types}

if dtypes.int4 is not None:
numpy_scalar_types.add(dtypes.int4)
if dtypes.uint4 is not None:
numpy_scalar_types.add(dtypes.uint4)

array_types: set[type] = {np.ndarray} | numpy_scalar_types # pylint: disable=g-bare-generic

def canonical_concrete_aval(val, weak_type=None):
Expand Down
52 changes: 16 additions & 36 deletions jax/_src/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,34 +379,20 @@ def issubdtype(a: DTypeLike | None, b: DTypeLike | None) -> bool:
_signed_types: list[JAXType]
_unsigned_types: list[JAXType]
_int_types: list[JAXType]
if int4 is not None:
_unsigned_types = [
np.dtype(uint4),
np.dtype('uint8'),
np.dtype('uint16'),
np.dtype('uint32'),
np.dtype('uint64'),
]
_signed_types = [
np.dtype(int4),
np.dtype('int8'),
np.dtype('int16'),
np.dtype('int32'),
np.dtype('int64'),
]
else:
_unsigned_types = [
np.dtype('uint8'),
np.dtype('uint16'),
np.dtype('uint32'),
np.dtype('uint64'),
]
_signed_types = [
np.dtype('int8'),
np.dtype('int16'),
np.dtype('int32'),
np.dtype('int64'),
]
_unsigned_types = [
np.dtype(uint4),
np.dtype('uint8'),
np.dtype('uint16'),
np.dtype('uint32'),
np.dtype('uint64'),
]
_signed_types = [
np.dtype(int4),
np.dtype('int8'),
np.dtype('int16'),
np.dtype('int32'),
np.dtype('int64'),
]

_int_types = _unsigned_types + _signed_types

Expand Down Expand Up @@ -493,10 +479,7 @@ def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> dict[JAXType, lis
This DAG maps each type to its immediately higher type on the lattice.
"""
b1, = _bool_types
if int4 is not None:
_uint4, u1, u2, u4, u8, _int4, i1, i2, i4, i8 = _int_types # pytype: disable=bad-unpacking
else:
u1, u2, u4, u8, i1, i2, i4, i8 = _int_types # pytype: disable=bad-unpacking
_uint4, u1, u2, u4, u8, _int4, i1, i2, i4, i8 = _int_types
*f1_types, bf, f2, f4, f8 = _float_types
c4, c8 = _complex_types
i_, f_, c_ = _weak_types
Expand Down Expand Up @@ -740,10 +723,7 @@ def check_user_dtype_supported(dtype, fun_name=None):
if isinstance(dtype, type) and dtype in {bool, int, float, builtins.complex}:
return
np_dtype = np.dtype(dtype)
if int4 is not None:
is_custom_dtype = np_dtype.type in [*_custom_float_scalar_types, int4, uint4]
else:
is_custom_dtype = np_dtype.type in _custom_float_scalar_types
is_custom_dtype = np_dtype.type in [*_custom_float_scalar_types, int4, uint4]
if np_dtype.kind not in "biufc" and not is_custom_dtype:
msg = f"JAX only supports number and bool dtypes, got dtype {dtype}"
msg += f" in {fun_name}" if fun_name else ""
Expand Down
9 changes: 2 additions & 7 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,12 @@ def delegate_lowering(ctx, lowering_fun, *args, **ctx_override_kwargs):
_dtype_to_ir_type : dict[np.dtype, Callable[[], ir.Type]] = {
np.dtype(dtypes.float0): partial(ir.IntegerType.get_signless, 1),
np.dtype(np.bool_): partial(ir.IntegerType.get_signless, 1),
np.dtype(dtypes.int4): partial(ir.IntegerType.get_signless, 4),
np.dtype(np.int8): partial(ir.IntegerType.get_signless, 8),
np.dtype(np.int16): partial(ir.IntegerType.get_signless, 16),
np.dtype(np.int32): partial(ir.IntegerType.get_signless, 32),
np.dtype(np.int64): partial(ir.IntegerType.get_signless, 64),
np.dtype(dtypes.uint4): partial(ir.IntegerType.get_unsigned, 4),
np.dtype(np.uint8): partial(ir.IntegerType.get_unsigned, 8),
np.dtype(np.uint16): partial(ir.IntegerType.get_unsigned, 16),
np.dtype(np.uint32): partial(ir.IntegerType.get_unsigned, 32),
Expand All @@ -167,13 +169,6 @@ def delegate_lowering(ctx, lowering_fun, *args, **ctx_override_kwargs):
np.dtype(np.complex128): lambda: ir.ComplexType.get(ir.F64Type.get()),
}

if dtypes.int4 is not None:
_dtype_to_ir_type.update({
np.dtype(dtypes.int4): partial(ir.IntegerType.get_signless, 4),
np.dtype(dtypes.uint4): partial(ir.IntegerType.get_unsigned, 4),
})


def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type:
if isinstance(dtype, core.bint):
# TODO Support different-size underlying dtypes to take advantage of the
Expand Down

0 comments on commit c55f187

Please sign in to comment.