From ae6aaff27ced384d1c87ac65618e29ea39baf16a Mon Sep 17 00:00:00 2001 From: Frando Date: Thu, 19 Jun 2025 12:00:56 +0200 Subject: [PATCH 1/5] feat: add a max message size restriction --- src/lib.rs | 121 ++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 97 insertions(+), 24 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 3d594f3..2477da9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -141,7 +141,7 @@ pub mod channel { /// Oneshot channel, similar to tokio's oneshot channel pub mod oneshot { - use std::{fmt::Debug, future::Future, io, pin::Pin, task}; + use std::{fmt::Debug, future::Future, pin::Pin, task}; use n0_future::future::Boxed as BoxFuture; @@ -162,7 +162,7 @@ pub mod channel { /// overhead is negligible. However, boxing can also be used for local communication, /// e.g. when applying a transform or filter to the message before sending it. pub type BoxedSender = - Box BoxFuture> + Send + Sync + 'static>; + Box BoxFuture> + Send + Sync + 'static>; /// A sender that can be wrapped in a `Box>`. /// @@ -172,7 +172,9 @@ pub mod channel { /// Remote receivers are always boxed, since for remote communication the boxing /// overhead is negligible. However, boxing can also be used for local communication, /// e.g. when applying a transform or filter to the message before receiving it. - pub trait DynSender: Future> + Send + Sync + 'static { + pub trait DynSender: + Future> + Send + Sync + 'static + { fn is_rpc(&self) -> bool; } @@ -181,7 +183,7 @@ pub mod channel { /// Remote receivers are always boxed, since for remote communication the boxing /// overhead is negligible. However, boxing can also be used for local communication, /// e.g. when applying a transform or filter to the message before receiving it. - pub type BoxedReceiver = BoxFuture>; + pub type BoxedReceiver = BoxFuture>; /// A oneshot sender. /// @@ -266,7 +268,7 @@ pub mod channel { fn poll(self: Pin<&mut Self>, cx: &mut task::Context) -> task::Poll { match self.get_mut() { Self::Tokio(rx) => Pin::new(rx).poll(cx).map_err(|_| RecvError::SenderClosed), - Self::Boxed(rx) => Pin::new(rx).poll(cx).map_err(RecvError::Io), + Self::Boxed(rx) => Pin::new(rx).poll(cx), } } } @@ -293,7 +295,7 @@ pub mod channel { impl From for Receiver where F: FnOnce() -> Fut, - Fut: Future> + Send + 'static, + Fut: Future> + Send + 'static, { fn from(f: F) -> Self { Self::Boxed(Box::pin(f())) @@ -317,7 +319,7 @@ pub mod channel { /// /// For the rpc case, the send side can not be cloned, hence spsc instead of mpsc. pub mod spsc { - use std::{fmt::Debug, future::Future, io, pin::Pin}; + use std::{fmt::Debug, future::Future, pin::Pin}; use super::{RecvError, SendError}; use crate::RpcMessage; @@ -402,7 +404,7 @@ pub mod channel { fn send( &mut self, value: T, - ) -> Pin> + Send + '_>>; + ) -> Pin> + Send + '_>>; /// Try to send a message, returning as fast as possible if sending /// is not currently possible. @@ -412,7 +414,7 @@ pub mod channel { fn try_send( &mut self, value: T, - ) -> Pin> + Send + '_>>; + ) -> Pin> + Send + '_>>; /// Await the sender close fn closed(&mut self) -> Pin + Send + '_>>; @@ -576,16 +578,26 @@ pub mod channel { /// for local communication. #[error("receiver closed")] ReceiverClosed, + /// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]). + #[error("maximum message size exceeded")] + MaxMessageSizeExceeded, /// The underlying io error. This can occur for remote communication, /// due to a network error or serialization error. #[error("io error: {0}")] Io(#[from] io::Error), } + impl From for SendError { + fn from(value: postcard::Error) -> Self { + Self::Io(io::Error::new(io::ErrorKind::InvalidData, value)) + } + } + impl From for io::Error { fn from(e: SendError) -> Self { match e { SendError::ReceiverClosed => io::Error::new(io::ErrorKind::BrokenPipe, e), + SendError::MaxMessageSizeExceeded => io::Error::new(io::ErrorKind::InvalidData, e), SendError::Io(e) => e, } } @@ -602,6 +614,9 @@ pub mod channel { /// for local communication. #[error("sender closed")] SenderClosed, + /// The message exceeded the maximum allowed message size [`MAX_MESSAGE_SIZE`]. + #[error("maximum message size exceeded")] + MaxMessageSizeExceeded, /// An io error occurred. This can occur for remote communication, /// due to a network error or deserialization error. #[error("io error: {0}")] @@ -613,6 +628,7 @@ pub mod channel { match e { RecvError::Io(e) => e, RecvError::SenderClosed => io::Error::new(io::ErrorKind::BrokenPipe, e), + RecvError::MaxMessageSizeExceeded => io::Error::new(io::ErrorKind::InvalidData, e), } } } @@ -1108,6 +1124,12 @@ pub mod rpc { RequestError, RpcMessage, }; + /// Default max message size (16 MiB). + const MAX_MESSAGE_SIZE: u64 = 1024 * 1024 * 16; + + /// Error code on streams if the max message size was exceeded. + const ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED: u32 = 1; + /// Error that can occur when writing the initial message when doing a /// cross-process RPC. #[derive(Debug, thiserror::Error)] @@ -1115,21 +1137,44 @@ pub mod rpc { /// Error writing to the stream with quinn #[error("error writing to stream: {0}")] Quinn(#[from] quinn::WriteError), + /// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]). + #[error("maximum message size exceeded")] + MaxMessageSizeExceeded, /// Generic IO error, e.g. when serializing the message or when using /// other transports. #[error("error serializing: {0}")] Io(#[from] io::Error), } + impl From for WriteError { + fn from(value: postcard::Error) -> Self { + Self::Io(io::Error::new(io::ErrorKind::InvalidData, value)) + } + } + impl From for io::Error { fn from(e: WriteError) -> Self { match e { WriteError::Io(e) => e, + WriteError::MaxMessageSizeExceeded => io::Error::new(io::ErrorKind::InvalidData, e), WriteError::Quinn(e) => e.into(), } } } + impl From for SendError { + fn from(err: quinn::WriteError) -> Self { + match err { + quinn::WriteError::Stopped(code) + if code == ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into() => + { + SendError::MaxMessageSizeExceeded + } + _ => SendError::Io(io::Error::from(err)), + } + } + } + /// Trait to abstract over a client connection to a remote service. /// /// This isn't really that much abstracted, since the result of open_bi must @@ -1238,6 +1283,9 @@ pub mod rpc { { let RemoteSender(mut send, recv, _) = self; let msg = msg.into(); + if postcard::experimental::serialized_size(&msg)? as u64 > MAX_MESSAGE_SIZE { + return Err(WriteError::MaxMessageSizeExceeded); + } let mut buf = SmallVec::<[u8; 128]>::new(); buf.write_length_prefixed(msg)?; send.write_all(&buf).await?; @@ -1248,17 +1296,24 @@ pub mod rpc { impl From for oneshot::Receiver { fn from(mut read: quinn::RecvStream) -> Self { let fut = async move { - let size = read.read_varint_u64().await?.ok_or(io::Error::new( - io::ErrorKind::UnexpectedEof, - "failed to read size", - ))?; + let size = read + .read_varint_u64() + .await? + .ok_or(RecvError::Io(io::Error::new( + io::ErrorKind::UnexpectedEof, + "failed to read size", + )))?; + if size > MAX_MESSAGE_SIZE { + read.stop(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into()).ok(); + return Err(RecvError::MaxMessageSizeExceeded); + } let rest = read .read_to_end(size as usize) .await .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; let msg: T = postcard::from_bytes(&rest) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - io::Result::Ok(msg) + Ok(msg) }; oneshot::Receiver::from(|| fut) } @@ -1272,9 +1327,10 @@ pub mod rpc { } impl From for spsc::Receiver { - fn from(read: quinn::RecvStream) -> Self { + fn from(recv: quinn::RecvStream) -> Self { spsc::Receiver::Boxed(Box::new(QuinnReceiver { - recv: read, + recv, + buffer: Vec::new(), _marker: PhantomData, })) } @@ -1291,11 +1347,14 @@ pub mod rpc { fn from(mut writer: quinn::SendStream) -> Self { oneshot::Sender::Boxed(Box::new(move |value| { Box::pin(async move { + if postcard::experimental::serialized_size(&value)? as u64 > MAX_MESSAGE_SIZE { + return Err(SendError::MaxMessageSizeExceeded); + } // write via a small buffer to avoid allocation for small values let mut buf = SmallVec::<[u8; 128]>::new(); buf.write_length_prefixed(value)?; writer.write_all(&buf).await?; - io::Result::Ok(()) + Ok(()) }) })) } @@ -1313,6 +1372,7 @@ pub mod rpc { struct QuinnReceiver { recv: quinn::RecvStream, + buffer: Vec, _marker: std::marker::PhantomData, } @@ -1332,11 +1392,17 @@ pub mod rpc { let Some(size) = read.read_varint_u64().await? else { return Ok(None); }; - let mut buf = vec![0; size as usize]; - read.read_exact(&mut buf) + if size > MAX_MESSAGE_SIZE { + self.recv + .stop(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into()) + .ok(); + return Err(RecvError::MaxMessageSizeExceeded); + } + self.buffer.resize(size as usize, 0); + read.read_exact(&mut self.buffer) .await .map_err(|e| io::Error::new(io::ErrorKind::UnexpectedEof, e))?; - let msg: T = postcard::from_bytes(&buf) + let msg: T = postcard::from_bytes(&self.buffer) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; Ok(Some(msg)) }) @@ -1360,9 +1426,14 @@ pub mod rpc { } impl DynSender for QuinnSender { - fn send(&mut self, value: T) -> Pin> + Send + '_>> { + fn send( + &mut self, + value: T, + ) -> Pin> + Send + '_>> { Box::pin(async { - let value = value; + if postcard::experimental::serialized_size(&value)? as u64 > MAX_MESSAGE_SIZE { + return Err(SendError::MaxMessageSizeExceeded); + } self.buffer.clear(); self.buffer.write_length_prefixed(value)?; self.send.write_all(&self.buffer).await?; @@ -1374,10 +1445,12 @@ pub mod rpc { fn try_send( &mut self, value: T, - ) -> Pin> + Send + '_>> { + ) -> Pin> + Send + '_>> { Box::pin(async { // todo: move the non-async part out of the box. Will require a new return type. - let value = value; + if postcard::experimental::serialized_size(&value)? as u64 > MAX_MESSAGE_SIZE { + return Err(SendError::MaxMessageSizeExceeded); + } self.buffer.clear(); self.buffer.write_length_prefixed(value)?; let Some(n) = now_or_never(self.send.write(&self.buffer)) else { From 7fe5de48167dd4b16d41bd56ae7a12fbb70cd9fc Mon Sep 17 00:00:00 2001 From: Frando Date: Thu, 19 Jun 2025 12:12:57 +0200 Subject: [PATCH 2/5] chore: clippy --- src/lib.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 2477da9..74a821e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -232,7 +232,7 @@ pub mod channel { pub async fn send(self, value: T) -> std::result::Result<(), SendError> { match self { Sender::Tokio(tx) => tx.send(value).map_err(|_| SendError::ReceiverClosed), - Sender::Boxed(f) => f(value).await.map_err(SendError::from), + Sender::Boxed(f) => f(value).await, } } } @@ -450,7 +450,7 @@ pub mod channel { Sender::Tokio(tx) => { tx.send(value).await.map_err(|_| SendError::ReceiverClosed) } - Sender::Boxed(sink) => sink.send(value).await.map_err(SendError::from), + Sender::Boxed(sink) => sink.send(value).await, } } @@ -477,7 +477,7 @@ pub mod channel { } Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => Ok(false), }, - Sender::Boxed(sink) => sink.try_send(value).await.map_err(SendError::from), + Sender::Boxed(sink) => sink.try_send(value).await, } } } From 740ec1d42f22138d400b5556f9fb58daf4d5b84c Mon Sep 17 00:00:00 2001 From: Frando Date: Thu, 19 Jun 2025 12:20:38 +0200 Subject: [PATCH 3/5] fix: postcard is only available in rpc module --- src/lib.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 74a821e..9ffaab6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -587,12 +587,6 @@ pub mod channel { Io(#[from] io::Error), } - impl From for SendError { - fn from(value: postcard::Error) -> Self { - Self::Io(io::Error::new(io::ErrorKind::InvalidData, value)) - } - } - impl From for io::Error { fn from(e: SendError) -> Self { match e { @@ -1152,6 +1146,12 @@ pub mod rpc { } } + impl From for SendError { + fn from(value: postcard::Error) -> Self { + Self::Io(io::Error::new(io::ErrorKind::InvalidData, value)) + } + } + impl From for io::Error { fn from(e: WriteError) -> Self { match e { From 058c0e282e6f46263437481a4703792fc194c074 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Thu, 26 Jun 2025 11:57:39 +0300 Subject: [PATCH 4/5] more tests --- src/lib.rs | 24 +++- tests/common.rs | 22 ++++ tests/{mpsc_sender.rs => mpsc_channel.rs} | 60 ++++++++- tests/oneshot_channel.rs | 148 ++++++++++++++++++++++ 4 files changed, 247 insertions(+), 7 deletions(-) create mode 100644 tests/common.rs rename tests/{mpsc_sender.rs => mpsc_channel.rs} (61%) create mode 100644 tests/oneshot_channel.rs diff --git a/src/lib.rs b/src/lib.rs index 4b3e8be..31acfc4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1142,6 +1142,9 @@ pub mod rpc { /// Error code on streams if the max message size was exceeded. const ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED: u32 = 1; + /// Error code on streams if the sender tried to send an message that could not be postcard serialized. + const ERROR_CODE_INVALID_POSTCARD: u32 = 2; + /// Error that can occur when writing the initial message when doing a /// cross-process RPC. #[derive(Debug, thiserror::Error)] @@ -1364,12 +1367,23 @@ pub mod rpc { fn from(mut writer: quinn::SendStream) -> Self { oneshot::Sender::Boxed(Box::new(move |value| { Box::pin(async move { - if postcard::experimental::serialized_size(&value)? as u64 > MAX_MESSAGE_SIZE { + let size = match postcard::experimental::serialized_size(&value) { + Ok(size) => size, + Err(e) => { + writer.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok(); + return Err(SendError::Io(io::Error::new(io::ErrorKind::InvalidData, e))); + } + }; + if size as u64 > MAX_MESSAGE_SIZE { + writer.reset(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into()).ok(); return Err(SendError::MaxMessageSizeExceeded); } // write via a small buffer to avoid allocation for small values let mut buf = SmallVec::<[u8; 128]>::new(); - buf.write_length_prefixed(value)?; + if let Err(e) = buf.write_length_prefixed(value) { + writer.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok(); + return Err(e.into()); + } writer.write_all(&buf).await?; Ok(()) }) @@ -1445,11 +1459,15 @@ pub mod rpc { ) -> Pin> + Send + Sync + '_>> { Box::pin(async { if postcard::experimental::serialized_size(&value)? as u64 > MAX_MESSAGE_SIZE { + self.send.reset(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into()).ok(); return Err(SendError::MaxMessageSizeExceeded); } let value = value; self.buffer.clear(); - self.buffer.write_length_prefixed(value)?; + if let Err(e) = self.buffer.write_length_prefixed(value) { + self.send.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok(); + return Err(e.into()); + } self.send.write_all(&self.buffer).await?; self.buffer.clear(); Ok(()) diff --git a/tests/common.rs b/tests/common.rs new file mode 100644 index 0000000..df51151 --- /dev/null +++ b/tests/common.rs @@ -0,0 +1,22 @@ +use std::{ + io::{self, ErrorKind}, + net::{Ipv4Addr, SocketAddr, SocketAddrV4}, + time::Duration, +}; + +use irpc::{ + channel::{mpsc, RecvError, SendError}, + util::{make_client_endpoint, make_server_endpoint, AsyncWriteVarintExt}, +}; +use quinn::Endpoint; +use testresult::TestResult; +use tokio::{task::JoinHandle, time::timeout}; + +pub fn create_connected_endpoints() -> TestResult<(Endpoint, Endpoint, SocketAddr)> { + let addr = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0).into(); + let (server, cert) = make_server_endpoint(addr)?; + let client = make_client_endpoint(addr, &[cert.as_slice()])?; + let port = server.local_addr()?.port(); + let server_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, port).into(); + Ok((server, client, server_addr)) +} diff --git a/tests/mpsc_sender.rs b/tests/mpsc_channel.rs similarity index 61% rename from tests/mpsc_sender.rs rename to tests/mpsc_channel.rs index e8382bb..21d5261 100644 --- a/tests/mpsc_sender.rs +++ b/tests/mpsc_channel.rs @@ -1,16 +1,16 @@ use std::{ - io::ErrorKind, + io::{self, ErrorKind}, net::{Ipv4Addr, SocketAddr, SocketAddrV4}, time::Duration, }; use irpc::{ - channel::{mpsc, SendError}, - util::{make_client_endpoint, make_server_endpoint}, + channel::{mpsc, RecvError, SendError}, + util::{make_client_endpoint, make_server_endpoint, AsyncWriteVarintExt}, }; use quinn::Endpoint; use testresult::TestResult; -use tokio::time::timeout; +use tokio::{task::JoinHandle, time::timeout}; fn create_connected_endpoints() -> TestResult<(Endpoint, Endpoint, SocketAddr)> { let addr = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0).into(); @@ -115,3 +115,55 @@ async fn mpsc_sender_clone_drop_error() -> TestResult<()> { third_client.await?; Ok(()) } + +/// Checks that the max message size is enforced on the sender side and that errors are propagated to the receiver side. +#[tokio::test] +async fn mpsc_max_message_size_send() -> TestResult<()> { + let (server, client, server_addr) = create_connected_endpoints()?; + let server: JoinHandle> = tokio::spawn(async move { + let conn = server.accept().await.unwrap().await.map_err(|e| RecvError::Io(e.into()))?; + let (_, recv) = conn.accept_bi().await.map_err(|e| RecvError::Io(e.into()))?; + let mut recv = mpsc::Receiver::>::from(recv); + while let Some(_) = recv.recv().await? {} + return Err(RecvError::Io(io::ErrorKind::UnexpectedEof.into())); + }); + let conn = client.connect(server_addr, "localhost")?.await?; + let (send, _) = conn.open_bi().await?; + let send = mpsc::Sender::>::from(send); + // this one should work! + send.send(vec![0u8; 1024 * 1024]).await?; + // this one should fail! + let Err(cause) = send.send(vec![0u8; 1024 * 1024 * 32]).await else { + panic!("client should have failed due to max message size"); + }; + assert!(matches!(cause, SendError::MaxMessageSizeExceeded)); + let Err(cause) = server.await? else { + panic!("server should have failed due to max message size"); + }; + assert!(matches!(cause, RecvError::Io(e) if e.kind() == ErrorKind::ConnectionReset)); + Ok(()) +} + +/// Checks that the max message size is enforced on receiver side. +#[tokio::test] +async fn mpsc_max_message_size_recv() -> TestResult<()> { + let (server, client, server_addr) = create_connected_endpoints()?; + let server: JoinHandle> = tokio::spawn(async move { + let conn = server.accept().await.unwrap().await.map_err(|e| RecvError::Io(e.into()))?; + let (_, recv) = conn.accept_bi().await.map_err(|e| RecvError::Io(e.into()))?; + let mut recv = mpsc::Receiver::>::from(recv); + while let Some(_) = recv.recv().await? {} + return Err(RecvError::Io(io::ErrorKind::UnexpectedEof.into())); + }); + let conn = client.connect(server_addr, "localhost")?.await?; + let (mut send, _) = conn.open_bi().await?; + // this one should work! + send.write_length_prefixed(vec![0u8; 1024 * 1024]).await?; + // this one should fail on receive! + send.write_length_prefixed(vec![0u8; 1024 * 1024 * 32]).await.ok(); + let Err(cause) = server.await? else { + panic!("server should have failed due to max message size"); + }; + assert!(matches!(cause, RecvError::MaxMessageSizeExceeded)); + Ok(()) +} \ No newline at end of file diff --git a/tests/oneshot_channel.rs b/tests/oneshot_channel.rs new file mode 100644 index 0000000..1fe4693 --- /dev/null +++ b/tests/oneshot_channel.rs @@ -0,0 +1,148 @@ +use std::{ + io::{self, ErrorKind}, + net::{Ipv4Addr, SocketAddr, SocketAddrV4}, +}; + +use irpc::{ + channel::{oneshot, RecvError, SendError}, + util::{make_client_endpoint, make_server_endpoint, AsyncWriteVarintExt}, +}; +use quinn::Endpoint; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use testresult::TestResult; +use tokio::task::JoinHandle; + +mod common; +use common::*; + +#[derive(Debug)] +struct NoSer(u64); + +#[derive(Debug, thiserror::Error)] +#[error("Cannot serialize odd number: {0}")] +struct OddNumberError(u64); + +impl Serialize for NoSer { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + if self.0 % 2 == 1 { + Err(serde::ser::Error::custom(OddNumberError(self.0))) + } else { + serializer.serialize_u64(self.0) + } + } +} + +impl<'de> Deserialize<'de> for NoSer { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let value = u64::deserialize(deserializer)?; + if value % 2 != 0 { + Err(serde::de::Error::custom(OddNumberError(value))) + } else { + Ok(NoSer(value)) + } + } +} + +/// Checks that the max message size is enforced on the sender side and that errors are propagated to the receiver side. +#[tokio::test] +async fn oneshot_max_message_size_send() -> TestResult<()> { + let (server, client, server_addr) = create_connected_endpoints()?; + let server: JoinHandle> = tokio::spawn(async move { + let conn = server.accept().await.unwrap().await.map_err(|e| RecvError::Io(e.into()))?; + let (_, recv) = conn.accept_bi().await.map_err(|e| RecvError::Io(e.into()))?; + let recv = oneshot::Receiver::>::from(recv); + recv.await?; + return Err(RecvError::Io(io::ErrorKind::UnexpectedEof.into())); + }); + let conn = client.connect(server_addr, "localhost")?.await?; + let (send, _) = conn.open_bi().await?; + let send = oneshot::Sender::>::from(send); + // this one should fail! + let Err(cause) = send.send(vec![0u8; 1024 * 1024 * 32]).await else { + panic!("client should have failed due to max message size"); + }; + assert!(matches!(cause, SendError::MaxMessageSizeExceeded)); + let Err(cause) = server.await? else { + panic!("server should have failed due to max message size"); + }; + assert!(matches!(cause, RecvError::Io(e) if e.kind() == ErrorKind::ConnectionReset)); + Ok(()) +} + +/// Checks that the max message size is enforced on receiver side. +#[tokio::test] +async fn oneshot_max_message_size_recv() -> TestResult<()> { + let (server, client, server_addr) = create_connected_endpoints()?; + let server: JoinHandle> = tokio::spawn(async move { + let conn = server.accept().await.unwrap().await.map_err(|e| RecvError::Io(e.into()))?; + let (_, recv) = conn.accept_bi().await.map_err(|e| RecvError::Io(e.into()))?; + let recv = oneshot::Receiver::>::from(recv); + recv.await?; + return Err(RecvError::Io(io::ErrorKind::UnexpectedEof.into())); + }); + let conn = client.connect(server_addr, "localhost")?.await?; + let (mut send, _) = conn.open_bi().await?; + // this one should fail on receive! + send.write_length_prefixed(vec![0u8; 1024 * 1024 * 32]).await.ok(); + let Err(cause) = server.await? else { + panic!("server should have failed due to max message size"); + }; + assert!(matches!(cause, RecvError::MaxMessageSizeExceeded)); + Ok(()) +} + +/// Checks that trying to send a message that cannot be serialized results in an error on the sender side and a connection reset on the receiver side. +#[tokio::test] +async fn oneshot_serialize_error_send() -> TestResult<()> { + let (server, client, server_addr) = create_connected_endpoints()?; + let server: JoinHandle> = tokio::spawn(async move { + let conn = server.accept().await.unwrap().await.map_err(|e| RecvError::Io(e.into()))?; + let (_, recv) = conn.accept_bi().await.map_err(|e| RecvError::Io(e.into()))?; + let recv = oneshot::Receiver::::from(recv); + recv.await?; + return Err(RecvError::Io(io::ErrorKind::UnexpectedEof.into())); + }); + let conn = client.connect(server_addr, "localhost")?.await?; + let (send, _) = conn.open_bi().await?; + let send = oneshot::Sender::::from(send); + // this one should fail! + let Err(cause) = send.send(NoSer(1)).await else { + panic!("client should have failed due to serialization error"); + }; + assert!(matches!(cause, SendError::Io(e) if e.kind() == ErrorKind::InvalidData)); + let Err(cause) = server.await? else { + panic!("server should have failed due to serialization error"); + }; + println!("Server error: {:?}", cause); + assert!(matches!(cause, RecvError::Io(e) if e.kind() == ErrorKind::ConnectionReset)); + Ok(()) +} + +#[tokio::test] +async fn oneshot_serialize_error_recv() -> TestResult<()> { + let (server, client, server_addr) = create_connected_endpoints()?; + let server: JoinHandle> = tokio::spawn(async move { + let conn = server.accept().await.unwrap().await.map_err(|e| RecvError::Io(e.into()))?; + let (_, recv) = conn.accept_bi().await.map_err(|e| RecvError::Io(e.into()))?; + let recv = oneshot::Receiver::::from(recv); + recv.await?; + return Err(RecvError::Io(io::ErrorKind::UnexpectedEof.into())); + }); + let conn = client.connect(server_addr, "localhost")?.await?; + let (mut send, _) = conn.open_bi().await?; + // this one should fail on receive! + send.write_length_prefixed(1u64).await?; + send.finish()?; + let Err(cause) = server.await? else { + panic!("server should have failed due to serialization error"); + }; + println!("Server error: {:?}", cause); + assert!(matches!(cause, RecvError::Io(e) if e.kind() == ErrorKind::InvalidData)); + Ok(()) +} \ No newline at end of file From d7ec2b36f4ebb12c9bc2dd88567f3f6529dd717f Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Thu, 26 Jun 2025 12:11:39 +0300 Subject: [PATCH 5/5] Add more tests for the unlikely case that postcard ser/de fails on send or recv side --- src/lib.rs | 22 ++++++-- tests/common.rs | 47 +++++++++++++---- tests/mpsc_channel.rs | 108 ++++++++++++++++++++++++++++---------- tests/oneshot_channel.rs | 109 ++++++++++++++------------------------- 4 files changed, 176 insertions(+), 110 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 31acfc4..e049553 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1371,11 +1371,16 @@ pub mod rpc { Ok(size) => size, Err(e) => { writer.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok(); - return Err(SendError::Io(io::Error::new(io::ErrorKind::InvalidData, e))); + return Err(SendError::Io(io::Error::new( + io::ErrorKind::InvalidData, + e, + ))); } }; if size as u64 > MAX_MESSAGE_SIZE { - writer.reset(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into()).ok(); + writer + .reset(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into()) + .ok(); return Err(SendError::MaxMessageSizeExceeded); } // write via a small buffer to avoid allocation for small values @@ -1458,8 +1463,17 @@ pub mod rpc { value: T, ) -> Pin> + Send + Sync + '_>> { Box::pin(async { - if postcard::experimental::serialized_size(&value)? as u64 > MAX_MESSAGE_SIZE { - self.send.reset(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into()).ok(); + let size = match postcard::experimental::serialized_size(&value) { + Ok(size) => size, + Err(e) => { + self.send.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok(); + return Err(SendError::Io(io::Error::new(io::ErrorKind::InvalidData, e))); + } + }; + if size as u64 > MAX_MESSAGE_SIZE { + self.send + .reset(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into()) + .ok(); return Err(SendError::MaxMessageSizeExceeded); } let value = value; diff --git a/tests/common.rs b/tests/common.rs index df51151..cf89a74 100644 --- a/tests/common.rs +++ b/tests/common.rs @@ -1,16 +1,9 @@ -use std::{ - io::{self, ErrorKind}, - net::{Ipv4Addr, SocketAddr, SocketAddrV4}, - time::Duration, -}; +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; -use irpc::{ - channel::{mpsc, RecvError, SendError}, - util::{make_client_endpoint, make_server_endpoint, AsyncWriteVarintExt}, -}; +use irpc::util::{make_client_endpoint, make_server_endpoint}; use quinn::Endpoint; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use testresult::TestResult; -use tokio::{task::JoinHandle, time::timeout}; pub fn create_connected_endpoints() -> TestResult<(Endpoint, Endpoint, SocketAddr)> { let addr = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0).into(); @@ -20,3 +13,37 @@ pub fn create_connected_endpoints() -> TestResult<(Endpoint, Endpoint, SocketAdd let server_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, port).into(); Ok((server, client, server_addr)) } + +#[derive(Debug)] +pub struct NoSer(pub u64); + +#[derive(Debug, thiserror::Error)] +#[error("Cannot serialize odd number: {0}")] +pub struct OddNumberError(u64); + +impl Serialize for NoSer { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + if self.0 % 2 == 1 { + Err(serde::ser::Error::custom(OddNumberError(self.0))) + } else { + serializer.serialize_u64(self.0) + } + } +} + +impl<'de> Deserialize<'de> for NoSer { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let value = u64::deserialize(deserializer)?; + if value % 2 != 0 { + Err(serde::de::Error::custom(OddNumberError(value))) + } else { + Ok(NoSer(value)) + } + } +} diff --git a/tests/mpsc_channel.rs b/tests/mpsc_channel.rs index 21d5261..28d7aa6 100644 --- a/tests/mpsc_channel.rs +++ b/tests/mpsc_channel.rs @@ -1,25 +1,18 @@ use std::{ io::{self, ErrorKind}, - net::{Ipv4Addr, SocketAddr, SocketAddrV4}, time::Duration, }; use irpc::{ channel::{mpsc, RecvError, SendError}, - util::{make_client_endpoint, make_server_endpoint, AsyncWriteVarintExt}, + util::AsyncWriteVarintExt, }; use quinn::Endpoint; use testresult::TestResult; -use tokio::{task::JoinHandle, time::timeout}; +use tokio::time::timeout; -fn create_connected_endpoints() -> TestResult<(Endpoint, Endpoint, SocketAddr)> { - let addr = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0).into(); - let (server, cert) = make_server_endpoint(addr)?; - let client = make_client_endpoint(addr, &[cert.as_slice()])?; - let port = server.local_addr()?.port(); - let server_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, port).into(); - Ok((server, client, server_addr)) -} +mod common; +use common::*; /// Checks that all clones of a `Sender` will get the closed signal as soon as /// a send fails with an io error. @@ -116,17 +109,27 @@ async fn mpsc_sender_clone_drop_error() -> TestResult<()> { Ok(()) } +async fn vec_receiver(server: Endpoint) -> Result<(), RecvError> { + let conn = server + .accept() + .await + .unwrap() + .await + .map_err(|e| RecvError::Io(e.into()))?; + let (_, recv) = conn + .accept_bi() + .await + .map_err(|e| RecvError::Io(e.into()))?; + let mut recv = mpsc::Receiver::>::from(recv); + while recv.recv().await?.is_some() {} + Err(RecvError::Io(io::ErrorKind::UnexpectedEof.into())) +} + /// Checks that the max message size is enforced on the sender side and that errors are propagated to the receiver side. #[tokio::test] async fn mpsc_max_message_size_send() -> TestResult<()> { let (server, client, server_addr) = create_connected_endpoints()?; - let server: JoinHandle> = tokio::spawn(async move { - let conn = server.accept().await.unwrap().await.map_err(|e| RecvError::Io(e.into()))?; - let (_, recv) = conn.accept_bi().await.map_err(|e| RecvError::Io(e.into()))?; - let mut recv = mpsc::Receiver::>::from(recv); - while let Some(_) = recv.recv().await? {} - return Err(RecvError::Io(io::ErrorKind::UnexpectedEof.into())); - }); + let server = tokio::spawn(vec_receiver(server)); let conn = client.connect(server_addr, "localhost")?.await?; let (send, _) = conn.open_bi().await?; let send = mpsc::Sender::>::from(send); @@ -148,22 +151,73 @@ async fn mpsc_max_message_size_send() -> TestResult<()> { #[tokio::test] async fn mpsc_max_message_size_recv() -> TestResult<()> { let (server, client, server_addr) = create_connected_endpoints()?; - let server: JoinHandle> = tokio::spawn(async move { - let conn = server.accept().await.unwrap().await.map_err(|e| RecvError::Io(e.into()))?; - let (_, recv) = conn.accept_bi().await.map_err(|e| RecvError::Io(e.into()))?; - let mut recv = mpsc::Receiver::>::from(recv); - while let Some(_) = recv.recv().await? {} - return Err(RecvError::Io(io::ErrorKind::UnexpectedEof.into())); - }); + let server = tokio::spawn(vec_receiver(server)); let conn = client.connect(server_addr, "localhost")?.await?; let (mut send, _) = conn.open_bi().await?; // this one should work! send.write_length_prefixed(vec![0u8; 1024 * 1024]).await?; // this one should fail on receive! - send.write_length_prefixed(vec![0u8; 1024 * 1024 * 32]).await.ok(); + send.write_length_prefixed(vec![0u8; 1024 * 1024 * 32]) + .await + .ok(); let Err(cause) = server.await? else { panic!("server should have failed due to max message size"); }; assert!(matches!(cause, RecvError::MaxMessageSizeExceeded)); Ok(()) -} \ No newline at end of file +} + +async fn noser_receiver(server: Endpoint) -> Result<(), RecvError> { + let conn = server + .accept() + .await + .unwrap() + .await + .map_err(|e| RecvError::Io(e.into()))?; + let (_, recv) = conn + .accept_bi() + .await + .map_err(|e| RecvError::Io(e.into()))?; + let mut recv = mpsc::Receiver::::from(recv); + while recv.recv().await?.is_some() {} + Err(RecvError::Io(io::ErrorKind::UnexpectedEof.into())) +} + +/// Checks that a serialization error is caught and propagated to the receiver. +#[tokio::test] +async fn mpsc_serialize_error_send() -> TestResult<()> { + let (server, client, server_addr) = create_connected_endpoints()?; + let server = tokio::spawn(noser_receiver(server)); + let conn = client.connect(server_addr, "localhost")?.await?; + let (send, _) = conn.open_bi().await?; + let send = mpsc::Sender::::from(send); + // this one should work! + send.send(NoSer(0)).await?; + // this one should fail! + let Err(cause) = send.send(NoSer(1)).await else { + panic!("client should have failed due to serialization error"); + }; + assert!(matches!(cause, SendError::Io(e) if e.kind() == ErrorKind::InvalidData)); + let Err(cause) = server.await? else { + panic!("server should have failed due to serialization error"); + }; + assert!(matches!(cause, RecvError::Io(e) if e.kind() == ErrorKind::ConnectionReset)); + Ok(()) +} + +#[tokio::test] +async fn mpsc_serialize_error_recv() -> TestResult<()> { + let (server, client, server_addr) = create_connected_endpoints()?; + let server = tokio::spawn(noser_receiver(server)); + let conn = client.connect(server_addr, "localhost")?.await?; + let (mut send, _) = conn.open_bi().await?; + // this one should work! + send.write_length_prefixed(0u64).await?; + // this one should fail on receive! + send.write_length_prefixed(1u64).await.ok(); + let Err(cause) = server.await? else { + panic!("server should have failed due to serialization error"); + }; + assert!(matches!(cause, RecvError::Io(e) if e.kind() == ErrorKind::InvalidData)); + Ok(()) +} diff --git a/tests/oneshot_channel.rs b/tests/oneshot_channel.rs index 1fe4693..3988721 100644 --- a/tests/oneshot_channel.rs +++ b/tests/oneshot_channel.rs @@ -1,65 +1,36 @@ -use std::{ - io::{self, ErrorKind}, - net::{Ipv4Addr, SocketAddr, SocketAddrV4}, -}; +use std::io::{self, ErrorKind}; use irpc::{ channel::{oneshot, RecvError, SendError}, - util::{make_client_endpoint, make_server_endpoint, AsyncWriteVarintExt}, + util::AsyncWriteVarintExt, }; use quinn::Endpoint; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; use testresult::TestResult; -use tokio::task::JoinHandle; mod common; use common::*; -#[derive(Debug)] -struct NoSer(u64); - -#[derive(Debug, thiserror::Error)] -#[error("Cannot serialize odd number: {0}")] -struct OddNumberError(u64); - -impl Serialize for NoSer { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - if self.0 % 2 == 1 { - Err(serde::ser::Error::custom(OddNumberError(self.0))) - } else { - serializer.serialize_u64(self.0) - } - } -} - -impl<'de> Deserialize<'de> for NoSer { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let value = u64::deserialize(deserializer)?; - if value % 2 != 0 { - Err(serde::de::Error::custom(OddNumberError(value))) - } else { - Ok(NoSer(value)) - } - } +async fn vec_receiver(server: Endpoint) -> Result<(), RecvError> { + let conn = server + .accept() + .await + .unwrap() + .await + .map_err(|e| RecvError::Io(e.into()))?; + let (_, recv) = conn + .accept_bi() + .await + .map_err(|e| RecvError::Io(e.into()))?; + let recv = oneshot::Receiver::>::from(recv); + recv.await?; + Err(RecvError::Io(io::ErrorKind::UnexpectedEof.into())) } /// Checks that the max message size is enforced on the sender side and that errors are propagated to the receiver side. #[tokio::test] async fn oneshot_max_message_size_send() -> TestResult<()> { let (server, client, server_addr) = create_connected_endpoints()?; - let server: JoinHandle> = tokio::spawn(async move { - let conn = server.accept().await.unwrap().await.map_err(|e| RecvError::Io(e.into()))?; - let (_, recv) = conn.accept_bi().await.map_err(|e| RecvError::Io(e.into()))?; - let recv = oneshot::Receiver::>::from(recv); - recv.await?; - return Err(RecvError::Io(io::ErrorKind::UnexpectedEof.into())); - }); + let server = tokio::spawn(vec_receiver(server)); let conn = client.connect(server_addr, "localhost")?.await?; let (send, _) = conn.open_bi().await?; let send = oneshot::Sender::>::from(send); @@ -79,17 +50,13 @@ async fn oneshot_max_message_size_send() -> TestResult<()> { #[tokio::test] async fn oneshot_max_message_size_recv() -> TestResult<()> { let (server, client, server_addr) = create_connected_endpoints()?; - let server: JoinHandle> = tokio::spawn(async move { - let conn = server.accept().await.unwrap().await.map_err(|e| RecvError::Io(e.into()))?; - let (_, recv) = conn.accept_bi().await.map_err(|e| RecvError::Io(e.into()))?; - let recv = oneshot::Receiver::>::from(recv); - recv.await?; - return Err(RecvError::Io(io::ErrorKind::UnexpectedEof.into())); - }); + let server = tokio::spawn(vec_receiver(server)); let conn = client.connect(server_addr, "localhost")?.await?; let (mut send, _) = conn.open_bi().await?; // this one should fail on receive! - send.write_length_prefixed(vec![0u8; 1024 * 1024 * 32]).await.ok(); + send.write_length_prefixed(vec![0u8; 1024 * 1024 * 32]) + .await + .ok(); let Err(cause) = server.await? else { panic!("server should have failed due to max message size"); }; @@ -97,17 +64,27 @@ async fn oneshot_max_message_size_recv() -> TestResult<()> { Ok(()) } +async fn noser_receiver(server: Endpoint) -> Result<(), RecvError> { + let conn = server + .accept() + .await + .unwrap() + .await + .map_err(|e| RecvError::Io(e.into()))?; + let (_, recv) = conn + .accept_bi() + .await + .map_err(|e| RecvError::Io(e.into()))?; + let recv = oneshot::Receiver::::from(recv); + recv.await?; + Err(RecvError::Io(io::ErrorKind::UnexpectedEof.into())) +} + /// Checks that trying to send a message that cannot be serialized results in an error on the sender side and a connection reset on the receiver side. #[tokio::test] async fn oneshot_serialize_error_send() -> TestResult<()> { let (server, client, server_addr) = create_connected_endpoints()?; - let server: JoinHandle> = tokio::spawn(async move { - let conn = server.accept().await.unwrap().await.map_err(|e| RecvError::Io(e.into()))?; - let (_, recv) = conn.accept_bi().await.map_err(|e| RecvError::Io(e.into()))?; - let recv = oneshot::Receiver::::from(recv); - recv.await?; - return Err(RecvError::Io(io::ErrorKind::UnexpectedEof.into())); - }); + let server = tokio::spawn(noser_receiver(server)); let conn = client.connect(server_addr, "localhost")?.await?; let (send, _) = conn.open_bi().await?; let send = oneshot::Sender::::from(send); @@ -127,13 +104,7 @@ async fn oneshot_serialize_error_send() -> TestResult<()> { #[tokio::test] async fn oneshot_serialize_error_recv() -> TestResult<()> { let (server, client, server_addr) = create_connected_endpoints()?; - let server: JoinHandle> = tokio::spawn(async move { - let conn = server.accept().await.unwrap().await.map_err(|e| RecvError::Io(e.into()))?; - let (_, recv) = conn.accept_bi().await.map_err(|e| RecvError::Io(e.into()))?; - let recv = oneshot::Receiver::::from(recv); - recv.await?; - return Err(RecvError::Io(io::ErrorKind::UnexpectedEof.into())); - }); + let server = tokio::spawn(noser_receiver(server)); let conn = client.connect(server_addr, "localhost")?.await?; let (mut send, _) = conn.open_bi().await?; // this one should fail on receive! @@ -145,4 +116,4 @@ async fn oneshot_serialize_error_recv() -> TestResult<()> { println!("Server error: {:?}", cause); assert!(matches!(cause, RecvError::Io(e) if e.kind() == ErrorKind::InvalidData)); Ok(()) -} \ No newline at end of file +}