Skip to content

Commit

Permalink
Fix test_util for new float8 type
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jun 8, 2023
1 parent 0ec9f3c commit 8d16519
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion jax/_src/public_test_util.py
Expand Up @@ -85,6 +85,11 @@ def default_tolerance():
np.dtype(np.complex128): 1e-5,
}

# TODO(jakevdp): make this unconditional when ml_dtypes>=0.2 is required
if _dtypes.float8_e4m3b11fnuz is not None:
_default_tolerance[np.dtype(_dtypes.float8_e4m3b11fnuz)] = 1e-1
default_gradient_tolerance[np.dtype(_dtypes.float8_e4m3b11fnuz)] = 1e-1

def is_python_scalar(val):
return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex))

Expand All @@ -93,7 +98,9 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
np.testing.assert_array_equal(a, b, err_msg=err_msg)
return
custom_dtypes = [_dtypes.float8_e4m3fn, _dtypes.float8_e5m2, _dtypes.bfloat16]
custom_dtypes = [_dtypes.bfloat16]
# TODO(jakevdp): make this unconditional when ml_dtypes>=0.2 is required
if _dtypes.float8_e4m3b11fnuz is not None:
custom_dtypes.insert(0, _dtypes.float8_e4m3b11fnuz)
a = a.astype(np.float32) if a.dtype in custom_dtypes else a
b = b.astype(np.float32) if b.dtype in custom_dtypes else b
kw = {}
Expand Down

0 comments on commit 8d16519

Please sign in to comment.