diff --git a/src/lib.rs b/src/lib.rs index 0372441..e049553 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -141,7 +141,7 @@ pub mod channel { /// Oneshot channel, similar to tokio's oneshot channel pub mod oneshot { - use std::{fmt::Debug, future::Future, io, pin::Pin, task}; + use std::{fmt::Debug, future::Future, pin::Pin, task}; use n0_future::future::Boxed as BoxFuture; @@ -162,7 +162,7 @@ pub mod channel { /// overhead is negligible. However, boxing can also be used for local communication, /// e.g. when applying a transform or filter to the message before sending it. pub type BoxedSender = - Box BoxFuture> + Send + Sync + 'static>; + Box BoxFuture> + Send + Sync + 'static>; /// A sender that can be wrapped in a `Box>`. /// @@ -172,7 +172,9 @@ pub mod channel { /// Remote receivers are always boxed, since for remote communication the boxing /// overhead is negligible. However, boxing can also be used for local communication, /// e.g. when applying a transform or filter to the message before receiving it. - pub trait DynSender: Future> + Send + Sync + 'static { + pub trait DynSender: + Future> + Send + Sync + 'static + { fn is_rpc(&self) -> bool; } @@ -181,7 +183,7 @@ pub mod channel { /// Remote receivers are always boxed, since for remote communication the boxing /// overhead is negligible. However, boxing can also be used for local communication, /// e.g. when applying a transform or filter to the message before receiving it. - pub type BoxedReceiver = BoxFuture>; + pub type BoxedReceiver = BoxFuture>; /// A oneshot sender. /// @@ -230,7 +232,7 @@ pub mod channel { pub async fn send(self, value: T) -> std::result::Result<(), SendError> { match self { Sender::Tokio(tx) => tx.send(value).map_err(|_| SendError::ReceiverClosed), - Sender::Boxed(f) => f(value).await.map_err(SendError::from), + Sender::Boxed(f) => f(value).await, } } } @@ -266,7 +268,7 @@ pub mod channel { fn poll(self: Pin<&mut Self>, cx: &mut task::Context) -> task::Poll { match self.get_mut() { Self::Tokio(rx) => Pin::new(rx).poll(cx).map_err(|_| RecvError::SenderClosed), - Self::Boxed(rx) => Pin::new(rx).poll(cx).map_err(RecvError::Io), + Self::Boxed(rx) => Pin::new(rx).poll(cx), } } } @@ -293,7 +295,7 @@ pub mod channel { impl From for Receiver where F: FnOnce() -> Fut, - Fut: Future> + Send + 'static, + Fut: Future> + Send + 'static, { fn from(f: F) -> Self { Self::Boxed(Box::pin(f())) @@ -317,7 +319,7 @@ pub mod channel { /// /// For the rpc case, the send side can not be cloned, hence mpsc instead of mpsc. pub mod mpsc { - use std::{fmt::Debug, future::Future, io, pin::Pin, sync::Arc}; + use std::{fmt::Debug, future::Future, pin::Pin, sync::Arc}; use super::{RecvError, SendError}; use crate::RpcMessage; @@ -398,7 +400,7 @@ pub mod channel { fn send( &self, value: T, - ) -> Pin> + Send + Sync + '_>>; + ) -> Pin> + Send + '_>>; /// Try to send a message, returning as fast as possible if sending /// is not currently possible. @@ -408,7 +410,7 @@ pub mod channel { fn try_send( &self, value: T, - ) -> Pin> + Send + Sync + '_>>; + ) -> Pin> + Send + '_>>; /// Await the sender close fn closed(&self) -> Pin + Send + Sync + '_>>; @@ -458,7 +460,7 @@ pub mod channel { Sender::Tokio(tx) => { tx.send(value).await.map_err(|_| SendError::ReceiverClosed) } - Sender::Boxed(sink) => sink.send(value).await.map_err(SendError::from), + Sender::Boxed(sink) => sink.send(value).await, } } @@ -492,7 +494,7 @@ pub mod channel { } Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => Ok(false), }, - Sender::Boxed(sink) => sink.try_send(value).await.map_err(SendError::from), + Sender::Boxed(sink) => sink.try_send(value).await, } } } @@ -593,6 +595,9 @@ pub mod channel { /// for local communication. #[error("receiver closed")] ReceiverClosed, + /// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]). + #[error("maximum message size exceeded")] + MaxMessageSizeExceeded, /// The underlying io error. This can occur for remote communication, /// due to a network error or serialization error. #[error("io error: {0}")] @@ -603,6 +608,7 @@ pub mod channel { fn from(e: SendError) -> Self { match e { SendError::ReceiverClosed => io::Error::new(io::ErrorKind::BrokenPipe, e), + SendError::MaxMessageSizeExceeded => io::Error::new(io::ErrorKind::InvalidData, e), SendError::Io(e) => e, } } @@ -619,6 +625,9 @@ pub mod channel { /// for local communication. #[error("sender closed")] SenderClosed, + /// The message exceeded the maximum allowed message size [`MAX_MESSAGE_SIZE`]. + #[error("maximum message size exceeded")] + MaxMessageSizeExceeded, /// An io error occurred. This can occur for remote communication, /// due to a network error or deserialization error. #[error("io error: {0}")] @@ -630,6 +639,7 @@ pub mod channel { match e { RecvError::Io(e) => e, RecvError::SenderClosed => io::Error::new(io::ErrorKind::BrokenPipe, e), + RecvError::MaxMessageSizeExceeded => io::Error::new(io::ErrorKind::InvalidData, e), } } } @@ -1126,6 +1136,15 @@ pub mod rpc { RequestError, RpcMessage, }; + /// Default max message size (16 MiB). + const MAX_MESSAGE_SIZE: u64 = 1024 * 1024 * 16; + + /// Error code on streams if the max message size was exceeded. + const ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED: u32 = 1; + + /// Error code on streams if the sender tried to send an message that could not be postcard serialized. + const ERROR_CODE_INVALID_POSTCARD: u32 = 2; + /// Error that can occur when writing the initial message when doing a /// cross-process RPC. #[derive(Debug, thiserror::Error)] @@ -1133,21 +1152,50 @@ pub mod rpc { /// Error writing to the stream with quinn #[error("error writing to stream: {0}")] Quinn(#[from] quinn::WriteError), + /// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]). + #[error("maximum message size exceeded")] + MaxMessageSizeExceeded, /// Generic IO error, e.g. when serializing the message or when using /// other transports. #[error("error serializing: {0}")] Io(#[from] io::Error), } + impl From for WriteError { + fn from(value: postcard::Error) -> Self { + Self::Io(io::Error::new(io::ErrorKind::InvalidData, value)) + } + } + + impl From for SendError { + fn from(value: postcard::Error) -> Self { + Self::Io(io::Error::new(io::ErrorKind::InvalidData, value)) + } + } + impl From for io::Error { fn from(e: WriteError) -> Self { match e { WriteError::Io(e) => e, + WriteError::MaxMessageSizeExceeded => io::Error::new(io::ErrorKind::InvalidData, e), WriteError::Quinn(e) => e.into(), } } } + impl From for SendError { + fn from(err: quinn::WriteError) -> Self { + match err { + quinn::WriteError::Stopped(code) + if code == ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into() => + { + SendError::MaxMessageSizeExceeded + } + _ => SendError::Io(io::Error::from(err)), + } + } + } + /// Trait to abstract over a client connection to a remote service. /// /// This isn't really that much abstracted, since the result of open_bi must @@ -1256,6 +1304,9 @@ pub mod rpc { { let RemoteSender(mut send, recv, _) = self; let msg = msg.into(); + if postcard::experimental::serialized_size(&msg)? as u64 > MAX_MESSAGE_SIZE { + return Err(WriteError::MaxMessageSizeExceeded); + } let mut buf = SmallVec::<[u8; 128]>::new(); buf.write_length_prefixed(msg)?; send.write_all(&buf).await?; @@ -1266,17 +1317,24 @@ pub mod rpc { impl From for oneshot::Receiver { fn from(mut read: quinn::RecvStream) -> Self { let fut = async move { - let size = read.read_varint_u64().await?.ok_or(io::Error::new( - io::ErrorKind::UnexpectedEof, - "failed to read size", - ))?; + let size = read + .read_varint_u64() + .await? + .ok_or(RecvError::Io(io::Error::new( + io::ErrorKind::UnexpectedEof, + "failed to read size", + )))?; + if size > MAX_MESSAGE_SIZE { + read.stop(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into()).ok(); + return Err(RecvError::MaxMessageSizeExceeded); + } let rest = read .read_to_end(size as usize) .await .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; let msg: T = postcard::from_bytes(&rest) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - io::Result::Ok(msg) + Ok(msg) }; oneshot::Receiver::from(|| fut) } @@ -1309,11 +1367,30 @@ pub mod rpc { fn from(mut writer: quinn::SendStream) -> Self { oneshot::Sender::Boxed(Box::new(move |value| { Box::pin(async move { + let size = match postcard::experimental::serialized_size(&value) { + Ok(size) => size, + Err(e) => { + writer.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok(); + return Err(SendError::Io(io::Error::new( + io::ErrorKind::InvalidData, + e, + ))); + } + }; + if size as u64 > MAX_MESSAGE_SIZE { + writer + .reset(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into()) + .ok(); + return Err(SendError::MaxMessageSizeExceeded); + } // write via a small buffer to avoid allocation for small values let mut buf = SmallVec::<[u8; 128]>::new(); - buf.write_length_prefixed(value)?; + if let Err(e) = buf.write_length_prefixed(value) { + writer.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok(); + return Err(e.into()); + } writer.write_all(&buf).await?; - io::Result::Ok(()) + Ok(()) }) })) } @@ -1353,6 +1430,12 @@ pub mod rpc { let Some(size) = read.read_varint_u64().await? else { return Ok(None); }; + if size > MAX_MESSAGE_SIZE { + self.recv + .stop(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into()) + .ok(); + return Err(RecvError::MaxMessageSizeExceeded); + } let mut buf = vec![0; size as usize]; read.read_exact(&mut buf) .await @@ -1378,11 +1461,27 @@ pub mod rpc { fn send( &mut self, value: T, - ) -> Pin> + Send + Sync + '_>> { + ) -> Pin> + Send + Sync + '_>> { Box::pin(async { + let size = match postcard::experimental::serialized_size(&value) { + Ok(size) => size, + Err(e) => { + self.send.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok(); + return Err(SendError::Io(io::Error::new(io::ErrorKind::InvalidData, e))); + } + }; + if size as u64 > MAX_MESSAGE_SIZE { + self.send + .reset(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into()) + .ok(); + return Err(SendError::MaxMessageSizeExceeded); + } let value = value; self.buffer.clear(); - self.buffer.write_length_prefixed(value)?; + if let Err(e) = self.buffer.write_length_prefixed(value) { + self.send.reset(ERROR_CODE_INVALID_POSTCARD.into()).ok(); + return Err(e.into()); + } self.send.write_all(&self.buffer).await?; self.buffer.clear(); Ok(()) @@ -1392,8 +1491,11 @@ pub mod rpc { fn try_send( &mut self, value: T, - ) -> Pin> + Send + Sync + '_>> { + ) -> Pin> + Send + Sync + '_>> { Box::pin(async { + if postcard::experimental::serialized_size(&value)? as u64 > MAX_MESSAGE_SIZE { + return Err(SendError::MaxMessageSizeExceeded); + } // todo: move the non-async part out of the box. Will require a new return type. let value = value; self.buffer.clear(); @@ -1434,7 +1536,7 @@ pub mod rpc { fn send( &self, value: T, - ) -> Pin> + Send + Sync + '_>> { + ) -> Pin> + Send + '_>> { Box::pin(async { let mut guard = self.0.lock().await; let sender = std::mem::take(guard.deref_mut()); @@ -1446,7 +1548,9 @@ pub mod rpc { } res } - QuinnSenderState::Closed => Err(io::ErrorKind::BrokenPipe.into()), + QuinnSenderState::Closed => { + Err(io::Error::from(io::ErrorKind::BrokenPipe).into()) + } } }) } @@ -1454,7 +1558,7 @@ pub mod rpc { fn try_send( &self, value: T, - ) -> Pin> + Send + Sync + '_>> { + ) -> Pin> + Send + '_>> { Box::pin(async { let mut guard = self.0.lock().await; let sender = std::mem::take(guard.deref_mut()); @@ -1466,7 +1570,9 @@ pub mod rpc { } res } - QuinnSenderState::Closed => Err(io::ErrorKind::BrokenPipe.into()), + QuinnSenderState::Closed => { + Err(io::Error::from(io::ErrorKind::BrokenPipe).into()) + } } }) } diff --git a/tests/common.rs b/tests/common.rs new file mode 100644 index 0000000..cf89a74 --- /dev/null +++ b/tests/common.rs @@ -0,0 +1,49 @@ +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; + +use irpc::util::{make_client_endpoint, make_server_endpoint}; +use quinn::Endpoint; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use testresult::TestResult; + +pub fn create_connected_endpoints() -> TestResult<(Endpoint, Endpoint, SocketAddr)> { + let addr = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0).into(); + let (server, cert) = make_server_endpoint(addr)?; + let client = make_client_endpoint(addr, &[cert.as_slice()])?; + let port = server.local_addr()?.port(); + let server_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, port).into(); + Ok((server, client, server_addr)) +} + +#[derive(Debug)] +pub struct NoSer(pub u64); + +#[derive(Debug, thiserror::Error)] +#[error("Cannot serialize odd number: {0}")] +pub struct OddNumberError(u64); + +impl Serialize for NoSer { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + if self.0 % 2 == 1 { + Err(serde::ser::Error::custom(OddNumberError(self.0))) + } else { + serializer.serialize_u64(self.0) + } + } +} + +impl<'de> Deserialize<'de> for NoSer { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let value = u64::deserialize(deserializer)?; + if value % 2 != 0 { + Err(serde::de::Error::custom(OddNumberError(value))) + } else { + Ok(NoSer(value)) + } + } +} diff --git a/tests/mpsc_channel.rs b/tests/mpsc_channel.rs new file mode 100644 index 0000000..28d7aa6 --- /dev/null +++ b/tests/mpsc_channel.rs @@ -0,0 +1,223 @@ +use std::{ + io::{self, ErrorKind}, + time::Duration, +}; + +use irpc::{ + channel::{mpsc, RecvError, SendError}, + util::AsyncWriteVarintExt, +}; +use quinn::Endpoint; +use testresult::TestResult; +use tokio::time::timeout; + +mod common; +use common::*; + +/// Checks that all clones of a `Sender` will get the closed signal as soon as +/// a send fails with an io error. +#[tokio::test] +async fn mpsc_sender_clone_closed_error() -> TestResult<()> { + tracing_subscriber::fmt::try_init().ok(); + let (server, client, server_addr) = create_connected_endpoints()?; + // accept a single bidi stream on a single connection, then immediately stop it + let server = tokio::spawn(async move { + let conn = server.accept().await.unwrap().await?; + let (_, mut recv) = conn.accept_bi().await?; + recv.stop(1u8.into())?; + TestResult::Ok(()) + }); + let conn = client.connect(server_addr, "localhost")?.await?; + let (send, _) = conn.open_bi().await?; + let send1 = mpsc::Sender::>::from(send); + let send2 = send1.clone(); + let send3 = send1.clone(); + let second_client = tokio::spawn(async move { + send2.closed().await; + }); + let third_client = tokio::spawn(async move { + // this should fail with an io error, since the stream was stopped + loop { + match send3.send(vec![1, 2, 3]).await { + Err(SendError::Io(e)) if e.kind() == ErrorKind::BrokenPipe => break, + _ => {} + }; + } + }); + // send until we get an error because the remote side stopped the stream + while send1.send(vec![1, 2, 3]).await.is_ok() {} + match send1.send(vec![4, 5, 6]).await { + Err(SendError::Io(e)) if e.kind() == ErrorKind::BrokenPipe => {} + e => panic!("Expected SendError::Io with kind BrokenPipe, got {:?}", e), + }; + // check that closed signal was received by the second sender + second_client.await?; + // check that the third sender will get the right kind of io error eventually + third_client.await?; + // server should finish without errors + server.await??; + Ok(()) +} + +/// Checks that all clones of a `Sender` will get the closed signal as soon as +/// a send future gets dropped before completing. +#[tokio::test] +async fn mpsc_sender_clone_drop_error() -> TestResult<()> { + let (server, client, server_addr) = create_connected_endpoints()?; + // accept a single bidi stream on a single connection, then read indefinitely + // until we get an error or the stream is finished + let server = tokio::spawn(async move { + let conn = server.accept().await.unwrap().await?; + let (_, mut recv) = conn.accept_bi().await?; + let mut buf = vec![0u8; 1024]; + while let Ok(Some(_)) = recv.read(&mut buf).await {} + TestResult::Ok(()) + }); + let conn = client.connect(server_addr, "localhost")?.await?; + let (send, _) = conn.open_bi().await?; + let send1 = mpsc::Sender::>::from(send); + let send2 = send1.clone(); + let send3 = send1.clone(); + let second_client = tokio::spawn(async move { + send2.closed().await; + }); + let third_client = tokio::spawn(async move { + // this should fail with an io error, since the stream was stopped + loop { + match send3.send(vec![1, 2, 3]).await { + Err(SendError::Io(e)) if e.kind() == ErrorKind::BrokenPipe => break, + _ => {} + }; + } + }); + // send a lot of data with a tiny timeout, this will cause the send future to be dropped + loop { + let send_future = send1.send(vec![0u8; 1024 * 1024]); + // not sure if there is a better way. I want to poll the future a few times so it has time to + // start sending, but don't want to give it enough time to complete. + // I don't think now_or_never would work, since it wouldn't have time to start sending + if timeout(Duration::from_micros(1), send_future) + .await + .is_err() + { + break; + } + } + server.await??; + second_client.await?; + third_client.await?; + Ok(()) +} + +async fn vec_receiver(server: Endpoint) -> Result<(), RecvError> { + let conn = server + .accept() + .await + .unwrap() + .await + .map_err(|e| RecvError::Io(e.into()))?; + let (_, recv) = conn + .accept_bi() + .await + .map_err(|e| RecvError::Io(e.into()))?; + let mut recv = mpsc::Receiver::>::from(recv); + while recv.recv().await?.is_some() {} + Err(RecvError::Io(io::ErrorKind::UnexpectedEof.into())) +} + +/// Checks that the max message size is enforced on the sender side and that errors are propagated to the receiver side. +#[tokio::test] +async fn mpsc_max_message_size_send() -> TestResult<()> { + let (server, client, server_addr) = create_connected_endpoints()?; + let server = tokio::spawn(vec_receiver(server)); + let conn = client.connect(server_addr, "localhost")?.await?; + let (send, _) = conn.open_bi().await?; + let send = mpsc::Sender::>::from(send); + // this one should work! + send.send(vec![0u8; 1024 * 1024]).await?; + // this one should fail! + let Err(cause) = send.send(vec![0u8; 1024 * 1024 * 32]).await else { + panic!("client should have failed due to max message size"); + }; + assert!(matches!(cause, SendError::MaxMessageSizeExceeded)); + let Err(cause) = server.await? else { + panic!("server should have failed due to max message size"); + }; + assert!(matches!(cause, RecvError::Io(e) if e.kind() == ErrorKind::ConnectionReset)); + Ok(()) +} + +/// Checks that the max message size is enforced on receiver side. +#[tokio::test] +async fn mpsc_max_message_size_recv() -> TestResult<()> { + let (server, client, server_addr) = create_connected_endpoints()?; + let server = tokio::spawn(vec_receiver(server)); + let conn = client.connect(server_addr, "localhost")?.await?; + let (mut send, _) = conn.open_bi().await?; + // this one should work! + send.write_length_prefixed(vec![0u8; 1024 * 1024]).await?; + // this one should fail on receive! + send.write_length_prefixed(vec![0u8; 1024 * 1024 * 32]) + .await + .ok(); + let Err(cause) = server.await? else { + panic!("server should have failed due to max message size"); + }; + assert!(matches!(cause, RecvError::MaxMessageSizeExceeded)); + Ok(()) +} + +async fn noser_receiver(server: Endpoint) -> Result<(), RecvError> { + let conn = server + .accept() + .await + .unwrap() + .await + .map_err(|e| RecvError::Io(e.into()))?; + let (_, recv) = conn + .accept_bi() + .await + .map_err(|e| RecvError::Io(e.into()))?; + let mut recv = mpsc::Receiver::::from(recv); + while recv.recv().await?.is_some() {} + Err(RecvError::Io(io::ErrorKind::UnexpectedEof.into())) +} + +/// Checks that a serialization error is caught and propagated to the receiver. +#[tokio::test] +async fn mpsc_serialize_error_send() -> TestResult<()> { + let (server, client, server_addr) = create_connected_endpoints()?; + let server = tokio::spawn(noser_receiver(server)); + let conn = client.connect(server_addr, "localhost")?.await?; + let (send, _) = conn.open_bi().await?; + let send = mpsc::Sender::::from(send); + // this one should work! + send.send(NoSer(0)).await?; + // this one should fail! + let Err(cause) = send.send(NoSer(1)).await else { + panic!("client should have failed due to serialization error"); + }; + assert!(matches!(cause, SendError::Io(e) if e.kind() == ErrorKind::InvalidData)); + let Err(cause) = server.await? else { + panic!("server should have failed due to serialization error"); + }; + assert!(matches!(cause, RecvError::Io(e) if e.kind() == ErrorKind::ConnectionReset)); + Ok(()) +} + +#[tokio::test] +async fn mpsc_serialize_error_recv() -> TestResult<()> { + let (server, client, server_addr) = create_connected_endpoints()?; + let server = tokio::spawn(noser_receiver(server)); + let conn = client.connect(server_addr, "localhost")?.await?; + let (mut send, _) = conn.open_bi().await?; + // this one should work! + send.write_length_prefixed(0u64).await?; + // this one should fail on receive! + send.write_length_prefixed(1u64).await.ok(); + let Err(cause) = server.await? else { + panic!("server should have failed due to serialization error"); + }; + assert!(matches!(cause, RecvError::Io(e) if e.kind() == ErrorKind::InvalidData)); + Ok(()) +} diff --git a/tests/mpsc_sender.rs b/tests/mpsc_sender.rs deleted file mode 100644 index e8382bb..0000000 --- a/tests/mpsc_sender.rs +++ /dev/null @@ -1,117 +0,0 @@ -use std::{ - io::ErrorKind, - net::{Ipv4Addr, SocketAddr, SocketAddrV4}, - time::Duration, -}; - -use irpc::{ - channel::{mpsc, SendError}, - util::{make_client_endpoint, make_server_endpoint}, -}; -use quinn::Endpoint; -use testresult::TestResult; -use tokio::time::timeout; - -fn create_connected_endpoints() -> TestResult<(Endpoint, Endpoint, SocketAddr)> { - let addr = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0).into(); - let (server, cert) = make_server_endpoint(addr)?; - let client = make_client_endpoint(addr, &[cert.as_slice()])?; - let port = server.local_addr()?.port(); - let server_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, port).into(); - Ok((server, client, server_addr)) -} - -/// Checks that all clones of a `Sender` will get the closed signal as soon as -/// a send fails with an io error. -#[tokio::test] -async fn mpsc_sender_clone_closed_error() -> TestResult<()> { - tracing_subscriber::fmt::try_init().ok(); - let (server, client, server_addr) = create_connected_endpoints()?; - // accept a single bidi stream on a single connection, then immediately stop it - let server = tokio::spawn(async move { - let conn = server.accept().await.unwrap().await?; - let (_, mut recv) = conn.accept_bi().await?; - recv.stop(1u8.into())?; - TestResult::Ok(()) - }); - let conn = client.connect(server_addr, "localhost")?.await?; - let (send, _) = conn.open_bi().await?; - let send1 = mpsc::Sender::>::from(send); - let send2 = send1.clone(); - let send3 = send1.clone(); - let second_client = tokio::spawn(async move { - send2.closed().await; - }); - let third_client = tokio::spawn(async move { - // this should fail with an io error, since the stream was stopped - loop { - match send3.send(vec![1, 2, 3]).await { - Err(SendError::Io(e)) if e.kind() == ErrorKind::BrokenPipe => break, - _ => {} - }; - } - }); - // send until we get an error because the remote side stopped the stream - while send1.send(vec![1, 2, 3]).await.is_ok() {} - match send1.send(vec![4, 5, 6]).await { - Err(SendError::Io(e)) if e.kind() == ErrorKind::BrokenPipe => {} - e => panic!("Expected SendError::Io with kind BrokenPipe, got {:?}", e), - }; - // check that closed signal was received by the second sender - second_client.await?; - // check that the third sender will get the right kind of io error eventually - third_client.await?; - // server should finish without errors - server.await??; - Ok(()) -} - -/// Checks that all clones of a `Sender` will get the closed signal as soon as -/// a send future gets dropped before completing. -#[tokio::test] -async fn mpsc_sender_clone_drop_error() -> TestResult<()> { - let (server, client, server_addr) = create_connected_endpoints()?; - // accept a single bidi stream on a single connection, then read indefinitely - // until we get an error or the stream is finished - let server = tokio::spawn(async move { - let conn = server.accept().await.unwrap().await?; - let (_, mut recv) = conn.accept_bi().await?; - let mut buf = vec![0u8; 1024]; - while let Ok(Some(_)) = recv.read(&mut buf).await {} - TestResult::Ok(()) - }); - let conn = client.connect(server_addr, "localhost")?.await?; - let (send, _) = conn.open_bi().await?; - let send1 = mpsc::Sender::>::from(send); - let send2 = send1.clone(); - let send3 = send1.clone(); - let second_client = tokio::spawn(async move { - send2.closed().await; - }); - let third_client = tokio::spawn(async move { - // this should fail with an io error, since the stream was stopped - loop { - match send3.send(vec![1, 2, 3]).await { - Err(SendError::Io(e)) if e.kind() == ErrorKind::BrokenPipe => break, - _ => {} - }; - } - }); - // send a lot of data with a tiny timeout, this will cause the send future to be dropped - loop { - let send_future = send1.send(vec![0u8; 1024 * 1024]); - // not sure if there is a better way. I want to poll the future a few times so it has time to - // start sending, but don't want to give it enough time to complete. - // I don't think now_or_never would work, since it wouldn't have time to start sending - if timeout(Duration::from_micros(1), send_future) - .await - .is_err() - { - break; - } - } - server.await??; - second_client.await?; - third_client.await?; - Ok(()) -} diff --git a/tests/oneshot_channel.rs b/tests/oneshot_channel.rs new file mode 100644 index 0000000..3988721 --- /dev/null +++ b/tests/oneshot_channel.rs @@ -0,0 +1,119 @@ +use std::io::{self, ErrorKind}; + +use irpc::{ + channel::{oneshot, RecvError, SendError}, + util::AsyncWriteVarintExt, +}; +use quinn::Endpoint; +use testresult::TestResult; + +mod common; +use common::*; + +async fn vec_receiver(server: Endpoint) -> Result<(), RecvError> { + let conn = server + .accept() + .await + .unwrap() + .await + .map_err(|e| RecvError::Io(e.into()))?; + let (_, recv) = conn + .accept_bi() + .await + .map_err(|e| RecvError::Io(e.into()))?; + let recv = oneshot::Receiver::>::from(recv); + recv.await?; + Err(RecvError::Io(io::ErrorKind::UnexpectedEof.into())) +} + +/// Checks that the max message size is enforced on the sender side and that errors are propagated to the receiver side. +#[tokio::test] +async fn oneshot_max_message_size_send() -> TestResult<()> { + let (server, client, server_addr) = create_connected_endpoints()?; + let server = tokio::spawn(vec_receiver(server)); + let conn = client.connect(server_addr, "localhost")?.await?; + let (send, _) = conn.open_bi().await?; + let send = oneshot::Sender::>::from(send); + // this one should fail! + let Err(cause) = send.send(vec![0u8; 1024 * 1024 * 32]).await else { + panic!("client should have failed due to max message size"); + }; + assert!(matches!(cause, SendError::MaxMessageSizeExceeded)); + let Err(cause) = server.await? else { + panic!("server should have failed due to max message size"); + }; + assert!(matches!(cause, RecvError::Io(e) if e.kind() == ErrorKind::ConnectionReset)); + Ok(()) +} + +/// Checks that the max message size is enforced on receiver side. +#[tokio::test] +async fn oneshot_max_message_size_recv() -> TestResult<()> { + let (server, client, server_addr) = create_connected_endpoints()?; + let server = tokio::spawn(vec_receiver(server)); + let conn = client.connect(server_addr, "localhost")?.await?; + let (mut send, _) = conn.open_bi().await?; + // this one should fail on receive! + send.write_length_prefixed(vec![0u8; 1024 * 1024 * 32]) + .await + .ok(); + let Err(cause) = server.await? else { + panic!("server should have failed due to max message size"); + }; + assert!(matches!(cause, RecvError::MaxMessageSizeExceeded)); + Ok(()) +} + +async fn noser_receiver(server: Endpoint) -> Result<(), RecvError> { + let conn = server + .accept() + .await + .unwrap() + .await + .map_err(|e| RecvError::Io(e.into()))?; + let (_, recv) = conn + .accept_bi() + .await + .map_err(|e| RecvError::Io(e.into()))?; + let recv = oneshot::Receiver::::from(recv); + recv.await?; + Err(RecvError::Io(io::ErrorKind::UnexpectedEof.into())) +} + +/// Checks that trying to send a message that cannot be serialized results in an error on the sender side and a connection reset on the receiver side. +#[tokio::test] +async fn oneshot_serialize_error_send() -> TestResult<()> { + let (server, client, server_addr) = create_connected_endpoints()?; + let server = tokio::spawn(noser_receiver(server)); + let conn = client.connect(server_addr, "localhost")?.await?; + let (send, _) = conn.open_bi().await?; + let send = oneshot::Sender::::from(send); + // this one should fail! + let Err(cause) = send.send(NoSer(1)).await else { + panic!("client should have failed due to serialization error"); + }; + assert!(matches!(cause, SendError::Io(e) if e.kind() == ErrorKind::InvalidData)); + let Err(cause) = server.await? else { + panic!("server should have failed due to serialization error"); + }; + println!("Server error: {:?}", cause); + assert!(matches!(cause, RecvError::Io(e) if e.kind() == ErrorKind::ConnectionReset)); + Ok(()) +} + +#[tokio::test] +async fn oneshot_serialize_error_recv() -> TestResult<()> { + let (server, client, server_addr) = create_connected_endpoints()?; + let server = tokio::spawn(noser_receiver(server)); + let conn = client.connect(server_addr, "localhost")?.await?; + let (mut send, _) = conn.open_bi().await?; + // this one should fail on receive! + send.write_length_prefixed(1u64).await?; + send.finish()?; + let Err(cause) = server.await? else { + panic!("server should have failed due to serialization error"); + }; + println!("Server error: {:?}", cause); + assert!(matches!(cause, RecvError::Io(e) if e.kind() == ErrorKind::InvalidData)); + Ok(()) +}