From d1aad3d9848ec4756011613ff76dbe031f7c531f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 26 Feb 2025 21:17:21 -0800 Subject: [PATCH] [IR] Fix an error when checking for float8_e4m3fnuz type in ir.Tensor The float8_e4m3fnuz type was mistaken with float8_e4m3b11fnuz, which is a different type: https://github.com/jax-ml/ml_dtypes#float8_e4m3b11fnuz --- onnxscript/ir/_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index fb113ee835..ddb0e80309 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -199,7 +199,7 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType) ) if dtype.itemsize == 1 and array.dtype not in ( np.uint8, - ml_dtypes.float8_e4m3b11fnuz, + ml_dtypes.float8_e4m3fnuz, ml_dtypes.float8_e4m3fn, ml_dtypes.float8_e5m2fnuz, ml_dtypes.float8_e5m2,