Skip to content

Commit

Permalink
fix bf16 support
Browse files Browse the repository at this point in the history
Signed-off-by: Ian Bearman <ianb@microsoft.com>
  • Loading branch information
manbearian committed May 20, 2022
1 parent 5b1346e commit 0ea5952
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 12 deletions.
28 changes: 25 additions & 3 deletions onnx/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from typing import Text, Sequence, Any, Optional, Dict, Union, TypeVar, Callable, Tuple, List, cast
import numpy as np # type: ignore
import warnings
from cmath import isnan
import struct
import sys

VersionRowType = Union[Tuple[Text, int, int, int], Tuple[Text, int, int, int, int]]
VersionTableType = List[VersionRowType]
Expand Down Expand Up @@ -277,6 +280,24 @@ def split_complex_to_pairs(ca: Sequence[np.complex64]) -> Sequence[int]:
for i in range(len(ca) * 2)]


# convert a f32 to bf16 (as int)
def float32_to_bfloat16(fval: float) -> int:
ival = int.from_bytes(struct.pack('<f', fval), 'little')
if isnan(fval):
# NaN requires at least 1 significand bit set
ival16 = 0x7FC0 # (sign=0, exp=all-ones, significand=0b1000000)
else:
# drop bottom 16-bits
# round remaining bits using round-to-nearest-even
round = ((ival >> 16) & 1) + 0x7fff
ival16 = (ival + round) >> 16
# swap byte order for big-endian
if sys.byteorder == 'big':
bytes = struct.pack('<h', ival16)
ival16 = int.from_bytes(bytes, 'big')
return ival16


def make_tensor(
name: Text,
data_type: int,
Expand Down Expand Up @@ -327,10 +348,11 @@ def make_tensor(
if (data_type == TensorProto.COMPLEX64
or data_type == TensorProto.COMPLEX128):
vals = split_complex_to_pairs(vals)
# floa16/bfloat16 are stored as uint16
elif (data_type == TensorProto.FLOAT16
or data_type == TensorProto.BFLOAT16):
# float16/bfloat16 are stored as uint16
elif data_type == TensorProto.FLOAT16:
vals = np.array(vals).astype(np.float16).view(dtype=np.uint16).flatten().tolist()
elif data_type == TensorProto.BFLOAT16:
vals = list(map(float32_to_bfloat16, np.array(vals).astype(np.float32).flatten().tolist()))
field = mapping.STORAGE_TENSOR_TYPE_TO_FIELD[
mapping.TENSOR_TYPE_TO_STORAGE_TENSOR_TYPE[data_type]]
getattr(tensor, field).extend(vals)
Expand Down
2 changes: 1 addition & 1 deletion onnx/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
int(TensorProto.INT64): np.dtype('int64'),
int(TensorProto.BOOL): np.dtype('bool'),
int(TensorProto.FLOAT16): np.dtype('float16'),
int(TensorProto.BFLOAT16): np.dtype('float16'), # native numpy does not support bfloat16
int(TensorProto.BFLOAT16): np.dtype('uint16'), # native numpy does not support bfloat16
int(TensorProto.DOUBLE): np.dtype('float64'),
int(TensorProto.COMPLEX64): np.dtype('complex64'),
int(TensorProto.COMPLEX128): np.dtype('complex128'),
Expand Down
22 changes: 19 additions & 3 deletions onnx/numpy_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ def combine_pairs_to_complex(fa: Sequence[int]) -> Sequence[np.complex64]:
return [complex(fa[i * 2], fa[i * 2 + 1]) for i in range(len(fa) // 2)]


# convert ndarray of bf16 (as uint32) to f32 (as uint32)
def bfloat16_to_float32(data: np.ndarray, dims) -> np.ndarray:
return np.asarray(list(map(lambda x: x << 16, data)), dtype=np.int32).reshape(dims).view(np.float32)


def to_array(tensor: TensorProto, base_dir: Text = "") -> np.ndarray:
"""Converts a tensor def object to a numpy array.
Expand Down Expand Up @@ -49,19 +54,30 @@ def to_array(tensor: TensorProto, base_dir: Text = "") -> np.ndarray:
if sys.byteorder == 'big':
# Convert endian from little to big
convert_endian(tensor)

# manually convert bf16 since there's no numpy support
if tensor_dtype == TensorProto.BFLOAT16:
data = np.frombuffer(tensor.raw_data, dtype=np.int16)
return bfloat16_to_float32(data, dims)

return np.frombuffer(
tensor.raw_data,
dtype=np_dtype).reshape(dims)
else:
# float16/bfloat16 is stored as int32 (uint16 type); Need view to get the original value
if (tensor_dtype == TensorProto.FLOAT16
or tensor_dtype == TensorProto.BFLOAT16):
# float16 is stored as int32 (uint16 type); Need view to get the original value
if tensor_dtype == TensorProto.FLOAT16:
return (
np.asarray(
tensor.int32_data,
dtype=np.uint16)
.reshape(dims)
.view(np.float16))

# bfloat16 is stored as int32 (uint16 type); no numpy support for bf16
if tensor_dtype == TensorProto.BFLOAT16:
data = np.asarray(tensor.int32_data, dtype=np.int32)
return bfloat16_to_float32(data, dims)

data = getattr(tensor, storage_field)
if (tensor_dtype == TensorProto.COMPLEX64
or tensor_dtype == TensorProto.COMPLEX128):
Expand Down
40 changes: 35 additions & 5 deletions onnx/test/helper_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

import random
import struct

import numpy as np # type: ignore

Expand Down Expand Up @@ -416,7 +417,21 @@ def test_make_float16_tensor_with_raw(self) -> None:
np.testing.assert_equal(np_array, numpy_helper.to_array(tensor))

def test_make_bfloat16_tensor(self) -> None:
np_array = np.random.randn(8, 7).astype(np.float16)
# numpy doesn't support bf16, so we have to compute the correct result manually
# np_array = np.random.randn(2, 3).astype(np.float16)
np_array = np.array([[1.0, 2.0], [3.0, 4.0], [0.099853515625, 0.099365234375], [0.0998535081744, 0.1], [np.nan, np.inf]])
np_results = np.array([
[struct.unpack('!f', bytes.fromhex('3F800000'))[0], # 1.0
struct.unpack('!f', bytes.fromhex('40000000'))[0]], # 2.0
[struct.unpack('!f', bytes.fromhex('40400000'))[0], # 3.0
struct.unpack('!f', bytes.fromhex('40800000'))[0]], # 4.0
[struct.unpack('!f', bytes.fromhex('3DCC0000'))[0], # round-to-nearest-even rounds down (0x8000)
struct.unpack('!f', bytes.fromhex('3DCC0000'))[0]], # round-to-nearest-even rounds up (0x8000)
[struct.unpack('!f', bytes.fromhex('3DCC0000'))[0], # round-to-nearest-even rounds down (0x7fff)
struct.unpack('!f', bytes.fromhex('3DCD0000'))[0]], # round-to-nearest-even rounds up (0xCCCD)
[struct.unpack('!f', bytes.fromhex('7FC00000'))[0], # NaN
struct.unpack('!f', bytes.fromhex('7F800000'))[0]], # inf
])

tensor = helper.make_tensor(
name='test',
Expand All @@ -425,20 +440,35 @@ def test_make_bfloat16_tensor(self) -> None:
vals=np_array
)
self.assertEqual(tensor.name, 'test')
np.testing.assert_equal(np_array, numpy_helper.to_array(tensor))
np.testing.assert_equal(np_results, numpy_helper.to_array(tensor))

def test_make_bfloat16_tensor_with_raw(self) -> None:
np_array = np.random.randn(8, 7).astype(np.float16)
# numpy doesn't support bf16, so we have to compute the correct result manually
# np_array = np.random.randn(8, 7).astype(np.float16)
np_array = np.array([[1.0, 2.0], [3.0, 4.0], [0.099853515625, 0.099365234375], [0.0998535081744, 0.1], [np.nan, np.inf]])
np_results = np.array([
[struct.unpack('!f', bytes.fromhex('3F800000'))[0], # 1.0
struct.unpack('!f', bytes.fromhex('40000000'))[0]], # 2.0
[struct.unpack('!f', bytes.fromhex('40400000'))[0], # 3.0
struct.unpack('!f', bytes.fromhex('40800000'))[0]], # 4.0
[struct.unpack('!f', bytes.fromhex('3DCC0000'))[0], # truncated
struct.unpack('!f', bytes.fromhex('3DCB0000'))[0]], # truncated
[struct.unpack('!f', bytes.fromhex('3DCC0000'))[0], # truncated
struct.unpack('!f', bytes.fromhex('3DCC0000'))[0]], # truncated
[struct.unpack('!f', bytes.fromhex('7FC00000'))[0], # NaN
struct.unpack('!f', bytes.fromhex('7F800000'))[0]], # inf
])

tensor = helper.make_tensor(
name='test',
data_type=TensorProto.BFLOAT16,
dims=np_array.shape,
vals=np_array.view(dtype=np.uint16).flatten().tobytes(),
# write out 16-bit of fp32 to create bf16 using truncation, no rounding
vals=np.array(list(map(lambda x: x >> 16, np_array.astype(np.float32).view(np.uint32).flatten()))).astype(np.uint16).tobytes(),
raw=True
)
self.assertEqual(tensor.name, 'test')
np.testing.assert_equal(np_array, numpy_helper.to_array(tensor))
np.testing.assert_equal(np_results, numpy_helper.to_array(tensor))

def test_make_sparse_tensor(self) -> None:
values = [1.1, 2.2, 3.3, 4.4, 5.5]
Expand Down

0 comments on commit 0ea5952

Please sign in to comment.