Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions changelog.d/51.improve.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Optimize packing of `bytearray`<ISSUES_LIST>.
By special-casing `bytearray`, we can avoid an allocation and complete extra copy of the data when packing it.
This speeds up packing of `bytearray`s by roughly 1/3.
63 changes: 53 additions & 10 deletions src/codec/packstream/v1/pack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ use std::sync::OnceLock;

use pyo3::exceptions::{PyOverflowError, PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::sync::with_critical_section;
use pyo3::sync::OnceLockExt;
use pyo3::types::{PyBytes, PyDict, PyString, PyType};
use pyo3::types::{PyByteArray, PyBytes, PyDict, PyString, PyTuple, PyType};
use pyo3::{intern, IntoPyObjectExt};

use super::super::Structure;
Expand All @@ -43,6 +44,35 @@ struct TypeMappings {

impl TypeMappings {
fn new(locals: &Bound<PyDict>) -> PyResult<Self> {
/// Remove some byte types from an iterable of types.
/// Types removed are `bytes`, `bytearray`, as those are handled specially in `pack`.
/// If the filtering fails for any reason, it returns the original input.
fn filter_bytes_types(types: Bound<PyAny>) -> Bound<PyAny> {
fn inner<'py>(types: &Bound<'py, PyAny>) -> PyResult<Bound<'py, PyAny>> {
fn is_of_known_bytes_types(typ: &Bound<PyType>) -> PyResult<bool> {
Ok(typ.is_subclass_of::<PyBytes>()? || typ.is_subclass_of::<PyByteArray>()?)
}

let py = types.py();
let types = types
.try_iter()?
.filter(|typ| {
let Ok(typ) = typ else {
return true;
};
let Ok(typ) = typ.downcast::<PyType>() else {
return true;
};
is_of_known_bytes_types(typ).map(|b| !b).unwrap_or(true)
})
.collect::<Result<Vec<_>, _>>()?;

Ok(PyTuple::new(py, types)?.into_any())
}

inner(&types).unwrap_or(types)
}

let py = locals.py();
Ok(Self {
none_values: locals
Expand Down Expand Up @@ -87,12 +117,15 @@ impl TypeMappings {
PyErr::new::<PyValueError, _>("Type mappings are missing MAPPING_TYPES.")
})?
.into_py_any(py)?,
bytes_types: locals
.get_item("BYTES_TYPES")?
.ok_or_else(|| {
PyErr::new::<PyValueError, _>("Type mappings are missing BYTES_TYPES.")
})?
.into_py_any(py)?,
bytes_types: filter_bytes_types(
locals
.get_item("BYTES_TYPES")?
.ok_or_else(|| {
PyErr::new::<PyValueError, _>("Type mappings are missing BYTES_TYPES.")
})?
.into_bound_py_any(py)?,
)
.unbind(),
})
}
}
Expand Down Expand Up @@ -170,8 +203,18 @@ impl<'a> PackStreamEncoder<'a> {
return self.write_string(value.extract::<&str>()?);
}

if value.is_instance(self.type_mappings.bytes_types.bind(py))? {
return self.write_bytes(value.extract::<Cow<[u8]>>()?);
if let Ok(value) = value.downcast::<PyBytes>() {
return self.write_bytes(value.as_bytes());
} else if let Ok(value) = value.downcast::<PyByteArray>() {
return with_critical_section(value, || {
// SAFETY:
// * we're holding the GIL/are attached to the Python interpreter
// * we're using a critical section to ensure exclusive access to the byte array
// * we don't interact with the interpreter/PyO3 APIs while reading the bytes
unsafe { self.write_bytes(value.as_bytes()) }
});
} else if value.is_instance(self.type_mappings.bytes_types.bind(py))? {
return self.write_bytes(&value.extract::<Cow<[u8]>>()?);
}

if value.is_instance(self.type_mappings.sequence_types.bind(py))? {
Expand Down Expand Up @@ -268,7 +311,7 @@ impl<'a> PackStreamEncoder<'a> {
Ok(())
}

fn write_bytes(&mut self, b: Cow<[u8]>) -> PyResult<()> {
fn write_bytes(&mut self, b: &[u8]) -> PyResult<()> {
let size = Self::usize_to_u64(b.len())?;
if size <= 255 {
self.buffer.extend(&[BYTES_8]);
Expand Down