diff --git a/irpc-iroh/src/lib.rs b/irpc-iroh/src/lib.rs index 5851ded..bd1702c 100644 --- a/irpc-iroh/src/lib.rs +++ b/irpc-iroh/src/lib.rs @@ -8,7 +8,7 @@ use iroh::{ protocol::{AcceptError, ProtocolHandler}, }; use irpc::{ - channel::RecvError, + channel::oneshot, rpc::{ Handler, RemoteConnection, RemoteService, ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED, MAX_MESSAGE_SIZE, @@ -257,7 +257,7 @@ pub async fn read_request_raw( ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into(), b"request exceeded max message size", ); - return Err(RecvError::MaxMessageSizeExceeded.into()); + return Err(oneshot::RecvError::MaxMessageSizeExceeded.into()); } let mut buf = vec![0; size as usize]; recv.read_exact(&mut buf) diff --git a/src/lib.rs b/src/lib.rs index a01738a..8d0ad3c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -371,13 +371,47 @@ pub mod channel { /// Oneshot channel, similar to tokio's oneshot channel pub mod oneshot { - use std::{fmt::Debug, future::Future, pin::Pin, task}; + use std::{fmt::Debug, future::Future, io, pin::Pin, task}; use n0_future::future::Boxed as BoxFuture; - use super::{RecvError, SendError}; + use super::SendError; use crate::util::FusedOneshotReceiver; + /// Error when receiving a oneshot or mpsc message. For local communication, + /// the only thing that can go wrong is that the sender has been closed. + /// + /// For rpc communication, there can be any number of errors, so this is a + /// generic io error. + #[derive(Debug, thiserror::Error)] + pub enum RecvError { + /// The sender has been closed. This is the only error that can occur + /// for local communication. + #[error("sender closed")] + SenderClosed, + /// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]). + /// + /// [`MAX_MESSAGE_SIZE`]: crate::rpc::MAX_MESSAGE_SIZE + #[error("maximum message size exceeded")] + MaxMessageSizeExceeded, + /// An io error occurred. This can occur for remote communication, + /// due to a network error or deserialization error. + #[error("io error: {0}")] + Io(#[from] io::Error), + } + + impl From for io::Error { + fn from(e: RecvError) -> Self { + match e { + RecvError::Io(e) => e, + RecvError::SenderClosed => io::Error::new(io::ErrorKind::BrokenPipe, e), + RecvError::MaxMessageSizeExceeded => { + io::Error::new(io::ErrorKind::InvalidData, e) + } + } + } + } + /// Create a local oneshot sender and receiver pair. /// /// This is currently using a tokio channel pair internally. @@ -586,9 +620,38 @@ pub mod channel { /// /// For the rpc case, the send side can not be cloned, hence mpsc instead of mpsc. pub mod mpsc { - use std::{fmt::Debug, future::Future, marker::PhantomData, pin::Pin, sync::Arc}; + use std::{fmt::Debug, future::Future, io, marker::PhantomData, pin::Pin, sync::Arc}; - use super::{RecvError, SendError}; + use super::SendError; + + /// Error when receiving a oneshot or mpsc message. For local communication, + /// the only thing that can go wrong is that the sender has been closed. + /// + /// For rpc communication, there can be any number of errors, so this is a + /// generic io error. + #[derive(Debug, thiserror::Error)] + pub enum RecvError { + /// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]). + /// + /// [`MAX_MESSAGE_SIZE`]: crate::rpc::MAX_MESSAGE_SIZE + #[error("maximum message size exceeded")] + MaxMessageSizeExceeded, + /// An io error occurred. This can occur for remote communication, + /// due to a network error or deserialization error. + #[error("io error: {0}")] + Io(#[from] io::Error), + } + + impl From for io::Error { + fn from(e: RecvError) -> Self { + match e { + RecvError::Io(e) => e, + RecvError::MaxMessageSizeExceeded => { + io::Error::new(io::ErrorKind::InvalidData, e) + } + } + } + } /// Create a local mpsc sender and receiver pair, with the given buffer size. /// @@ -1067,38 +1130,6 @@ pub mod channel { } } } - - /// Error when receiving a oneshot or mpsc message. For local communication, - /// the only thing that can go wrong is that the sender has been closed. - /// - /// For rpc communication, there can be any number of errors, so this is a - /// generic io error. - #[derive(Debug, thiserror::Error)] - pub enum RecvError { - /// The sender has been closed. This is the only error that can occur - /// for local communication. - #[error("sender closed")] - SenderClosed, - /// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]). - /// - /// [`MAX_MESSAGE_SIZE`]: crate::rpc::MAX_MESSAGE_SIZE - #[error("maximum message size exceeded")] - MaxMessageSizeExceeded, - /// An io error occurred. This can occur for remote communication, - /// due to a network error or deserialization error. - #[error("io error: {0}")] - Io(#[from] io::Error), - } - - impl From for io::Error { - fn from(e: RecvError) -> Self { - match e { - RecvError::Io(e) => e, - RecvError::SenderClosed => io::Error::new(io::ErrorKind::BrokenPipe, e), - RecvError::MaxMessageSizeExceeded => io::Error::new(io::ErrorKind::InvalidData, e), - } - } - } } /// A wrapper for a message with channels to send and receive it. @@ -1682,8 +1713,10 @@ pub enum Error { Request(#[from] RequestError), #[error("send error: {0}")] Send(#[from] channel::SendError), - #[error("recv error: {0}")] - Recv(#[from] channel::RecvError), + #[error("mpsc recv error: {0}")] + MpscRecv(#[from] channel::mpsc::RecvError), + #[error("oneshot recv error: {0}")] + OneshotRecv(#[from] channel::oneshot::RecvError), #[cfg(feature = "rpc")] #[error("recv error: {0}")] Write(#[from] rpc::WriteError), @@ -1697,7 +1730,8 @@ impl From for io::Error { match e { Error::Request(e) => e.into(), Error::Send(e) => e.into(), - Error::Recv(e) => e.into(), + Error::MpscRecv(e) => e.into(), + Error::OneshotRecv(e) => e.into(), #[cfg(feature = "rpc")] Error::Write(e) => e.into(), } @@ -1772,7 +1806,7 @@ pub mod rpc { channel::{ mpsc::{self, DynReceiver, DynSender}, none::NoSender, - oneshot, RecvError, SendError, + oneshot, SendError, }, util::{now_or_never, AsyncReadVarintExt, WriteVarintExt}, LocalSender, RequestError, RpcMessage, Service, @@ -1970,16 +2004,13 @@ pub mod rpc { impl From for oneshot::Receiver { fn from(mut read: quinn::RecvStream) -> Self { let fut = async move { - let size = read - .read_varint_u64() - .await? - .ok_or(RecvError::Io(io::Error::new( - io::ErrorKind::UnexpectedEof, - "failed to read size", - )))?; + let size = read.read_varint_u64().await?.ok_or(io::Error::new( + io::ErrorKind::UnexpectedEof, + "failed to read size", + ))?; if size > MAX_MESSAGE_SIZE { read.stop(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into()).ok(); - return Err(RecvError::MaxMessageSizeExceeded); + return Err(oneshot::RecvError::MaxMessageSizeExceeded); } let rest = read .read_to_end(size as usize) @@ -2076,7 +2107,12 @@ pub mod rpc { fn recv( &mut self, ) -> Pin< - Box, RecvError>> + Send + Sync + '_>, + Box< + dyn Future, mpsc::RecvError>> + + Send + + Sync + + '_, + >, > { Box::pin(async { let read = &mut self.recv; @@ -2087,7 +2123,7 @@ pub mod rpc { self.recv .stop(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into()) .ok(); - return Err(RecvError::MaxMessageSizeExceeded); + return Err(mpsc::RecvError::MaxMessageSizeExceeded); } let mut buf = vec![0; size as usize]; read.read_exact(&mut buf) @@ -2375,7 +2411,7 @@ pub mod rpc { ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into(), b"request exceeded max message size", ); - return Err(RecvError::MaxMessageSizeExceeded.into()); + return Err(mpsc::RecvError::MaxMessageSizeExceeded.into()); } let mut buf = vec![0; size as usize]; recv.read_exact(&mut buf) diff --git a/tests/mpsc_channel.rs b/tests/mpsc_channel.rs index d3982b6..4717fd1 100644 --- a/tests/mpsc_channel.rs +++ b/tests/mpsc_channel.rs @@ -6,7 +6,10 @@ use std::{ }; use irpc::{ - channel::{mpsc, RecvError, SendError}, + channel::{ + mpsc::{self, Receiver, RecvError}, + SendError, + }, util::AsyncWriteVarintExt, }; use quinn::Endpoint; @@ -122,7 +125,7 @@ async fn vec_receiver(server: Endpoint) -> Result<(), RecvError> { .accept_bi() .await .map_err(|e| RecvError::Io(e.into()))?; - let mut recv = mpsc::Receiver::>::from(recv); + let mut recv = Receiver::>::from(recv); while recv.recv().await?.is_some() {} Err(RecvError::Io(io::ErrorKind::UnexpectedEof.into())) } @@ -145,7 +148,7 @@ async fn mpsc_max_message_size_send() -> TestResult<()> { let Err(cause) = server.await? else { panic!("server should have failed due to max message size"); }; - assert!(matches!(cause, RecvError::Io(e) if e.kind() == ErrorKind::ConnectionReset)); + assert!(matches!(cause, mpsc::RecvError::Io(e) if e.kind() == ErrorKind::ConnectionReset)); Ok(()) } @@ -165,24 +168,24 @@ async fn mpsc_max_message_size_recv() -> TestResult<()> { let Err(cause) = server.await? else { panic!("server should have failed due to max message size"); }; - assert!(matches!(cause, RecvError::MaxMessageSizeExceeded)); + assert!(matches!(cause, mpsc::RecvError::MaxMessageSizeExceeded)); Ok(()) } -async fn noser_receiver(server: Endpoint) -> Result<(), RecvError> { +async fn noser_receiver(server: Endpoint) -> Result<(), mpsc::RecvError> { let conn = server .accept() .await .unwrap() .await - .map_err(|e| RecvError::Io(e.into()))?; + .map_err(|e| mpsc::RecvError::Io(e.into()))?; let (_, recv) = conn .accept_bi() .await - .map_err(|e| RecvError::Io(e.into()))?; + .map_err(|e| mpsc::RecvError::Io(e.into()))?; let mut recv = mpsc::Receiver::::from(recv); while recv.recv().await?.is_some() {} - Err(RecvError::Io(io::ErrorKind::UnexpectedEof.into())) + Err(mpsc::RecvError::Io(io::ErrorKind::UnexpectedEof.into())) } /// Checks that a serialization error is caught and propagated to the receiver. @@ -203,7 +206,7 @@ async fn mpsc_serialize_error_send() -> TestResult<()> { let Err(cause) = server.await? else { panic!("server should have failed due to serialization error"); }; - assert!(matches!(cause, RecvError::Io(e) if e.kind() == ErrorKind::ConnectionReset)); + assert!(matches!(cause, mpsc::RecvError::Io(e) if e.kind() == ErrorKind::ConnectionReset)); Ok(()) } @@ -220,6 +223,6 @@ async fn mpsc_serialize_error_recv() -> TestResult<()> { let Err(cause) = server.await? else { panic!("server should have failed due to serialization error"); }; - assert!(matches!(cause, RecvError::Io(e) if e.kind() == ErrorKind::InvalidData)); + assert!(matches!(cause, mpsc::RecvError::Io(e) if e.kind() == ErrorKind::InvalidData)); Ok(()) } diff --git a/tests/oneshot_channel.rs b/tests/oneshot_channel.rs index 922edbc..72202e9 100644 --- a/tests/oneshot_channel.rs +++ b/tests/oneshot_channel.rs @@ -3,7 +3,10 @@ use std::io::{self, ErrorKind}; use irpc::{ - channel::{oneshot, RecvError, SendError}, + channel::{ + oneshot::{self, RecvError}, + SendError, + }, util::AsyncWriteVarintExt, }; use quinn::Endpoint;