Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions irpc-iroh/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use iroh::{
protocol::{AcceptError, ProtocolHandler},
};
use irpc::{
channel::RecvError,
channel::oneshot,
rpc::{
Handler, RemoteConnection, RemoteService, ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED,
MAX_MESSAGE_SIZE,
Expand Down Expand Up @@ -257,7 +257,7 @@ pub async fn read_request_raw<R: DeserializeOwned + 'static>(
ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into(),
b"request exceeded max message size",
);
return Err(RecvError::MaxMessageSizeExceeded.into());
return Err(oneshot::RecvError::MaxMessageSizeExceeded.into());
}
let mut buf = vec![0; size as usize];
recv.read_exact(&mut buf)
Expand Down
138 changes: 87 additions & 51 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -371,13 +371,47 @@ pub mod channel {

/// Oneshot channel, similar to tokio's oneshot channel
pub mod oneshot {
use std::{fmt::Debug, future::Future, pin::Pin, task};
use std::{fmt::Debug, future::Future, io, pin::Pin, task};

use n0_future::future::Boxed as BoxFuture;

use super::{RecvError, SendError};
use super::SendError;
use crate::util::FusedOneshotReceiver;

/// Error when receiving a oneshot or mpsc message. For local communication,
/// the only thing that can go wrong is that the sender has been closed.
///
/// For rpc communication, there can be any number of errors, so this is a
/// generic io error.
#[derive(Debug, thiserror::Error)]
pub enum RecvError {
/// The sender has been closed. This is the only error that can occur
/// for local communication.
#[error("sender closed")]
SenderClosed,
/// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]).
///
/// [`MAX_MESSAGE_SIZE`]: crate::rpc::MAX_MESSAGE_SIZE
#[error("maximum message size exceeded")]
MaxMessageSizeExceeded,
/// An io error occurred. This can occur for remote communication,
/// due to a network error or deserialization error.
#[error("io error: {0}")]
Io(#[from] io::Error),
}

impl From<RecvError> for io::Error {
fn from(e: RecvError) -> Self {
match e {
RecvError::Io(e) => e,
RecvError::SenderClosed => io::Error::new(io::ErrorKind::BrokenPipe, e),
RecvError::MaxMessageSizeExceeded => {
io::Error::new(io::ErrorKind::InvalidData, e)
}
}
}
}

/// Create a local oneshot sender and receiver pair.
///
/// This is currently using a tokio channel pair internally.
Expand Down Expand Up @@ -586,9 +620,38 @@ pub mod channel {
///
/// For the rpc case, the send side can not be cloned, hence mpsc instead of mpsc.
pub mod mpsc {
use std::{fmt::Debug, future::Future, marker::PhantomData, pin::Pin, sync::Arc};
use std::{fmt::Debug, future::Future, io, marker::PhantomData, pin::Pin, sync::Arc};

use super::{RecvError, SendError};
use super::SendError;

/// Error when receiving a oneshot or mpsc message. For local communication,
/// the only thing that can go wrong is that the sender has been closed.
///
/// For rpc communication, there can be any number of errors, so this is a
/// generic io error.
#[derive(Debug, thiserror::Error)]
pub enum RecvError {
/// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]).
///
/// [`MAX_MESSAGE_SIZE`]: crate::rpc::MAX_MESSAGE_SIZE
#[error("maximum message size exceeded")]
MaxMessageSizeExceeded,
/// An io error occurred. This can occur for remote communication,
/// due to a network error or deserialization error.
#[error("io error: {0}")]
Io(#[from] io::Error),
}

impl From<RecvError> for io::Error {
fn from(e: RecvError) -> Self {
match e {
RecvError::Io(e) => e,
RecvError::MaxMessageSizeExceeded => {
io::Error::new(io::ErrorKind::InvalidData, e)
}
}
}
}

/// Create a local mpsc sender and receiver pair, with the given buffer size.
///
Expand Down Expand Up @@ -1067,38 +1130,6 @@ pub mod channel {
}
}
}

/// Error when receiving a oneshot or mpsc message. For local communication,
/// the only thing that can go wrong is that the sender has been closed.
///
/// For rpc communication, there can be any number of errors, so this is a
/// generic io error.
#[derive(Debug, thiserror::Error)]
pub enum RecvError {
/// The sender has been closed. This is the only error that can occur
/// for local communication.
#[error("sender closed")]
SenderClosed,
/// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]).
///
/// [`MAX_MESSAGE_SIZE`]: crate::rpc::MAX_MESSAGE_SIZE
#[error("maximum message size exceeded")]
MaxMessageSizeExceeded,
/// An io error occurred. This can occur for remote communication,
/// due to a network error or deserialization error.
#[error("io error: {0}")]
Io(#[from] io::Error),
}

impl From<RecvError> for io::Error {
fn from(e: RecvError) -> Self {
match e {
RecvError::Io(e) => e,
RecvError::SenderClosed => io::Error::new(io::ErrorKind::BrokenPipe, e),
RecvError::MaxMessageSizeExceeded => io::Error::new(io::ErrorKind::InvalidData, e),
}
}
}
}

/// A wrapper for a message with channels to send and receive it.
Expand Down Expand Up @@ -1682,8 +1713,10 @@ pub enum Error {
Request(#[from] RequestError),
#[error("send error: {0}")]
Send(#[from] channel::SendError),
#[error("recv error: {0}")]
Recv(#[from] channel::RecvError),
#[error("mpsc recv error: {0}")]
MpscRecv(#[from] channel::mpsc::RecvError),
#[error("oneshot recv error: {0}")]
OneshotRecv(#[from] channel::oneshot::RecvError),
#[cfg(feature = "rpc")]
#[error("recv error: {0}")]
Write(#[from] rpc::WriteError),
Expand All @@ -1697,7 +1730,8 @@ impl From<Error> for io::Error {
match e {
Error::Request(e) => e.into(),
Error::Send(e) => e.into(),
Error::Recv(e) => e.into(),
Error::MpscRecv(e) => e.into(),
Error::OneshotRecv(e) => e.into(),
#[cfg(feature = "rpc")]
Error::Write(e) => e.into(),
}
Expand Down Expand Up @@ -1772,7 +1806,7 @@ pub mod rpc {
channel::{
mpsc::{self, DynReceiver, DynSender},
none::NoSender,
oneshot, RecvError, SendError,
oneshot, SendError,
},
util::{now_or_never, AsyncReadVarintExt, WriteVarintExt},
LocalSender, RequestError, RpcMessage, Service,
Expand Down Expand Up @@ -1970,16 +2004,13 @@ pub mod rpc {
impl<T: DeserializeOwned> From<quinn::RecvStream> for oneshot::Receiver<T> {
fn from(mut read: quinn::RecvStream) -> Self {
let fut = async move {
let size = read
.read_varint_u64()
.await?
.ok_or(RecvError::Io(io::Error::new(
io::ErrorKind::UnexpectedEof,
"failed to read size",
)))?;
let size = read.read_varint_u64().await?.ok_or(io::Error::new(
io::ErrorKind::UnexpectedEof,
"failed to read size",
))?;
if size > MAX_MESSAGE_SIZE {
read.stop(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into()).ok();
return Err(RecvError::MaxMessageSizeExceeded);
return Err(oneshot::RecvError::MaxMessageSizeExceeded);
}
let rest = read
.read_to_end(size as usize)
Expand Down Expand Up @@ -2076,7 +2107,12 @@ pub mod rpc {
fn recv(
&mut self,
) -> Pin<
Box<dyn Future<Output = std::result::Result<Option<T>, RecvError>> + Send + Sync + '_>,
Box<
dyn Future<Output = std::result::Result<Option<T>, mpsc::RecvError>>
+ Send
+ Sync
+ '_,
>,
> {
Box::pin(async {
let read = &mut self.recv;
Expand All @@ -2087,7 +2123,7 @@ pub mod rpc {
self.recv
.stop(ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into())
.ok();
return Err(RecvError::MaxMessageSizeExceeded);
return Err(mpsc::RecvError::MaxMessageSizeExceeded);
}
let mut buf = vec![0; size as usize];
read.read_exact(&mut buf)
Expand Down Expand Up @@ -2375,7 +2411,7 @@ pub mod rpc {
ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into(),
b"request exceeded max message size",
);
return Err(RecvError::MaxMessageSizeExceeded.into());
return Err(mpsc::RecvError::MaxMessageSizeExceeded.into());
}
let mut buf = vec![0; size as usize];
recv.read_exact(&mut buf)
Expand Down
23 changes: 13 additions & 10 deletions tests/mpsc_channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ use std::{
};

use irpc::{
channel::{mpsc, RecvError, SendError},
channel::{
mpsc::{self, Receiver, RecvError},
SendError,
},
util::AsyncWriteVarintExt,
};
use quinn::Endpoint;
Expand Down Expand Up @@ -122,7 +125,7 @@ async fn vec_receiver(server: Endpoint) -> Result<(), RecvError> {
.accept_bi()
.await
.map_err(|e| RecvError::Io(e.into()))?;
let mut recv = mpsc::Receiver::<Vec<u8>>::from(recv);
let mut recv = Receiver::<Vec<u8>>::from(recv);
while recv.recv().await?.is_some() {}
Err(RecvError::Io(io::ErrorKind::UnexpectedEof.into()))
}
Expand All @@ -145,7 +148,7 @@ async fn mpsc_max_message_size_send() -> TestResult<()> {
let Err(cause) = server.await? else {
panic!("server should have failed due to max message size");
};
assert!(matches!(cause, RecvError::Io(e) if e.kind() == ErrorKind::ConnectionReset));
assert!(matches!(cause, mpsc::RecvError::Io(e) if e.kind() == ErrorKind::ConnectionReset));
Ok(())
}

Expand All @@ -165,24 +168,24 @@ async fn mpsc_max_message_size_recv() -> TestResult<()> {
let Err(cause) = server.await? else {
panic!("server should have failed due to max message size");
};
assert!(matches!(cause, RecvError::MaxMessageSizeExceeded));
assert!(matches!(cause, mpsc::RecvError::MaxMessageSizeExceeded));
Ok(())
}

async fn noser_receiver(server: Endpoint) -> Result<(), RecvError> {
async fn noser_receiver(server: Endpoint) -> Result<(), mpsc::RecvError> {
let conn = server
.accept()
.await
.unwrap()
.await
.map_err(|e| RecvError::Io(e.into()))?;
.map_err(|e| mpsc::RecvError::Io(e.into()))?;
let (_, recv) = conn
.accept_bi()
.await
.map_err(|e| RecvError::Io(e.into()))?;
.map_err(|e| mpsc::RecvError::Io(e.into()))?;
let mut recv = mpsc::Receiver::<NoSer>::from(recv);
while recv.recv().await?.is_some() {}
Err(RecvError::Io(io::ErrorKind::UnexpectedEof.into()))
Err(mpsc::RecvError::Io(io::ErrorKind::UnexpectedEof.into()))
}

/// Checks that a serialization error is caught and propagated to the receiver.
Expand All @@ -203,7 +206,7 @@ async fn mpsc_serialize_error_send() -> TestResult<()> {
let Err(cause) = server.await? else {
panic!("server should have failed due to serialization error");
};
assert!(matches!(cause, RecvError::Io(e) if e.kind() == ErrorKind::ConnectionReset));
assert!(matches!(cause, mpsc::RecvError::Io(e) if e.kind() == ErrorKind::ConnectionReset));
Ok(())
}

Expand All @@ -220,6 +223,6 @@ async fn mpsc_serialize_error_recv() -> TestResult<()> {
let Err(cause) = server.await? else {
panic!("server should have failed due to serialization error");
};
assert!(matches!(cause, RecvError::Io(e) if e.kind() == ErrorKind::InvalidData));
assert!(matches!(cause, mpsc::RecvError::Io(e) if e.kind() == ErrorKind::InvalidData));
Ok(())
}
5 changes: 4 additions & 1 deletion tests/oneshot_channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
use std::io::{self, ErrorKind};

use irpc::{
channel::{oneshot, RecvError, SendError},
channel::{
oneshot::{self, RecvError},
SendError,
},
util::AsyncWriteVarintExt,
};
use quinn::Endpoint;
Expand Down