Skip to content

Commit

Permalink
Require ml_dtypes>=0.2
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jul 7, 2023
1 parent 3a0c135 commit 9962065
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 70 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Expand Up @@ -32,7 +32,7 @@ repos:
- id: mypy
files: (jax/|tests/typing_test\.py)
exclude: jax/_src/basearray.py # Use pyi instead
additional_dependencies: [types-requests==2.29.0, jaxlib==0.4.7, ml_dtypes==0.1.0, numpy==1.24.3, scipy==1.10.1]
additional_dependencies: [types-requests==2.29.0, jaxlib==0.4.7, ml_dtypes==0.2.0, numpy==1.24.3, scipy==1.10.1]
args: [--config=pyproject.toml]

- repo: https://github.com/mwouts/jupytext
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -23,6 +23,7 @@ Remember to align the itemized text with the first line of an item within a list
For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)`

* Breaking changes
* JAX now requires ml_dtypes version 0.2.0 or newer.
* To fix a corner case, calls to {func}`jax.lax.cond` with five
arguments will always resolve to the "common operands" `cond`
behavior (as documented) if the second and third arguments are
Expand Down
55 changes: 22 additions & 33 deletions jax/_src/dtypes.py
Expand Up @@ -34,6 +34,15 @@
from jax._src import traceback_util
traceback_util.register_exclusion(__file__)

try:
_ml_dtypes_version = tuple(map(int, ml_dtypes.__version__.split('.')[:3]))
except:
pass
else:
if _ml_dtypes_version < (0, 2, 0):
raise ValueError("JAX requires ml_dtypes version 0.2.0 or newer; "
f"installed version is {ml_dtypes.__version__}.")

FLAGS = flags.FLAGS

# TODO(frostig,mattjj): achieve this w/ a protocol instead of registry?
Expand All @@ -43,12 +52,11 @@ def is_opaque_dtype(dtype: Any) -> bool:
return type(dtype) in opaque_dtypes

# fp8 support
# TODO(jakevdp): remove this if statement when minimum ml_dtypes version > 0.1
float8_e4m3b11fnuz: Optional[type[np.generic]] = None
float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz
float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn
float8_e5m2: type[np.generic] = ml_dtypes.float8_e5m2

_float8_e4m3b11fnuz_dtype: Optional[np.dtype] = None
_float8_e4m3b11fnuz_dtype: np.dtype = np.dtype(float8_e4m3b11fnuz)
_float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn)
_float8_e5m2_dtype: np.dtype = np.dtype(float8_e5m2)

Expand All @@ -57,32 +65,24 @@ def is_opaque_dtype(dtype: Any) -> bool:
_bfloat16_dtype: np.dtype = np.dtype(bfloat16)

_custom_float_scalar_types = [
float8_e4m3b11fnuz,
float8_e4m3fn,
float8_e5m2,
bfloat16,
]
_custom_float_dtypes = [
_float8_e4m3b11fnuz_dtype,
_float8_e4m3fn_dtype,
_float8_e5m2_dtype,
_bfloat16_dtype,
]

if hasattr(ml_dtypes, "float8_e4m3b11fnuz"):
float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11fnuz
_float8_e4m3b11fnuz_dtype = np.dtype(float8_e4m3b11fnuz)
_custom_float_scalar_types.insert(0, float8_e4m3b11fnuz) # type: ignore[arg-type]
_custom_float_dtypes.insert(0, _float8_e4m3b11fnuz_dtype) # type: ignore[arg-type]
# 4-bit integer support
int4: type[np.generic] = ml_dtypes.int4
uint4: type[np.generic] = ml_dtypes.uint4

int4: Optional[type[np.generic]] = None
_int4_dtype: Optional[np.dtype] = None
uint4: Optional[type[np.generic]] = None
_uint4_dtype: Optional[np.dtype] = None

if hasattr(ml_dtypes, "int4"):
int4 = ml_dtypes.int4
uint4 = ml_dtypes.uint4
_int4_dtype = np.dtype(int4)
_uint4_dtype = np.dtype(uint4)
_int4_dtype: np.dtype = np.dtype(int4)
_uint4_dtype: np.dtype = np.dtype(uint4)

# Default types.
bool_: type = np.bool_
Expand Down Expand Up @@ -226,17 +226,8 @@ def coerce_to_array(x: Any, dtype: Optional[DTypeLike] = None) -> np.ndarray:
dtype = _scalar_type_to_dtype(type(x), x)
return np.asarray(x, dtype)

try:
iinfo = ml_dtypes.iinfo
except AttributeError:
iinfo = np.iinfo

try:
finfo = ml_dtypes.finfo
except AttributeError as err:
_ml_dtypes_version = getattr(ml_dtypes, "__version__", "<unknown>")
raise ImportError("JAX requires package ml_dtypes>=0.1.0. "
f"Installed version is {_ml_dtypes_version}.") from err
iinfo = ml_dtypes.iinfo
finfo = ml_dtypes.finfo

def _issubclass(a: Any, b: Any) -> bool:
"""Determines if ``a`` is a subclass of ``b``.
Expand Down Expand Up @@ -285,13 +276,11 @@ def issubdtype(a: DTypeLike, b: DTypeLike) -> bool:
if isinstance(b, np.dtype):
return a == b
return b in [np.floating, np.inexact, np.number]
# TODO(phawkins): remove the "_int4_dtype is not None" tests after requiring
# an ml_dtypes version that has int4 and uint4.
if _int4_dtype is not None and a == _int4_dtype:
if a == _int4_dtype:
if isinstance(b, np.dtype):
return a == b
return b in [np.signedinteger, np.integer, np.number]
if _uint4_dtype is not None and a == _uint4_dtype:
if a == _uint4_dtype:
if isinstance(b, np.dtype):
return a == b
return b in [np.unsignedinteger, np.integer, np.number]
Expand Down
9 changes: 3 additions & 6 deletions jax/_src/numpy/lax_numpy.py
Expand Up @@ -159,22 +159,19 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta:
return meta

bool_ = _make_scalar_type(np.bool_)
if dtypes.uint4 is not None:
uint4 = _make_scalar_type(dtypes.uint4)
uint4 = _make_scalar_type(dtypes.uint4)
uint8 = _make_scalar_type(np.uint8)
uint16 = _make_scalar_type(np.uint16)
uint32 = _make_scalar_type(np.uint32)
uint64 = _make_scalar_type(np.uint64)
if dtypes.int4 is not None:
int4 = _make_scalar_type(dtypes.int4)
int4 = _make_scalar_type(dtypes.int4)
int8 = _make_scalar_type(np.int8)
int16 = _make_scalar_type(np.int16)
int32 = _make_scalar_type(np.int32)
int64 = _make_scalar_type(np.int64)
float8_e4m3fn = _make_scalar_type(dtypes.float8_e4m3fn)
float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2)
if dtypes.float8_e4m3b11fnuz is not None:
float8_e4m3b11fnuz = _make_scalar_type(dtypes.float8_e4m3b11fnuz)
float8_e4m3b11fnuz = _make_scalar_type(dtypes.float8_e4m3b11fnuz)
bfloat16 = _make_scalar_type(dtypes.bfloat16)
float16 = _make_scalar_type(np.float16)
float32 = single = _make_scalar_type(np.float32)
Expand Down
15 changes: 6 additions & 9 deletions jax/_src/public_test_util.py
Expand Up @@ -47,14 +47,17 @@ def _dtype(x):
_default_tolerance = {
_dtypes.float0: 0,
np.dtype(np.bool_): 0,
np.dtype(_dtypes.int4): 0,
np.dtype(np.int8): 0,
np.dtype(np.int16): 0,
np.dtype(np.int32): 0,
np.dtype(np.int64): 0,
np.dtype(_dtypes.uint4): 0,
np.dtype(np.uint8): 0,
np.dtype(np.uint16): 0,
np.dtype(np.uint32): 0,
np.dtype(np.uint64): 0,
np.dtype(_dtypes.float8_e4m3b11fnuz): 1e-1,
np.dtype(_dtypes.float8_e4m3fn): 1e-1,
np.dtype(_dtypes.float8_e5m2): 1e-1,
np.dtype(_dtypes.bfloat16): 1e-2,
Expand All @@ -75,6 +78,7 @@ def default_tolerance():


default_gradient_tolerance = {
np.dtype(_dtypes.float8_e4m3b11fnuz): 1e-1,
np.dtype(_dtypes.float8_e4m3fn): 1e-1,
np.dtype(_dtypes.float8_e5m2): 1e-1,
np.dtype(_dtypes.bfloat16): 1e-1,
Expand All @@ -85,22 +89,15 @@ def default_tolerance():
np.dtype(np.complex128): 1e-5,
}

# TODO(jakevdp): make this unconditional when ml_dtypes>=0.2 is required
if _dtypes.float8_e4m3b11fnuz is not None:
_default_tolerance[np.dtype(_dtypes.float8_e4m3b11fnuz)] = 1e-1
default_gradient_tolerance[np.dtype(_dtypes.float8_e4m3b11fnuz)] = 1e-1

def is_python_scalar(val):
return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex))

def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
if a.dtype == b.dtype == _dtypes.float0:
np.testing.assert_array_equal(a, b, err_msg=err_msg)
return
custom_dtypes = [_dtypes.float8_e4m3fn, _dtypes.float8_e5m2, _dtypes.bfloat16]
# TODO(jakevdp): make this unconditional when ml_dtypes>=0.2 is required
if _dtypes.float8_e4m3b11fnuz is not None:
custom_dtypes.insert(0, _dtypes.float8_e4m3b11fnuz)
custom_dtypes = [_dtypes.float8_e4m3b11fnuz, _dtypes.float8_e4m3fn,
_dtypes.float8_e5m2, _dtypes.bfloat16]
a = a.astype(np.float32) if a.dtype in custom_dtypes else a
b = b.astype(np.float32) if b.dtype in custom_dtypes else b
kw = {}
Expand Down
19 changes: 3 additions & 16 deletions jax/numpy/__init__.py
Expand Up @@ -107,6 +107,7 @@
float16 as float16,
float32 as float32,
float64 as float64,
float8_e4m3b11fnuz as float8_e4m3b11fnuz,
float8_e4m3fn as float8_e4m3fn,
float8_e5m2 as float8_e5m2,
float_ as float_,
Expand Down Expand Up @@ -142,6 +143,7 @@
inf as inf,
inner as inner,
insert as insert,
int4 as int4,
int8 as int8,
int16 as int16,
int32 as int32,
Expand Down Expand Up @@ -237,6 +239,7 @@
triu_indices_from as triu_indices_from,
trunc as trunc,
uint as uint,
uint4 as uint4,
uint8 as uint8,
uint16 as uint16,
uint32 as uint32,
Expand All @@ -254,22 +257,6 @@
zeros_like as zeros_like,
)

# TODO(phawkins): make this import unconditional after increasing the ml_dtypes
# minimum version.
import jax._src.numpy.lax_numpy
if hasattr(jax._src.numpy.lax_numpy, "int4"):
from jax._src.numpy.lax_numpy import (
int4 as int4,
uint4 as uint4,
)
# TODO(jakevdp): make this import unconditional after increasing the minimum
# version for ml_dtypes and jaxlib
if hasattr(jax._src.numpy.lax_numpy, "float8_e4m3b11fnuz"):
from jax._src.numpy.lax_numpy import (
float8_e4m3b11fnuz as float8_e4m3b11fnuz,
)


from jax._src.numpy.index_tricks import (
c_ as c_,
index_exp as index_exp,
Expand Down
2 changes: 1 addition & 1 deletion jaxlib/setup.py
Expand Up @@ -50,7 +50,7 @@ def has_ext_modules(self):
author_email='jax-dev@google.com',
packages=['jaxlib', 'jaxlib.xla_extension'],
python_requires='>=3.9',
install_requires=['scipy>=1.7', 'numpy>=1.22', 'ml_dtypes>=0.1.0'],
install_requires=['scipy>=1.7', 'numpy>=1.22', 'ml_dtypes>=0.2.0'],
extras_require={
'cuda11_pip': [
"nvidia-cublas-cu11>=11.11",
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -63,7 +63,7 @@ def generate_proto(source):
package_data={'jax': ['py.typed', "*.pyi", "**/*.pyi"]},
python_requires='>=3.9',
install_requires=[
'ml_dtypes>=0.1.0',
'ml_dtypes>=0.2.0',
'numpy>=1.22',
'opt_einsum',
'scipy>=1.7',
Expand Down
5 changes: 2 additions & 3 deletions tests/dtypes_test.py
Expand Up @@ -59,9 +59,8 @@
float_dtypes = [np.dtype(dtypes.bfloat16)] + np_float_dtypes
custom_float_dtypes = [np.dtype(dtypes.bfloat16)]
if _fp8_enabled:
fp8_dtypes = [np.dtype(dtypes.float8_e4m3fn), np.dtype(dtypes.float8_e5m2)]
if dtypes.float8_e4m3b11fnuz is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e4m3b11fnuz)]
fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz), np.dtype(dtypes.float8_e4m3fn),
np.dtype(dtypes.float8_e5m2)]
float_dtypes += fp8_dtypes
custom_float_dtypes += fp8_dtypes

Expand Down

0 comments on commit 9962065

Please sign in to comment.