Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 59 additions & 3 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,9 @@ mod varint_util {
io::{self, Error},
};

use serde::Serialize;
use tokio::io::{AsyncRead, AsyncReadExt};
use serde::{de::DeserializeOwned, Serialize};
use smallvec::SmallVec;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};

/// Reads a u64 varint from an AsyncRead source, using the Postcard/LEB128 format.
///
Expand Down Expand Up @@ -291,12 +292,38 @@ mod varint_util {
///
/// If the stream is at the end, this returns `Ok(None)`.
fn read_varint_u64(&mut self) -> impl Future<Output = io::Result<Option<u64>>>;

fn read_length_prefixed<T: DeserializeOwned>(
&mut self,
max_size: usize,
) -> impl Future<Output = io::Result<T>>;
}

impl<T: AsyncRead + Unpin> AsyncReadVarintExt for T {
fn read_varint_u64(&mut self) -> impl Future<Output = io::Result<Option<u64>>> {
read_varint_u64(self)
}

async fn read_length_prefixed<I: DeserializeOwned>(
&mut self,
max_size: usize,
) -> io::Result<I> {
let size = match self.read_varint_u64().await? {
Some(size) => size,
None => return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "EOF reached")),
};

if size > max_size as u64 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Length-prefixed value too large",
));
}

let mut buf = vec![0; size as usize];
self.read_exact(&mut buf).await?;
postcard::from_bytes(&buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
}
}

/// Provides a fn to write a varint to an [`io::Write`] target, as well as a
Expand All @@ -318,9 +345,38 @@ mod varint_util {
write_length_prefixed(self, value)
}
}

/// Provides a fn to write a varint to an [`io::Write`] target, as well as a
/// helper to write a length-prefixed value.
pub trait AsyncWriteVarintExt: AsyncWrite + Unpin {
/// Write a varint
fn write_varint_u64(&mut self, value: u64) -> impl Future<Output = io::Result<usize>>;
/// Write a value with a varint enoded length prefix.
fn write_length_prefixed<T: Serialize>(
&mut self,
value: T,
) -> impl Future<Output = io::Result<usize>>;
}

impl<T: AsyncWrite + Unpin> AsyncWriteVarintExt for T {
async fn write_varint_u64(&mut self, value: u64) -> io::Result<usize> {
let mut buf: SmallVec<[u8; 10]> = Default::default();
write_varint_u64_sync(&mut buf, value).unwrap();
self.write_all(&buf[..]).await?;
Ok(buf.len())
}

async fn write_length_prefixed<V: Serialize>(&mut self, value: V) -> io::Result<usize> {
let mut buf = Vec::new();
write_length_prefixed(&mut buf, value)?;
let size = buf.len();
self.write_all(&buf).await?;
Ok(size)
}
}
}
#[cfg(feature = "rpc")]
pub use varint_util::{AsyncReadVarintExt, WriteVarintExt};
pub use varint_util::{AsyncReadVarintExt, AsyncWriteVarintExt, WriteVarintExt};

mod fuse_wrapper {
use std::{
Expand Down