Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance of ReservedSpace #260

Merged
merged 5 commits into from
May 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions heed/src/reserved_space.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ impl ReservedSpace<'_> {
}

/// The total number of bytes that this memory buffer has.
#[inline]
pub fn size(&self) -> usize {
self.bytes.len()
}

/// The remaining number of bytes that this memory buffer has.
#[inline]
pub fn remaining(&self) -> usize {
self.bytes.len() - self.write_head
}
Expand All @@ -47,6 +49,7 @@ impl ReservedSpace<'_> {
/// serialization. For example, this method can be used to serialize a value, then compute a
/// checksum over the bytes, and then write that checksum to a header at the start of the
/// reserved space.
#[inline]
pub fn written_mut(&mut self) -> &mut [u8] {
let ptr = self.bytes.as_mut_ptr();
let len = self.written;
Expand All @@ -62,6 +65,7 @@ impl ReservedSpace<'_> {
///
/// After calling this function, the entire space is considered to be filled and any
/// further attempt to [`write`](std::io::Write::write) anything else will fail.
#[inline]
pub fn fill_zeroes(&mut self) {
self.bytes[self.write_head..].fill(MaybeUninit::new(0));
self.written = self.bytes.len();
Expand All @@ -79,6 +83,7 @@ impl ReservedSpace<'_> {
/// As the memory comes from within the database itself, the bytes may not yet be
/// initialized. Thus, it is up to the caller to ensure that only initialized memory is read
/// (ensured by the [`MaybeUninit`] API).
#[inline]
pub fn as_uninit_mut(&mut self) -> &mut [MaybeUninit<u8>] {
self.bytes
}
Expand All @@ -89,6 +94,7 @@ impl ReservedSpace<'_> {
/// # Safety
///
/// The caller guarantees that all bytes in the range have been initialized.
#[inline]
pub unsafe fn assume_written(&mut self, len: usize) {
debug_assert!(len <= self.bytes.len());
self.written = len;
Expand All @@ -97,24 +103,33 @@ impl ReservedSpace<'_> {
}

impl io::Write for ReservedSpace<'_> {
#[inline]
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if self.remaining() >= buf.len() {
let dest = unsafe { self.bytes.as_mut_ptr().add(self.write_head) };
unsafe { buf.as_ptr().copy_to_nonoverlapping(dest.cast(), buf.len()) };
self.write_head += buf.len();
self.written = usize::max(self.written, self.write_head);
Ok(buf.len())
} else {
Err(io::Error::from(io::ErrorKind::WriteZero))
}
self.write_all(buf)?;
Ok(buf.len())
}

#[inline]
fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
let count = self.write(buf)?;
debug_assert_eq!(count, buf.len());
let remaining = unsafe { self.bytes.get_unchecked_mut(self.write_head..) };

if buf.len() > remaining.len() {
return Err(io::Error::from(io::ErrorKind::WriteZero));
}

unsafe {
// SAFETY: we can always cast `T` -> `MaybeUninit<T>` as it's a transparent wrapper
let buf_uninit = std::slice::from_raw_parts(buf.as_ptr().cast(), buf.len());
remaining.as_mut_ptr().copy_from_nonoverlapping(buf_uninit.as_ptr(), buf.len());
}

self.write_head += buf.len();
self.written = usize::max(self.written, self.write_head);
Kerollmops marked this conversation as resolved.
Show resolved Hide resolved

Ok(())
}

#[inline(always)]
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
Expand All @@ -125,6 +140,7 @@ impl io::Write for ReservedSpace<'_> {
/// May only seek within the previously written space.
/// Attempts to do otherwise will result in an error.
impl io::Seek for ReservedSpace<'_> {
#[inline]
fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
let (base, offset) = match pos {
io::SeekFrom::Start(start) => (start, 0),
Expand All @@ -151,11 +167,13 @@ impl io::Seek for ReservedSpace<'_> {
Ok(new_pos)
}

#[inline]
fn rewind(&mut self) -> io::Result<()> {
self.write_head = 0;
Ok(())
}

#[inline]
fn stream_position(&mut self) -> io::Result<u64> {
Ok(self.write_head as u64)
}
Expand Down
Loading