Skip to content

Commit

Permalink
numpy i8, u8
Browse files Browse the repository at this point in the history
  • Loading branch information
ijl committed Sep 25, 2020
1 parent a5fb4a7 commit b3b2660
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 54 deletions.
4 changes: 2 additions & 2 deletions README.md
Expand Up @@ -784,8 +784,8 @@ JSONEncodeError: Integer exceeds 53-bit range
### numpy

orjson natively serializes `numpy.ndarray` and individual `numpy.float64`,
`numpy.float32`, `numpy.int64`, `numpy.int32`, `numpy.uint64`, and
`numpy.uint32` instances. Arrays may have a
`numpy.float32`, `numpy.int64`, `numpy.int32`, `numpy.int8`, `numpy.uint64`,
`numpy.uint32`, and `numpy.uint8` instances. Arrays may have a
`dtype` of `numpy.bool`, `numpy.float32`, `numpy.float64`, `numpy.int32`,
`numpy.int64`, `numpy.uint32`, `numpy.uint64`, `numpy.uintp`, or `numpy.intp`.
orjson is faster than all compared libraries at serializing
Expand Down
6 changes: 5 additions & 1 deletion pynumpy
Expand Up @@ -30,8 +30,12 @@ elif kind == "float64":
assert array.dtype == numpy.float64
elif kind == "bool":
array = numpy.random.choice((True, False), size=(100000, 200))
elif kind == "int8":
array = numpy.random.randint(((2 ** 7) - 1), size=(100000, 100), dtype=numpy.int8)
elif kind == "uint8":
array = numpy.random.randint(((2 ** 8) - 1), size=(100000, 100), dtype=numpy.uint8)
else:
print("usage: pynumpy (bool|int32|float64)")
print("usage: pynumpy (bool|int32|float64|int8|uint8)")
sys.exit(1)
proc = psutil.Process()

Expand Down
167 changes: 125 additions & 42 deletions src/serialize/numpy.rs
Expand Up @@ -19,8 +19,10 @@ pub fn is_numpy_scalar(ob_type: *mut PyTypeObject) -> bool {
|| ob_type == scalar_types.float32
|| ob_type == scalar_types.int64
|| ob_type == scalar_types.int32
|| ob_type == scalar_types.int8
|| ob_type == scalar_types.uint64
|| ob_type == scalar_types.uint32
|| ob_type == scalar_types.uint8
}
}

Expand Down Expand Up @@ -61,8 +63,10 @@ pub enum ItemType {
BOOL,
F32,
F64,
I8,
I32,
I64,
U8,
U32,
U64,
}
Expand All @@ -73,44 +77,6 @@ pub enum PyArrayError {
UnsupportedDataType,
}

#[repr(transparent)]
pub struct NumpyScalar {
pub ptr: *mut pyo3::ffi::PyObject,
}

impl NumpyScalar {
pub fn new(ptr: *mut PyObject) -> Self {
NumpyScalar { ptr }
}
}

impl<'p> Serialize for NumpyScalar {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
unsafe {
let ob_type = ob_type!(self.ptr);
let scalar_types = NUMPY_TYPES.deref_mut().as_ref().unwrap();
if ob_type == scalar_types.float64 {
(*(self.ptr as *mut NumpyFloat64)).serialize(serializer)
} else if ob_type == scalar_types.float32 {
(*(self.ptr as *mut NumpyFloat32)).serialize(serializer)
} else if ob_type == scalar_types.int64 {
(*(self.ptr as *mut NumpyInt64)).serialize(serializer)
} else if ob_type == scalar_types.int32 {
(*(self.ptr as *mut NumpyInt32)).serialize(serializer)
} else if ob_type == scalar_types.uint64 {
(*(self.ptr as *mut NumpyUint64)).serialize(serializer)
} else if ob_type == scalar_types.uint32 {
(*(self.ptr as *mut NumpyUint32)).serialize(serializer)
} else {
unreachable!()
}
}
}
}

// >>> arr = numpy.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], numpy.int32)
// >>> arr.ndim
// 3
Expand All @@ -127,6 +93,7 @@ pub struct NumpyArray {
}

impl<'a> NumpyArray {
#[inline(never)]
pub fn new(ptr: *mut PyObject) -> Result<Self, PyArrayError> {
let capsule = ffi!(PyObject_GetAttr(ptr, ARRAY_STRUCT_STR));
let array = unsafe { (*(capsule as *mut PyCapsule)).pointer as *mut PyArrayInterface };
Expand Down Expand Up @@ -176,8 +143,10 @@ impl<'a> NumpyArray {
(098, 1) => Some(ItemType::BOOL),
(102, 4) => Some(ItemType::F32),
(102, 8) => Some(ItemType::F64),
(105, 1) => Some(ItemType::I8),
(105, 4) => Some(ItemType::I32),
(105, 8) => Some(ItemType::I64),
(117, 1) => Some(ItemType::U8),
(117, 4) => Some(ItemType::U32),
(117, 8) => Some(ItemType::U64),
_ => None,
Expand Down Expand Up @@ -237,6 +206,7 @@ impl Drop for NumpyArray {
}

impl<'p> Serialize for NumpyArray {
#[inline(never)]
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
Expand All @@ -248,7 +218,6 @@ impl<'p> Serialize for NumpyArray {
for child in &self.children {
seq.serialize_element(child).unwrap();
}

} else {
let data_ptr = self.data();
let num_items = self.num_items();
Expand Down Expand Up @@ -277,10 +246,16 @@ impl<'p> Serialize for NumpyArray {
seq.serialize_element(&DataTypeI32 { obj: each }).unwrap();
}
}
ItemType::U64 => {
let slice: &[u64] = slice!(data_ptr as *const u64, num_items);
ItemType::I8 => {
let slice: &[i8] = slice!(data_ptr as *const i8, num_items);
for &each in slice.iter() {
seq.serialize_element(&DataTypeU64 { obj: each }).unwrap();
seq.serialize_element(&DataTypeI8 { obj: each }).unwrap();
}
}
ItemType::U8 => {
let slice: &[u8] = slice!(data_ptr as *const u8, num_items);
for &each in slice.iter() {
seq.serialize_element(&DataTypeU8 { obj: each }).unwrap();
}
}
ItemType::U32 => {
Expand All @@ -289,6 +264,12 @@ impl<'p> Serialize for NumpyArray {
seq.serialize_element(&DataTypeU32 { obj: each }).unwrap();
}
}
ItemType::U64 => {
let slice: &[u64] = slice!(data_ptr as *const u64, num_items);
for &each in slice.iter() {
seq.serialize_element(&DataTypeU64 { obj: each }).unwrap();
}
}
ItemType::BOOL => {
let slice: &[u8] = slice!(data_ptr as *const u8, num_items);
for &each in slice.iter() {
Expand Down Expand Up @@ -330,6 +311,20 @@ impl<'p> Serialize for DataTypeF64 {
}
}

#[repr(transparent)]
pub struct DataTypeI8 {
pub obj: i8,
}

impl<'p> Serialize for DataTypeI8 {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_i8(self.obj)
}
}

#[repr(transparent)]
pub struct DataTypeI32 {
pub obj: i32,
Expand Down Expand Up @@ -358,6 +353,20 @@ impl<'p> Serialize for DataTypeI64 {
}
}

#[repr(transparent)]
pub struct DataTypeU8 {
pub obj: u8,
}

impl<'p> Serialize for DataTypeU8 {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_u8(self.obj)
}
}

#[repr(transparent)]
pub struct DataTypeU32 {
pub obj: u32,
Expand Down Expand Up @@ -400,6 +409,64 @@ impl<'p> Serialize for DataTypeBOOL {
}
}

#[repr(transparent)]
pub struct NumpyScalar {
pub ptr: *mut pyo3::ffi::PyObject,
}

impl NumpyScalar {
pub fn new(ptr: *mut PyObject) -> Self {
NumpyScalar { ptr }
}
}

impl<'p> Serialize for NumpyScalar {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
unsafe {
let ob_type = ob_type!(self.ptr);
let scalar_types = NUMPY_TYPES.deref_mut().as_ref().unwrap();
if ob_type == scalar_types.float64 {
(*(self.ptr as *mut NumpyFloat64)).serialize(serializer)
} else if ob_type == scalar_types.float32 {
(*(self.ptr as *mut NumpyFloat32)).serialize(serializer)
} else if ob_type == scalar_types.int64 {
(*(self.ptr as *mut NumpyInt64)).serialize(serializer)
} else if ob_type == scalar_types.int32 {
(*(self.ptr as *mut NumpyInt32)).serialize(serializer)
} else if ob_type == scalar_types.int8 {
(*(self.ptr as *mut NumpyInt8)).serialize(serializer)
} else if ob_type == scalar_types.uint64 {
(*(self.ptr as *mut NumpyUint64)).serialize(serializer)
} else if ob_type == scalar_types.uint32 {
(*(self.ptr as *mut NumpyUint32)).serialize(serializer)
} else if ob_type == scalar_types.uint8 {
(*(self.ptr as *mut NumpyUint8)).serialize(serializer)
} else {
unreachable!()
}
}
}
}

#[repr(C)]
pub struct NumpyInt8 {
pub ob_refcnt: Py_ssize_t,
pub ob_type: *mut PyTypeObject,
pub value: i8,
}

impl<'p> Serialize for NumpyInt8 {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_i8(self.value)
}
}

#[repr(C)]
pub struct NumpyInt32 {
pub ob_refcnt: Py_ssize_t,
Expand Down Expand Up @@ -432,6 +499,22 @@ impl<'p> Serialize for NumpyInt64 {
}
}

#[repr(C)]
pub struct NumpyUint8 {
pub ob_refcnt: Py_ssize_t,
pub ob_type: *mut PyTypeObject,
pub value: u8,
}

impl<'p> Serialize for NumpyUint8 {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_u8(self.value)
}
}

#[repr(C)]
pub struct NumpyUint32 {
pub ob_refcnt: Py_ssize_t,
Expand Down
6 changes: 5 additions & 1 deletion src/typeref.rs
Expand Up @@ -7,13 +7,15 @@ use std::ptr::NonNull;
use std::sync::Once;

pub struct NumpyTypes {
pub array: *mut PyTypeObject,
pub float64: *mut PyTypeObject,
pub float32: *mut PyTypeObject,
pub int64: *mut PyTypeObject,
pub int32: *mut PyTypeObject,
pub int8: *mut PyTypeObject,
pub uint64: *mut PyTypeObject,
pub uint32: *mut PyTypeObject,
pub array: *mut PyTypeObject,
pub uint8: *mut PyTypeObject,
}
pub static mut HASH_SEED: u64 = 0;

Expand Down Expand Up @@ -146,10 +148,12 @@ unsafe fn load_numpy_types() -> Option<NumpyTypes> {
array: look_up_numpy_type(numpy, "ndarray\0")?.as_ptr(),
float32: look_up_numpy_type(numpy, "float32\0")?.as_ptr(),
float64: look_up_numpy_type(numpy, "float64\0")?.as_ptr(),
int8: look_up_numpy_type(numpy, "int8\0")?.as_ptr(),
int32: look_up_numpy_type(numpy, "int32\0")?.as_ptr(),
int64: look_up_numpy_type(numpy, "int64\0")?.as_ptr(),
uint32: look_up_numpy_type(numpy, "uint32\0")?.as_ptr(),
uint64: look_up_numpy_type(numpy, "uint64\0")?.as_ptr(),
uint8: look_up_numpy_type(numpy, "uint8\0")?.as_ptr(),
});
Py_XDECREF(numpy);
types
Expand Down

0 comments on commit b3b2660

Please sign in to comment.