Skip to content

Commit

Permalink
Support numpy.float16
Browse files Browse the repository at this point in the history
  • Loading branch information
ijl committed Mar 27, 2024
1 parent 56c1a03 commit 1fc3ed8
Show file tree
Hide file tree
Showing 7 changed files with 209 additions and 15 deletions.
11 changes: 11 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ bytecount = { version = "^0.6.7", default_features = false, features = ["runtime
chrono = { version = "=0.4.34", default_features = false }
compact_str = { version = "0.7", default_features = false, features = ["serde"] }
encoding_rs = { version = "0.8", default_features = false }
half = { version = "2", default_features = false, features = ["std"] }
itoa = { version = "1", default_features = false }
itoap = { version = "1", features = ["std", "simd"] }
once_cell = { version = "1", default_features = false, features = ["race"] }
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ JSONEncodeError: Integer exceeds 53-bit range
### numpy

orjson natively serializes `numpy.ndarray` and individual
`numpy.float64`, `numpy.float32`,
`numpy.float64`, `numpy.float32`, `numpy.float16` (`numpy.half`),
`numpy.int64`, `numpy.int32`, `numpy.int16`, `numpy.int8`,
`numpy.uint64`, `numpy.uint32`, `numpy.uint16`, `numpy.uint8`,
`numpy.uintp`, `numpy.intp`, `numpy.datetime64`, and `numpy.bool`
Expand Down
36 changes: 27 additions & 9 deletions script/pynumpy
Original file line number Diff line number Diff line change
Expand Up @@ -23,32 +23,50 @@ os.sched_setaffinity(os.getpid(), {0, 1})

kind = sys.argv[1] if len(sys.argv) >= 1 else ""

if kind == "int32":
array = numpy.random.randint(((2**31) - 1), size=(100000, 100), dtype=numpy.int32)

if kind == "float16":
dtype = numpy.float16
array = numpy.random.random(size=(50000, 100)).astype(dtype)
elif kind == "float32":
dtype = numpy.float32
array = numpy.random.random(size=(50000, 100)).astype(dtype)
elif kind == "float64":
dtype = numpy.float64
array = numpy.random.random(size=(50000, 100))
assert array.dtype == numpy.float64
elif kind == "bool":
dtype = numpy.bool
array = numpy.random.choice((True, False), size=(100000, 200))
elif kind == "int8":
array = numpy.random.randint(((2**7) - 1), size=(100000, 100), dtype=numpy.int8)
dtype = numpy.int8
array = numpy.random.randint(((2**7) - 1), size=(100000, 100), dtype=dtype)
elif kind == "int16":
array = numpy.random.randint(((2**15) - 1), size=(100000, 100), dtype=numpy.int16)
dtype = numpy.int16
array = numpy.random.randint(((2**15) - 1), size=(100000, 100), dtype=dtype)
elif kind == "int32":
array = numpy.random.randint(((2**31) - 1), size=(100000, 100), dtype=numpy.int32)
dtype = numpy.int32
array = numpy.random.randint(((2**31) - 1), size=(100000, 100), dtype=dtype)
elif kind == "uint8":
array = numpy.random.randint(((2**8) - 1), size=(100000, 100), dtype=numpy.uint8)
dtype = numpy.uint8
array = numpy.random.randint(((2**8) - 1), size=(100000, 100), dtype=dtype)
elif kind == "uint16":
array = numpy.random.randint(((2**16) - 1), size=(100000, 100), dtype=numpy.uint16)
dtype = numpy.uint16
array = numpy.random.randint(((2**16) - 1), size=(100000, 100), dtype=dtype)
elif kind == "uint32":
dtype = numpy.uint32
array = numpy.random.randint(((2**31) - 1), size=(100000, 100), dtype=dtype)
else:
print("usage: pynumpy (bool|int16|int32|float64|int8|uint8|uint16)")
print(
"usage: pynumpy (bool|int16|int32|float16|float32|float64|int8|uint8|uint16|uint32)"
)
sys.exit(1)
proc = psutil.Process()


def default(__obj):
if isinstance(__obj, numpy.ndarray):
return __obj.tolist()
raise TypeError


headers = ("Library", "Latency (ms)", "RSS diff (MiB)", "vs. orjson")
Expand Down Expand Up @@ -92,7 +110,7 @@ def per_iter_latency(val):


def test_correctness(func):
return orjson.loads(func()) == array.tolist()
return numpy.array_equal(array, numpy.array(orjson.loads(func()), dtype=dtype))


table = []
Expand Down
69 changes: 69 additions & 0 deletions src/serialize/per_type/numpy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ pub fn is_numpy_scalar(ob_type: *mut PyTypeObject) -> bool {
let scalar_types = unsafe { numpy_types.unwrap().as_ref() };
ob_type == scalar_types.float64
|| ob_type == scalar_types.float32
|| ob_type == scalar_types.float16
|| ob_type == scalar_types.int64
|| ob_type == scalar_types.int16
|| ob_type == scalar_types.int32
Expand Down Expand Up @@ -117,6 +118,7 @@ pub struct PyArrayInterface {
pub enum ItemType {
BOOL,
DATETIME64(NumpyDatetimeUnit),
F16,
F32,
F64,
I8,
Expand All @@ -137,6 +139,7 @@ impl ItemType {
let unit = NumpyDatetimeUnit::from_pyobject(ptr);
Some(ItemType::DATETIME64(unit))
}
(102, 2) => Some(ItemType::F16),
(102, 4) => Some(ItemType::F32),
(102, 8) => Some(ItemType::F64),
(105, 1) => Some(ItemType::I8),
Expand Down Expand Up @@ -312,6 +315,10 @@ impl Serialize for NumpyArray {
NumpyF32Array::new(slice!(self.data() as *const f32, self.num_items()))
.serialize(serializer)
}
ItemType::F16 => {
NumpyF16Array::new(slice!(self.data() as *const u16, self.num_items()))
.serialize(serializer)
}
ItemType::U64 => {
NumpyU64Array::new(slice!(self.data() as *const u64, self.num_items()))
.serialize(serializer)
Expand Down Expand Up @@ -439,6 +446,48 @@ impl Serialize for DataTypeF32 {
}
}

#[repr(transparent)]
struct NumpyF16Array<'a> {
data: &'a [u16],
}

impl<'a> NumpyF16Array<'a> {
fn new(data: &'a [u16]) -> Self {
Self { data }
}
}

impl<'a> Serialize for NumpyF16Array<'a> {
#[cold]
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut seq = serializer.serialize_seq(None).unwrap();
for &each in self.data.iter() {
seq.serialize_element(&DataTypeF16 { obj: each }).unwrap();
}
seq.end()
}
}

#[repr(transparent)]
struct DataTypeF16 {
obj: u16,
}

impl Serialize for DataTypeF16 {
#[cold]
#[cfg_attr(feature = "optimize", optimize(size))]
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let as_f16 = half::f16::from_bits(self.obj);
serializer.serialize_f32(as_f16.to_f32())
}
}

#[repr(transparent)]
struct NumpyU64Array<'a> {
data: &'a [u64],
Expand Down Expand Up @@ -826,6 +875,8 @@ impl Serialize for NumpyScalar {
(*(self.ptr as *mut NumpyFloat64)).serialize(serializer)
} else if ob_type == scalar_types.float32 {
(*(self.ptr as *mut NumpyFloat32)).serialize(serializer)
} else if ob_type == scalar_types.float16 {
(*(self.ptr as *mut NumpyFloat16)).serialize(serializer)
} else if ob_type == scalar_types.int64 {
(*(self.ptr as *mut NumpyInt64)).serialize(serializer)
} else if ob_type == scalar_types.int32 {
Expand Down Expand Up @@ -994,6 +1045,24 @@ impl Serialize for NumpyUint64 {
}
}

#[repr(C)]
pub struct NumpyFloat16 {
ob_refcnt: Py_ssize_t,
ob_type: *mut PyTypeObject,
value: u16,
}

impl Serialize for NumpyFloat16 {
#[cold]
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let as_f16 = half::f16::from_bits(self.value);
serializer.serialize_f32(as_f16.to_f32())
}
}

#[repr(C)]
pub struct NumpyFloat32 {
ob_refcnt: Py_ssize_t,
Expand Down
2 changes: 2 additions & 0 deletions src/typeref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub struct NumpyTypes {
pub array: *mut PyTypeObject,
pub float64: *mut PyTypeObject,
pub float32: *mut PyTypeObject,
pub float16: *mut PyTypeObject,
pub int64: *mut PyTypeObject,
pub int32: *mut PyTypeObject,
pub int16: *mut PyTypeObject,
Expand Down Expand Up @@ -239,6 +240,7 @@ pub fn load_numpy_types() -> Box<Option<NonNull<NumpyTypes>>> {
let numpy_module_dict = PyObject_GenericGetDict(numpy, null_mut());
let types = Box::new(NumpyTypes {
array: look_up_numpy_type(numpy_module_dict, "ndarray\0"),
float16: look_up_numpy_type(numpy_module_dict, "half\0"),
float32: look_up_numpy_type(numpy_module_dict, "float32\0"),
float64: look_up_numpy_type(numpy_module_dict, "float64\0"),
int8: look_up_numpy_type(numpy_module_dict, "int8\0"),
Expand Down
103 changes: 98 additions & 5 deletions test/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@


def numpy_default(obj):
return obj.tolist()
if isinstance(obj, numpy.ndarray):
return obj.tolist()
raise TypeError


@pytest.mark.skipif(numpy is None, reason="numpy is not installed")
Expand Down Expand Up @@ -114,6 +116,94 @@ def test_numpy_array_d1_f32(self):
== b"[1.0,3.4028235e38]"
)

def test_numpy_array_d1_f16(self):
assert (
orjson.dumps(
numpy.array([-1.0, 0.0009765625, 1.0, 65504.0], numpy.float16),
option=orjson.OPT_SERIALIZE_NUMPY,
)
== b"[-1.0,0.0009765625,1.0,65504.0]"
)

def test_numpy_array_f16_roundtrip(self):
ref = [
-1.0,
-2.0,
0.000000059604645,
0.000060975552,
0.00006103515625,
0.0009765625,
0.33325195,
0.99951172,
1.0,
1.00097656,
65504.0,
]
obj = numpy.array(ref, numpy.float16) # type: ignore
serialized = orjson.dumps(
obj,
option=orjson.OPT_SERIALIZE_NUMPY,
)
deserialized = numpy.array(orjson.loads(serialized), numpy.float16) # type: ignore
assert numpy.array_equal(obj, deserialized)

def test_numpy_array_f16_edge(self):
assert (
orjson.dumps(
numpy.array(
[
numpy.inf,
numpy.NINF,
numpy.nan,
numpy.NZERO,
numpy.PZERO,
numpy.pi,
],
numpy.float16,
),
option=orjson.OPT_SERIALIZE_NUMPY,
)
== b"[null,null,null,-0.0,0.0,3.140625]"
)

def test_numpy_array_f32_edge(self):
assert (
orjson.dumps(
numpy.array(
[
numpy.inf,
numpy.NINF,
numpy.nan,
numpy.NZERO,
numpy.PZERO,
numpy.pi,
],
numpy.float32,
),
option=orjson.OPT_SERIALIZE_NUMPY,
)
== b"[null,null,null,-0.0,0.0,3.1415927]"
)

def test_numpy_array_f64_edge(self):
assert (
orjson.dumps(
numpy.array(
[
numpy.inf,
numpy.NINF,
numpy.nan,
numpy.NZERO,
numpy.PZERO,
numpy.pi,
],
numpy.float64,
),
option=orjson.OPT_SERIALIZE_NUMPY,
)
== b"[null,null,null,-0.0,0.0,3.141592653589793]"
)

def test_numpy_array_d1_f64(self):
assert (
orjson.dumps(
Expand Down Expand Up @@ -375,13 +465,10 @@ def test_numpy_array_non_contiguous_message(self):
)

def test_numpy_array_unsupported_dtype(self):
array = numpy.array([[1, 2], [3, 4]], numpy.float16) # type: ignore
array = numpy.array([[1, 2], [3, 4]], numpy.csingle) # type: ignore
with pytest.raises(orjson.JSONEncodeError) as cm:
orjson.dumps(array, option=orjson.OPT_SERIALIZE_NUMPY)
assert "unsupported datatype in numpy array" in str(cm)
assert orjson.dumps(
array, default=numpy_default, option=orjson.OPT_SERIALIZE_NUMPY
) == orjson.dumps(array.tolist())

def test_numpy_array_d1(self):
array = numpy.array([1])
Expand Down Expand Up @@ -602,6 +689,12 @@ def test_numpy_scalar_uint64(self):
== b"18446744073709551615"
)

def test_numpy_scalar_float16(self):
assert (
orjson.dumps(numpy.float16(1.0), option=orjson.OPT_SERIALIZE_NUMPY)
== b"1.0"
)

def test_numpy_scalar_float32(self):
assert (
orjson.dumps(numpy.float32(1.0), option=orjson.OPT_SERIALIZE_NUMPY)
Expand Down

0 comments on commit 1fc3ed8

Please sign in to comment.