From 9962065deb4319838e4217e44da9cc9adeebdde2 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 7 Jul 2023 12:07:44 -0700 Subject: [PATCH] Require ml_dtypes>=0.2 --- .pre-commit-config.yaml | 2 +- CHANGELOG.md | 1 + jax/_src/dtypes.py | 55 +++++++++++++++--------------------- jax/_src/numpy/lax_numpy.py | 9 ++---- jax/_src/public_test_util.py | 15 ++++------ jax/numpy/__init__.py | 19 ++----------- jaxlib/setup.py | 2 +- setup.py | 2 +- tests/dtypes_test.py | 5 ++-- 9 files changed, 40 insertions(+), 70 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6293a88466eb..0c76e2d41fb0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,7 +32,7 @@ repos: - id: mypy files: (jax/|tests/typing_test\.py) exclude: jax/_src/basearray.py # Use pyi instead - additional_dependencies: [types-requests==2.29.0, jaxlib==0.4.7, ml_dtypes==0.1.0, numpy==1.24.3, scipy==1.10.1] + additional_dependencies: [types-requests==2.29.0, jaxlib==0.4.7, ml_dtypes==0.2.0, numpy==1.24.3, scipy==1.10.1] args: [--config=pyproject.toml] - repo: https://github.com/mwouts/jupytext diff --git a/CHANGELOG.md b/CHANGELOG.md index c441b7aa50dd..1a2262ba54fb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ Remember to align the itemized text with the first line of an item within a list For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)` * Breaking changes + * JAX now requires ml_dtypes version 0.2.0 or newer. * To fix a corner case, calls to {func}`jax.lax.cond` with five arguments will always resolve to the "common operands" `cond` behavior (as documented) if the second and third arguments are diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 4561cf858585..883e0602c959 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -34,6 +34,15 @@ from jax._src import traceback_util traceback_util.register_exclusion(__file__) +try: + _ml_dtypes_version = tuple(map(int, ml_dtypes.__version__.split('.')[:3])) +except: + pass +else: + if _ml_dtypes_version < (0, 2, 0): + raise ValueError("JAX requires ml_dtypes version 0.2.0 or newer; " + f"installed version is {ml_dtypes.__version__}.") + FLAGS = flags.FLAGS # TODO(frostig,mattjj): achieve this w/ a protocol instead of registry? @@ -43,12 +52,11 @@ def is_opaque_dtype(dtype: Any) -> bool: return type(dtype) in opaque_dtypes # fp8 support -# TODO(jakevdp): remove this if statement when minimum ml_dtypes version > 0.1 -float8_e4m3b11fnuz: Optional[type[np.generic]] = None +float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn float8_e5m2: type[np.generic] = ml_dtypes.float8_e5m2 -_float8_e4m3b11fnuz_dtype: Optional[np.dtype] = None +_float8_e4m3b11fnuz_dtype: np.dtype = np.dtype(float8_e4m3b11fnuz) _float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn) _float8_e5m2_dtype: np.dtype = np.dtype(float8_e5m2) @@ -57,32 +65,24 @@ def is_opaque_dtype(dtype: Any) -> bool: _bfloat16_dtype: np.dtype = np.dtype(bfloat16) _custom_float_scalar_types = [ + float8_e4m3b11fnuz, float8_e4m3fn, float8_e5m2, bfloat16, ] _custom_float_dtypes = [ + _float8_e4m3b11fnuz_dtype, _float8_e4m3fn_dtype, _float8_e5m2_dtype, _bfloat16_dtype, ] -if hasattr(ml_dtypes, "float8_e4m3b11fnuz"): - float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11fnuz - _float8_e4m3b11fnuz_dtype = np.dtype(float8_e4m3b11fnuz) - _custom_float_scalar_types.insert(0, float8_e4m3b11fnuz) # type: ignore[arg-type] - _custom_float_dtypes.insert(0, _float8_e4m3b11fnuz_dtype) # type: ignore[arg-type] +# 4-bit integer support +int4: type[np.generic] = ml_dtypes.int4 +uint4: type[np.generic] = ml_dtypes.uint4 -int4: Optional[type[np.generic]] = None -_int4_dtype: Optional[np.dtype] = None -uint4: Optional[type[np.generic]] = None -_uint4_dtype: Optional[np.dtype] = None - -if hasattr(ml_dtypes, "int4"): - int4 = ml_dtypes.int4 - uint4 = ml_dtypes.uint4 - _int4_dtype = np.dtype(int4) - _uint4_dtype = np.dtype(uint4) +_int4_dtype: np.dtype = np.dtype(int4) +_uint4_dtype: np.dtype = np.dtype(uint4) # Default types. bool_: type = np.bool_ @@ -226,17 +226,8 @@ def coerce_to_array(x: Any, dtype: Optional[DTypeLike] = None) -> np.ndarray: dtype = _scalar_type_to_dtype(type(x), x) return np.asarray(x, dtype) -try: - iinfo = ml_dtypes.iinfo -except AttributeError: - iinfo = np.iinfo - -try: - finfo = ml_dtypes.finfo -except AttributeError as err: - _ml_dtypes_version = getattr(ml_dtypes, "__version__", "") - raise ImportError("JAX requires package ml_dtypes>=0.1.0. " - f"Installed version is {_ml_dtypes_version}.") from err +iinfo = ml_dtypes.iinfo +finfo = ml_dtypes.finfo def _issubclass(a: Any, b: Any) -> bool: """Determines if ``a`` is a subclass of ``b``. @@ -285,13 +276,11 @@ def issubdtype(a: DTypeLike, b: DTypeLike) -> bool: if isinstance(b, np.dtype): return a == b return b in [np.floating, np.inexact, np.number] - # TODO(phawkins): remove the "_int4_dtype is not None" tests after requiring - # an ml_dtypes version that has int4 and uint4. - if _int4_dtype is not None and a == _int4_dtype: + if a == _int4_dtype: if isinstance(b, np.dtype): return a == b return b in [np.signedinteger, np.integer, np.number] - if _uint4_dtype is not None and a == _uint4_dtype: + if a == _uint4_dtype: if isinstance(b, np.dtype): return a == b return b in [np.unsignedinteger, np.integer, np.number] diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 75bcac006092..c7e2c388e784 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -159,22 +159,19 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta: return meta bool_ = _make_scalar_type(np.bool_) -if dtypes.uint4 is not None: - uint4 = _make_scalar_type(dtypes.uint4) +uint4 = _make_scalar_type(dtypes.uint4) uint8 = _make_scalar_type(np.uint8) uint16 = _make_scalar_type(np.uint16) uint32 = _make_scalar_type(np.uint32) uint64 = _make_scalar_type(np.uint64) -if dtypes.int4 is not None: - int4 = _make_scalar_type(dtypes.int4) +int4 = _make_scalar_type(dtypes.int4) int8 = _make_scalar_type(np.int8) int16 = _make_scalar_type(np.int16) int32 = _make_scalar_type(np.int32) int64 = _make_scalar_type(np.int64) float8_e4m3fn = _make_scalar_type(dtypes.float8_e4m3fn) float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2) -if dtypes.float8_e4m3b11fnuz is not None: - float8_e4m3b11fnuz = _make_scalar_type(dtypes.float8_e4m3b11fnuz) +float8_e4m3b11fnuz = _make_scalar_type(dtypes.float8_e4m3b11fnuz) bfloat16 = _make_scalar_type(dtypes.bfloat16) float16 = _make_scalar_type(np.float16) float32 = single = _make_scalar_type(np.float32) diff --git a/jax/_src/public_test_util.py b/jax/_src/public_test_util.py index d715aaeb0cad..c0d6e3754a65 100644 --- a/jax/_src/public_test_util.py +++ b/jax/_src/public_test_util.py @@ -47,14 +47,17 @@ def _dtype(x): _default_tolerance = { _dtypes.float0: 0, np.dtype(np.bool_): 0, + np.dtype(_dtypes.int4): 0, np.dtype(np.int8): 0, np.dtype(np.int16): 0, np.dtype(np.int32): 0, np.dtype(np.int64): 0, + np.dtype(_dtypes.uint4): 0, np.dtype(np.uint8): 0, np.dtype(np.uint16): 0, np.dtype(np.uint32): 0, np.dtype(np.uint64): 0, + np.dtype(_dtypes.float8_e4m3b11fnuz): 1e-1, np.dtype(_dtypes.float8_e4m3fn): 1e-1, np.dtype(_dtypes.float8_e5m2): 1e-1, np.dtype(_dtypes.bfloat16): 1e-2, @@ -75,6 +78,7 @@ def default_tolerance(): default_gradient_tolerance = { + np.dtype(_dtypes.float8_e4m3b11fnuz): 1e-1, np.dtype(_dtypes.float8_e4m3fn): 1e-1, np.dtype(_dtypes.float8_e5m2): 1e-1, np.dtype(_dtypes.bfloat16): 1e-1, @@ -85,11 +89,6 @@ def default_tolerance(): np.dtype(np.complex128): 1e-5, } -# TODO(jakevdp): make this unconditional when ml_dtypes>=0.2 is required -if _dtypes.float8_e4m3b11fnuz is not None: - _default_tolerance[np.dtype(_dtypes.float8_e4m3b11fnuz)] = 1e-1 - default_gradient_tolerance[np.dtype(_dtypes.float8_e4m3b11fnuz)] = 1e-1 - def is_python_scalar(val): return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex)) @@ -97,10 +96,8 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): if a.dtype == b.dtype == _dtypes.float0: np.testing.assert_array_equal(a, b, err_msg=err_msg) return - custom_dtypes = [_dtypes.float8_e4m3fn, _dtypes.float8_e5m2, _dtypes.bfloat16] - # TODO(jakevdp): make this unconditional when ml_dtypes>=0.2 is required - if _dtypes.float8_e4m3b11fnuz is not None: - custom_dtypes.insert(0, _dtypes.float8_e4m3b11fnuz) + custom_dtypes = [_dtypes.float8_e4m3b11fnuz, _dtypes.float8_e4m3fn, + _dtypes.float8_e5m2, _dtypes.bfloat16] a = a.astype(np.float32) if a.dtype in custom_dtypes else a b = b.astype(np.float32) if b.dtype in custom_dtypes else b kw = {} diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index abf43b52d45d..a5b449a2690c 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -107,6 +107,7 @@ float16 as float16, float32 as float32, float64 as float64, + float8_e4m3b11fnuz as float8_e4m3b11fnuz, float8_e4m3fn as float8_e4m3fn, float8_e5m2 as float8_e5m2, float_ as float_, @@ -142,6 +143,7 @@ inf as inf, inner as inner, insert as insert, + int4 as int4, int8 as int8, int16 as int16, int32 as int32, @@ -237,6 +239,7 @@ triu_indices_from as triu_indices_from, trunc as trunc, uint as uint, + uint4 as uint4, uint8 as uint8, uint16 as uint16, uint32 as uint32, @@ -254,22 +257,6 @@ zeros_like as zeros_like, ) -# TODO(phawkins): make this import unconditional after increasing the ml_dtypes -# minimum version. -import jax._src.numpy.lax_numpy -if hasattr(jax._src.numpy.lax_numpy, "int4"): - from jax._src.numpy.lax_numpy import ( - int4 as int4, - uint4 as uint4, - ) -# TODO(jakevdp): make this import unconditional after increasing the minimum -# version for ml_dtypes and jaxlib -if hasattr(jax._src.numpy.lax_numpy, "float8_e4m3b11fnuz"): - from jax._src.numpy.lax_numpy import ( - float8_e4m3b11fnuz as float8_e4m3b11fnuz, - ) - - from jax._src.numpy.index_tricks import ( c_ as c_, index_exp as index_exp, diff --git a/jaxlib/setup.py b/jaxlib/setup.py index e5e53105df15..0e7370847689 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -50,7 +50,7 @@ def has_ext_modules(self): author_email='jax-dev@google.com', packages=['jaxlib', 'jaxlib.xla_extension'], python_requires='>=3.9', - install_requires=['scipy>=1.7', 'numpy>=1.22', 'ml_dtypes>=0.1.0'], + install_requires=['scipy>=1.7', 'numpy>=1.22', 'ml_dtypes>=0.2.0'], extras_require={ 'cuda11_pip': [ "nvidia-cublas-cu11>=11.11", diff --git a/setup.py b/setup.py index f3e1978160ad..e5c0f55643ea 100644 --- a/setup.py +++ b/setup.py @@ -63,7 +63,7 @@ def generate_proto(source): package_data={'jax': ['py.typed', "*.pyi", "**/*.pyi"]}, python_requires='>=3.9', install_requires=[ - 'ml_dtypes>=0.1.0', + 'ml_dtypes>=0.2.0', 'numpy>=1.22', 'opt_einsum', 'scipy>=1.7', diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index 57ae91bd131b..e967c684f3c3 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -59,9 +59,8 @@ float_dtypes = [np.dtype(dtypes.bfloat16)] + np_float_dtypes custom_float_dtypes = [np.dtype(dtypes.bfloat16)] if _fp8_enabled: - fp8_dtypes = [np.dtype(dtypes.float8_e4m3fn), np.dtype(dtypes.float8_e5m2)] - if dtypes.float8_e4m3b11fnuz is not None: - fp8_dtypes += [np.dtype(dtypes.float8_e4m3b11fnuz)] + fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz), np.dtype(dtypes.float8_e4m3fn), + np.dtype(dtypes.float8_e5m2)] float_dtypes += fp8_dtypes custom_float_dtypes += fp8_dtypes