From e3a4c9acebbd157cc7d8f47385b564e3bf47f289 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 15 Oct 2024 00:38:54 +0000 Subject: [PATCH 1/6] [IR] Support float4e2m1 --- onnxscript/ir/_core.py | 16 +++++++++++++--- onnxscript/ir/_enums.py | 6 ++++++ onnxscript/ir/_enums_test.py | 2 ++ onnxscript/ir/_type_casting.py | 15 +++++++++++++++ 4 files changed, 36 insertions(+), 3 deletions(-) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 25722d7ba1..478c117e6d 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -70,6 +70,7 @@ _enums.DataType.FLOAT8E5M2FNUZ, _enums.DataType.INT4, _enums.DataType.UINT4, + _enums.DataType.FLOAT4E2M1, ) ) @@ -182,7 +183,7 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType) When the dtype is not one of the numpy native dtypes, the value needs need to be: - ``int8`` or ``uint8`` for int4, with the sign bit extended to 8 bits. - - ``uint8`` for uint4. + - ``uint8`` for uint4 or float4. - ``uint8`` for 8-bit data types. - ``uint16`` for bfloat16 @@ -213,6 +214,11 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType) raise TypeError( f"The numpy array dtype must be uint8 or or ml_dtypes.uint4 (not {array.dtype}) for IR data type {dtype}." ) + if dtype == _enums.DataType.FLOAT4E2M1: + if array.dtype not in (np.uint8, ml_dtypes.float4_e2m1fn): + raise TypeError( + f"The numpy array dtype must be uint8 or ml_dtypes.float4_e2m1fn (not {array.dtype}) for IR data type {dtype}." + ) return try: @@ -256,6 +262,8 @@ def _maybe_view_np_array_with_ml_dtypes( return array.view(ml_dtypes.int4) if dtype == _enums.DataType.UINT4: return array.view(ml_dtypes.uint4) + if dtype == _enums.DataType.FLOAT4E2M1: + return array.view(ml_dtypes.float4_e2m1fn) return array @@ -431,7 +439,7 @@ def tobytes(self) -> bytes: """ # TODO(justinchuby): Support DLPack array = self.numpy() - if self.dtype in {_enums.DataType.INT4, _enums.DataType.UINT4}: + if self.dtype in {_enums.DataType.INT4, _enums.DataType.UINT4, _enums.DataType.FLOAT4E2M1}: # Pack the array into int4 array = _type_casting.pack_int4(array) else: @@ -609,7 +617,7 @@ def _load(self): ) # Handle the byte order correctly by always using little endian dt = np.dtype(self.dtype.numpy()).newbyteorder("<") - if self.dtype in {_enums.DataType.INT4, _enums.DataType.UINT4}: + if self.dtype in {_enums.DataType.INT4, _enums.DataType.UINT4, _enums.DataType.FLOAT4E2M1}: # Use uint8 to read in the full byte. Otherwise ml_dtypes.int4 will clip the values dt = np.dtype(np.uint8).newbyteorder("<") count = self.size // 2 + self.size % 2 @@ -622,6 +630,8 @@ def _load(self): self._array = _type_casting.unpack_int4(self._array, shape) elif self.dtype == _enums.DataType.UINT4: self._array = _type_casting.unpack_uint4(self._array, shape) + elif self.dtype == _enums.DataType.FLOAT4E2M1: + self._array = _type_casting.unpack_float4_e2m1(self._array, shape) else: self._array = self._array.reshape(shape) diff --git a/onnxscript/ir/_enums.py b/onnxscript/ir/_enums.py index d561ad58da..ccb7687a80 100644 --- a/onnxscript/ir/_enums.py +++ b/onnxscript/ir/_enums.py @@ -64,6 +64,7 @@ class DataType(enum.IntEnum): FLOAT8E5M2FNUZ = 20 UINT4 = 21 INT4 = 22 + FLOAT4E2M1 = 23 @classmethod def from_numpy(cls, dtype: np.dtype) -> DataType: @@ -150,5 +151,10 @@ def __str__(self) -> str: np.dtype(ml_dtypes.uint4): DataType.UINT4, } +# TODO(after min req for ml_dtypes>=0.5): Move this inside _NP_TYPE_TO_DATA_TYPE +_NP_TYPE_TO_DATA_TYPE.update( + {np.dtype(ml_dtypes.float4_e2m1fn): DataType.FLOAT4E2M1} if hasattr(ml_dtypes, "float4_e2m1fn") else {} +) + # ONNX DataType to Numpy dtype. _DATA_TYPE_TO_NP_TYPE = {v: k for k, v in _NP_TYPE_TO_DATA_TYPE.items()} diff --git a/onnxscript/ir/_enums_test.py b/onnxscript/ir/_enums_test.py index 6616819205..0721aaa996 100644 --- a/onnxscript/ir/_enums_test.py +++ b/onnxscript/ir/_enums_test.py @@ -32,6 +32,8 @@ def test_enums_are_the_same_as_spec(self): self.assertEqual(_enums.DataType.FLOAT8E5M2FNUZ, onnx.TensorProto.FLOAT8E5M2FNUZ) self.assertEqual(_enums.DataType.UINT4, onnx.TensorProto.UINT4) self.assertEqual(_enums.DataType.INT4, onnx.TensorProto.INT4) + if hasattr(onnx.TensorProto, "FLOAT4E2M1"): + self.assertEqual(_enums.DataType.FLOAT4E2M1, onnx.TensorProto.FLOAT4E2M1) self.assertEqual(_enums.DataType.UNDEFINED, onnx.TensorProto.UNDEFINED) def test_from_numpy_takes_np_dtype_and_returns_data_type(self): diff --git a/onnxscript/ir/_type_casting.py b/onnxscript/ir/_type_casting.py index 3f3611000b..d7392053a8 100644 --- a/onnxscript/ir/_type_casting.py +++ b/onnxscript/ir/_type_casting.py @@ -89,3 +89,18 @@ def unpack_int4( """ unpacked = _unpack_uint4_as_uint8(data, dims) return _extend_int4_sign_bits(unpacked).view(ml_dtypes.int4) + + +def unpack_float4e2m1( + data: npt.NDArray[np.uint8], dims: Sequence[int] +) -> npt.NDArray[ml_dtypes.float4e2m1]: + """Convert a packed float4e2m1 array to unpacked float4e2m1 array. + + Args: + data: A numpy array. + dims: The dimensions are used to reshape the unpacked buffer. + + Returns: + A numpy array of float32 reshaped to dims. + """ + return _unpack_uint4_as_uint8(data, dims).view(ml_dtypes.float4e2m1) From 5184490bf90966977dbb8bc2a1cfed85e1607fe3 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 15 Oct 2024 00:39:10 +0000 Subject: [PATCH 2/6] update --- onnxscript/ir/_core.py | 12 ++++++++++-- onnxscript/ir/_enums.py | 4 +++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 478c117e6d..ae380a1b98 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -439,7 +439,11 @@ def tobytes(self) -> bytes: """ # TODO(justinchuby): Support DLPack array = self.numpy() - if self.dtype in {_enums.DataType.INT4, _enums.DataType.UINT4, _enums.DataType.FLOAT4E2M1}: + if self.dtype in { + _enums.DataType.INT4, + _enums.DataType.UINT4, + _enums.DataType.FLOAT4E2M1, + }: # Pack the array into int4 array = _type_casting.pack_int4(array) else: @@ -617,7 +621,11 @@ def _load(self): ) # Handle the byte order correctly by always using little endian dt = np.dtype(self.dtype.numpy()).newbyteorder("<") - if self.dtype in {_enums.DataType.INT4, _enums.DataType.UINT4, _enums.DataType.FLOAT4E2M1}: + if self.dtype in { + _enums.DataType.INT4, + _enums.DataType.UINT4, + _enums.DataType.FLOAT4E2M1, + }: # Use uint8 to read in the full byte. Otherwise ml_dtypes.int4 will clip the values dt = np.dtype(np.uint8).newbyteorder("<") count = self.size // 2 + self.size % 2 diff --git a/onnxscript/ir/_enums.py b/onnxscript/ir/_enums.py index ccb7687a80..13789d6932 100644 --- a/onnxscript/ir/_enums.py +++ b/onnxscript/ir/_enums.py @@ -153,7 +153,9 @@ def __str__(self) -> str: # TODO(after min req for ml_dtypes>=0.5): Move this inside _NP_TYPE_TO_DATA_TYPE _NP_TYPE_TO_DATA_TYPE.update( - {np.dtype(ml_dtypes.float4_e2m1fn): DataType.FLOAT4E2M1} if hasattr(ml_dtypes, "float4_e2m1fn") else {} + {np.dtype(ml_dtypes.float4_e2m1fn): DataType.FLOAT4E2M1} + if hasattr(ml_dtypes, "float4_e2m1fn") + else {} ) # ONNX DataType to Numpy dtype. From 5321313aabf5a9e9122c9fb33481d564ad74b5b8 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 15 Oct 2024 00:41:47 +0000 Subject: [PATCH 3/6] serde --- onnxscript/ir/serde.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index b454997443..fd072bbe43 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -323,6 +323,8 @@ def numpy(self) -> np.ndarray: return _type_casting.unpack_int4(array.astype(np.uint8), self._proto.dims) elif dtype == _enums.DataType.UINT4: return _type_casting.unpack_uint4(array.astype(np.uint8), self._proto.dims) + elif dtype == _enums.DataType.FLOAT4E2M1: + return _type_casting.unpack_float4e2m1(array.astype(np.uint8), self._proto.dims) else: # Otherwise convert to the correct dtype and reshape # Note we cannot use view() here because the storage dtype may not be the same size as the target @@ -369,6 +371,7 @@ def tobytes(self) -> bytes: _enums.DataType.FLOAT8E5M2FNUZ, _enums.DataType.INT4, _enums.DataType.UINT4, + _enums.DataType.FLOAT4E2M1, }: # uint4 and int4 values are already packed, even when stored as int32 # so we don't need to pack them again From d4088a34a4d0d0cff3df7f8af7463c948354ef6c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 14 Oct 2024 18:34:42 -0700 Subject: [PATCH 4/6] Update _core.py --- onnxscript/ir/_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index ae380a1b98..30d88cef99 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -639,7 +639,7 @@ def _load(self): elif self.dtype == _enums.DataType.UINT4: self._array = _type_casting.unpack_uint4(self._array, shape) elif self.dtype == _enums.DataType.FLOAT4E2M1: - self._array = _type_casting.unpack_float4_e2m1(self._array, shape) + self._array = _type_casting.unpack_float4e2m1(self._array, shape) else: self._array = self._array.reshape(shape) From 5fb716e835109f4b589824551e72b2ed3127063e Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 14 Oct 2024 18:35:13 -0700 Subject: [PATCH 5/6] Update _type_casting.py --- onnxscript/ir/_type_casting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/ir/_type_casting.py b/onnxscript/ir/_type_casting.py index d7392053a8..20bab69037 100644 --- a/onnxscript/ir/_type_casting.py +++ b/onnxscript/ir/_type_casting.py @@ -93,7 +93,7 @@ def unpack_int4( def unpack_float4e2m1( data: npt.NDArray[np.uint8], dims: Sequence[int] -) -> npt.NDArray[ml_dtypes.float4e2m1]: +) -> npt.NDArray[ml_dtypes.float4_e2m1fn]: """Convert a packed float4e2m1 array to unpacked float4e2m1 array. Args: @@ -103,4 +103,4 @@ def unpack_float4e2m1( Returns: A numpy array of float32 reshaped to dims. """ - return _unpack_uint4_as_uint8(data, dims).view(ml_dtypes.float4e2m1) + return _unpack_uint4_as_uint8(data, dims).view(ml_dtypes.float4_e2m1fn) From 0b097169b93eca3db5c89c3fbd53e6d1a2c1d4ba Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 21 Oct 2024 23:46:55 +0000 Subject: [PATCH 6/6] more tests and fixes --- onnxscript/ir/_core_test.py | 36 ++++++++++++++++++++++++++++++++---- onnxscript/ir/_enums.py | 1 + 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/onnxscript/ir/_core_test.py b/onnxscript/ir/_core_test.py index 802bf39deb..0361399084 100644 --- a/onnxscript/ir/_core_test.py +++ b/onnxscript/ir/_core_test.py @@ -55,6 +55,7 @@ def test_init_requires_type_when_value_is_not_np_array(self): ("int4", np.int8, ir.DataType.INT4), ("int4_uint8", np.uint8, ir.DataType.INT4), ("uint4", np.uint8, ir.DataType.UINT4), + ("float4e2m1", np.uint8, ir.DataType.FLOAT4E2M1), ] ) def test_init_with_non_native_numpy_dtype(self, _: str, np_dtype, dtype: ir.DataType): @@ -131,34 +132,48 @@ def test_tobytes(self): tensor = _core.Tensor(torch_tensor, dtype=ir.DataType.FLOAT) self.assertEqual(tensor.tobytes(), array.tobytes()) - def test_tobtyes_returns_packed_data_for_int4(self): + def test_tobytes_returns_packed_data_for_int4(self): array = np.array([-8, -1, 0, 1, 2, 7, 1], dtype=np.int8) # Test odd sized array assert len(array) % 2 == 1 tensor = _core.Tensor(array, dtype=ir.DataType.INT4) self.assertEqual(tensor.tobytes(), b"\xf8\x10r\x01") - def test_tobtyes_returns_packed_data_for_int4_ml_dtypes(self): + def test_tobytes_returns_packed_data_for_int4_ml_dtypes(self): array = np.array([-8, -1, 0, 1, 2, 7, 1], dtype=ml_dtypes.int4) # Test odd sized array assert len(array) % 2 == 1 tensor = _core.Tensor(array, dtype=ir.DataType.INT4) self.assertEqual(tensor.tobytes(), b"\xf8\x10r\x01") - def test_tobtyes_returns_packed_data_for_uint4(self): + def test_tobytes_returns_packed_data_for_uint4(self): array = np.array([0, 1, 2, 7, 15], dtype=np.uint8) # Test odd sized array assert len(array) % 2 == 1 tensor = _core.Tensor(array, dtype=ir.DataType.UINT4) self.assertEqual(tensor.tobytes(), b"\x10r\x0f") - def test_tobtyes_returns_packed_data_for_uint4_ml_dtypes(self): + def test_tobytes_returns_packed_data_for_uint4_ml_dtypes(self): array = np.array([0, 1, 2, 7, 15], dtype=ml_dtypes.uint4) # Test odd sized array assert len(array) % 2 == 1 tensor = _core.Tensor(array, dtype=ir.DataType.UINT4) self.assertEqual(tensor.tobytes(), b"\x10r\x0f") + def test_tobytes_returns_packed_data_for_float4e2m1(self): + array = np.array([0, 1, 2, 7, 15], dtype=np.uint8) + # Test odd sized array + assert len(array) % 2 == 1 + tensor = _core.Tensor(array, dtype=ir.DataType.FLOAT4E2M1) + self.assertEqual(tensor.tobytes(), b"\x10r\x0f") + + def test_tobytes_returns_packed_data_for_float4e2m1_ml_dtypes(self): + array = np.array([0, 1, 2, 7, 15], dtype=np.uint8) + # Test odd sized array + assert len(array) % 2 == 1 + tensor = _core.Tensor(array, dtype=ir.DataType.FLOAT4E2M1) + self.assertEqual(tensor.tobytes(), b"\x10r\x0f") + def test_metadata(self): array = np.random.rand(1, 2).astype(np.float32) tensor = _core.Tensor(array) @@ -444,6 +459,19 @@ def test_external_tensor_complex(self, _: str, np_dtype: np.dtype): # about permission errors del tensor + def test_external_tensor_float4e2m1(self): + expected_array = np.array([0, 1, 2, 7, 15]).view(ml_dtypes.float4_e2m1fn) + tensor_proto = ir.serde.serialize_tensor( + ir.Tensor(expected_array, dtype=ir.DataType.FLOAT4E2M1) + ) + with tempfile.TemporaryDirectory() as temp_dir: + _to_external_tensor(tensor_proto, temp_dir, "tensor.bin") + tensor = ir.serde.deserialize_tensor(tensor_proto, temp_dir) + np.testing.assert_array_equal(tensor.numpy(), expected_array) + # Close the mmap file by deleting the reference to tensor so Windows doesn't complain + # about permission errors + del tensor + def test_external_tensor_empty_tensor(self): expected_array = np.array([], dtype=np.float32) tensor_proto = ir.serde.serialize_tensor(ir.Tensor(expected_array)) diff --git a/onnxscript/ir/_enums.py b/onnxscript/ir/_enums.py index 13789d6932..d0d8c19270 100644 --- a/onnxscript/ir/_enums.py +++ b/onnxscript/ir/_enums.py @@ -122,6 +122,7 @@ def __str__(self) -> str: DataType.FLOAT8E5M2FNUZ: 1, DataType.UINT4: 0.5, DataType.INT4: 0.5, + DataType.FLOAT4E2M1: 0.5, }