diff --git a/Cargo.lock b/Cargo.lock index 2eb398c..70fe0a9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1655,6 +1655,7 @@ dependencies = [ "rustls", "serde", "smallvec", + "testresult", "thiserror 2.0.12", "thousands", "tokio", @@ -3300,6 +3301,12 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "testresult" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "614b328ff036a4ef882c61570f72918f7e9c5bee1da33f8e7f91e01daee7e56c" + [[package]] name = "thiserror" version = "1.0.69" diff --git a/Cargo.toml b/Cargo.toml index 616d9c6..09c5076 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,6 +57,7 @@ tokio = { workspace = true, features = ["full"] } thousands = "0.2.0" # macro tests trybuild = "1.0.104" +testresult = "0.4.1" [features] # enable the remote transport diff --git a/examples/compute.rs b/examples/compute.rs index 0ec3446..95e4f45 100644 --- a/examples/compute.rs +++ b/examples/compute.rs @@ -123,7 +123,7 @@ impl ComputeActor { tx, inner, span, .. } = fib; let _entered = span.enter(); - let mut sender = tx; + let sender = tx; let mut a = 0u64; let mut b = 1u64; while a <= inner.max { @@ -144,7 +144,7 @@ impl ComputeActor { } = mult; let _entered = span.enter(); let mut receiver = rx; - let mut sender = tx; + let sender = tx; let multiplier = inner.initial; while let Some(num) = receiver.recv().await? { sender.send(multiplier * num).await?; @@ -260,7 +260,7 @@ async fn local() -> anyhow::Result<()> { println!("Local: 5^2 = {}", rx.await?); // Test Sum - let (mut tx, rx) = api.sum().await?; + let (tx, rx) = api.sum().await?; tx.send(1).await?; tx.send(2).await?; tx.send(3).await?; @@ -276,7 +276,7 @@ async fn local() -> anyhow::Result<()> { println!(); // Test Multiply - let (mut in_tx, mut out_rx) = api.multiply(3).await?; + let (in_tx, mut out_rx) = api.multiply(3).await?; in_tx.send(2).await?; in_tx.send(4).await?; in_tx.send(6).await?; @@ -311,7 +311,7 @@ async fn remote() -> anyhow::Result<()> { println!("Remote: 4^2 = {}", rx.await?); // Test Sum - let (mut tx, rx) = api.sum().await?; + let (tx, rx) = api.sum().await?; tx.send(4).await?; tx.send(5).await?; tx.send(6).await?; @@ -327,7 +327,7 @@ async fn remote() -> anyhow::Result<()> { println!(); // Test Multiply - let (mut in_tx, mut out_rx) = api.multiply(5).await?; + let (in_tx, mut out_rx) = api.multiply(5).await?; in_tx.send(1).await?; in_tx.send(2).await?; in_tx.send(3).await?; @@ -380,7 +380,7 @@ async fn bench(api: ComputeApi, n: u64) -> anyhow::Result<()> { // Sequential streaming (using Multiply instead of MultiplyUpdate) { let t0 = std::time::Instant::now(); - let (mut send, mut recv) = api.multiply(2).await?; + let (send, mut recv) = api.multiply(2).await?; let handle = tokio::task::spawn(async move { for i in 0..n { send.send(i).await?; diff --git a/examples/derive.rs b/examples/derive.rs index e03f39f..0f482e4 100644 --- a/examples/derive.rs +++ b/examples/derive.rs @@ -111,7 +111,7 @@ impl StorageActor { } StorageMessage::List(list) => { info!("list {:?}", list); - let WithChannels { mut tx, .. } = list; + let WithChannels { tx, .. } = list; for (key, value) in &self.state { if tx.send(format!("{key}={value}")).await.is_err() { break; @@ -172,7 +172,7 @@ async fn client_demo(api: StorageApi) -> Result<()> { let value = api.get("hello".to_string()).await?; println!("get: hello = {:?}", value); - let (mut tx, rx) = api.set_many().await?; + let (tx, rx) = api.set_many().await?; for i in 0..3 { tx.send((format!("key{i}"), format!("value{i}"))).await?; } diff --git a/examples/storage.rs b/examples/storage.rs index d73f29f..29b07b1 100644 --- a/examples/storage.rs +++ b/examples/storage.rs @@ -104,7 +104,7 @@ impl StorageActor { } StorageMessage::List(list) => { info!("list {:?}", list); - let WithChannels { mut tx, .. } = list; + let WithChannels { tx, .. } = list; for (key, value) in &self.state { if tx.send(format!("{key}={value}")).await.is_err() { break; diff --git a/irpc-iroh/examples/auth.rs b/irpc-iroh/examples/auth.rs index 88944a7..6558bb3 100644 --- a/irpc-iroh/examples/auth.rs +++ b/irpc-iroh/examples/auth.rs @@ -218,7 +218,7 @@ mod storage { } StorageMessage::List(list) => { info!("list {:?}", list); - let WithChannels { mut tx, .. } = list; + let WithChannels { tx, .. } = list; let values = { let state = self.state.lock().unwrap(); // TODO: use async lock to not clone here. diff --git a/irpc-iroh/examples/derive.rs b/irpc-iroh/examples/derive.rs index f348654..b381cb7 100644 --- a/irpc-iroh/examples/derive.rs +++ b/irpc-iroh/examples/derive.rs @@ -141,7 +141,7 @@ mod storage { } StorageMessage::List(list) => { info!("list {:?}", list); - let WithChannels { mut tx, .. } = list; + let WithChannels { tx, .. } = list; for (key, value) in &self.state { if tx.send(format!("{key}={value}")).await.is_err() { break; diff --git a/src/lib.rs b/src/lib.rs index 3d594f3..27dbaf1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -317,7 +317,7 @@ pub mod channel { /// /// For the rpc case, the send side can not be cloned, hence spsc instead of mpsc. pub mod spsc { - use std::{fmt::Debug, future::Future, io, pin::Pin}; + use std::{fmt::Debug, future::Future, io, pin::Pin, sync::Arc}; use super::{RecvError, SendError}; use crate::RpcMessage; @@ -332,15 +332,11 @@ pub mod channel { /// Single producer, single consumer sender. /// - /// For the local case, this wraps a tokio::sync::mpsc::Sender. However, - /// due to the fact that a stream to a remote service can not be cloned, - /// this can also not be cloned. - /// - /// This forces you to use senders in a linear way, passing out references - /// to the sender to other tasks instead of cloning it. + /// For the local case, this wraps a tokio::sync::mpsc::Sender. + #[derive(Clone)] pub enum Sender { Tokio(tokio::sync::mpsc::Sender), - Boxed(Box>), + Boxed(Arc>), } impl Sender { @@ -354,7 +350,7 @@ pub mod channel { } } - pub async fn closed(&mut self) + pub async fn closed(&self) where T: RpcMessage, { @@ -369,7 +365,7 @@ pub mod channel { where T: RpcMessage, { - futures_util::sink::unfold(self, |mut sink, value| async move { + futures_util::sink::unfold(self, |sink, value| async move { sink.send(value).await?; Ok(sink) }) @@ -393,16 +389,16 @@ pub mod channel { } } - /// A sender that can be wrapped in a `Box>`. + /// A sender that can be wrapped in a `Arc>`. pub trait DynSender: Debug + Send + Sync + 'static { /// Send a message. /// /// For the remote case, if the message can not be completely sent, /// this must return an error and disable the channel. fn send( - &mut self, + &self, value: T, - ) -> Pin> + Send + '_>>; + ) -> Pin> + Send + Sync + '_>>; /// Try to send a message, returning as fast as possible if sending /// is not currently possible. @@ -410,12 +406,12 @@ pub mod channel { /// For the remote case, it must be guaranteed that the message is /// either completely sent or not at all. fn try_send( - &mut self, + &self, value: T, - ) -> Pin> + Send + '_>>; + ) -> Pin> + Send + Sync + '_>>; /// Await the sender close - fn closed(&mut self) -> Pin + Send + '_>>; + fn closed(&self) -> Pin + Send + Sync + '_>>; /// True if this is a remote sender fn is_rpc(&self) -> bool; @@ -425,7 +421,14 @@ pub mod channel { pub trait DynReceiver: Debug + Send + Sync + 'static { fn recv( &mut self, - ) -> Pin, RecvError>> + Send + '_>>; + ) -> Pin< + Box< + dyn Future, RecvError>> + + Send + + Sync + + '_, + >, + >; } impl Debug for Sender { @@ -443,7 +446,14 @@ pub mod channel { impl Sender { /// Send a message and yield until either it is sent or an error occurs. - pub async fn send(&mut self, value: T) -> std::result::Result<(), SendError> { + /// + /// ## Cancellation safety + /// + /// If the future is dropped before completion, and if this is a remote sender, + /// then the sender will be closed and further sends will return an [`io::Error`] + /// with [`io::ErrorKind::BrokenPipe`]. Therefore, make sure to always poll the + /// future until completion if you want to reuse the sender or any clone afterwards. + pub async fn send(&self, value: T) -> std::result::Result<(), SendError> { match self { Sender::Tokio(tx) => { tx.send(value).await.map_err(|_| SendError::ReceiverClosed) @@ -466,6 +476,13 @@ pub mod channel { /// all. /// /// Returns true if the message was sent. + /// + /// ## Cancellation safety + /// + /// If the future is dropped before completion, and if this is a remote sender, + /// then the sender will be closed and further sends will return an [`io::Error`] + /// with [`io::ErrorKind::BrokenPipe`]. Therefore, make sure to always poll the + /// future until completion if you want to reuse the sender or any clone afterwards. pub async fn try_send(&mut self, value: T) -> std::result::Result { match self { Sender::Tokio(tx) => match tx.try_send(value) { @@ -505,7 +522,7 @@ pub mod channel { #[cfg(feature = "stream")] pub fn into_stream( self, - ) -> impl n0_future::Stream> + Send + 'static + ) -> impl n0_future::Stream> + Send + Sync + 'static { n0_future::stream::unfold(self, |mut recv| async move { recv.recv().await.transpose().map(|msg| (msg, recv)) @@ -1089,7 +1106,9 @@ pub mod rpc { #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))] pub mod rpc { //! Module for cross-process RPC using [`quinn`]. - use std::{fmt::Debug, future::Future, io, marker::PhantomData, pin::Pin, sync::Arc}; + use std::{ + fmt::Debug, future::Future, io, marker::PhantomData, ops::DerefMut, pin::Pin, sync::Arc, + }; use n0_future::{future::Boxed as BoxFuture, task::JoinSet}; use quinn::ConnectionError; @@ -1303,11 +1322,13 @@ pub mod rpc { impl From for spsc::Sender { fn from(write: quinn::SendStream) -> Self { - spsc::Sender::Boxed(Box::new(QuinnSender { - send: write, - buffer: SmallVec::new(), - _marker: PhantomData, - })) + spsc::Sender::Boxed(Arc::new(QuinnSender(tokio::sync::Mutex::new( + QuinnSenderState::Open(QuinnSenderInner { + send: write, + buffer: SmallVec::new(), + _marker: PhantomData, + }), + )))) } } @@ -1325,8 +1346,9 @@ pub mod rpc { impl DynReceiver for QuinnReceiver { fn recv( &mut self, - ) -> Pin, RecvError>> + Send + '_>> - { + ) -> Pin< + Box, RecvError>> + Send + Sync + '_>, + > { Box::pin(async { let read = &mut self.recv; let Some(size) = read.read_varint_u64().await? else { @@ -1347,20 +1369,17 @@ pub mod rpc { fn drop(&mut self) {} } - struct QuinnSender { + struct QuinnSenderInner { send: quinn::SendStream, buffer: SmallVec<[u8; 128]>, _marker: std::marker::PhantomData, } - impl Debug for QuinnSender { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("QuinnSender").finish() - } - } - - impl DynSender for QuinnSender { - fn send(&mut self, value: T) -> Pin> + Send + '_>> { + impl QuinnSenderInner { + fn send( + &mut self, + value: T, + ) -> Pin> + Send + Sync + '_>> { Box::pin(async { let value = value; self.buffer.clear(); @@ -1374,7 +1393,7 @@ pub mod rpc { fn try_send( &mut self, value: T, - ) -> Pin> + Send + '_>> { + ) -> Pin> + Send + Sync + '_>> { Box::pin(async { // todo: move the non-async part out of the box. Will require a new return type. let value = value; @@ -1390,20 +1409,81 @@ pub mod rpc { }) } - fn closed(&mut self) -> Pin + Send + '_>> { + fn closed(&mut self) -> Pin + Send + Sync + '_>> { Box::pin(async move { self.send.stopped().await.ok(); }) } + } - fn is_rpc(&self) -> bool { - true + #[derive(Default)] + enum QuinnSenderState { + Open(QuinnSenderInner), + #[default] + Closed, + } + + struct QuinnSender(tokio::sync::Mutex>); + + impl Debug for QuinnSender { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("QuinnSender").finish() } } - impl Drop for QuinnSender { - fn drop(&mut self) { - self.send.finish().ok(); + impl DynSender for QuinnSender { + fn send( + &self, + value: T, + ) -> Pin> + Send + Sync + '_>> { + Box::pin(async { + let mut guard = self.0.lock().await; + let sender = std::mem::take(guard.deref_mut()); + match sender { + QuinnSenderState::Open(mut sender) => { + let res = sender.send(value).await; + if res.is_ok() { + *guard = QuinnSenderState::Open(sender); + } + res + } + QuinnSenderState::Closed => Err(io::ErrorKind::BrokenPipe.into()), + } + }) + } + + fn try_send( + &self, + value: T, + ) -> Pin> + Send + Sync + '_>> { + Box::pin(async { + let mut guard = self.0.lock().await; + let sender = std::mem::take(guard.deref_mut()); + match sender { + QuinnSenderState::Open(mut sender) => { + let res = sender.try_send(value).await; + if res.is_ok() { + *guard = QuinnSenderState::Open(sender); + } + res + } + QuinnSenderState::Closed => Err(io::ErrorKind::BrokenPipe.into()), + } + }) + } + + fn closed(&self) -> Pin + Send + Sync + '_>> { + Box::pin(async { + let mut guard = self.0.lock().await; + match guard.deref_mut() { + QuinnSenderState::Open(sender) => sender.closed().await, + QuinnSenderState::Closed => {} + } + }) + } + + fn is_rpc(&self) -> bool { + true } } diff --git a/tests/mpsc_sender.rs b/tests/mpsc_sender.rs new file mode 100644 index 0000000..c9a58da --- /dev/null +++ b/tests/mpsc_sender.rs @@ -0,0 +1,117 @@ +use std::{ + io::ErrorKind, + net::{Ipv4Addr, SocketAddr, SocketAddrV4}, + time::Duration, +}; + +use irpc::{ + channel::{spsc, SendError}, + util::{make_client_endpoint, make_server_endpoint}, +}; +use quinn::Endpoint; +use testresult::TestResult; +use tokio::time::timeout; + +fn create_connected_endpoints() -> TestResult<(Endpoint, Endpoint, SocketAddr)> { + let addr = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0).into(); + let (server, cert) = make_server_endpoint(addr)?; + let client = make_client_endpoint(addr, &[cert.as_slice()])?; + let port = server.local_addr()?.port(); + let server_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, port).into(); + Ok((server, client, server_addr)) +} + +/// Checks that all clones of a `Sender` will get the closed signal as soon as +/// a send fails with an io error. +#[tokio::test] +async fn mpsc_sender_clone_closed_error() -> TestResult<()> { + tracing_subscriber::fmt::try_init().ok(); + let (server, client, server_addr) = create_connected_endpoints()?; + // accept a single bidi stream on a single connection, then immediately stop it + let server = tokio::spawn(async move { + let conn = server.accept().await.unwrap().await?; + let (_, mut recv) = conn.accept_bi().await?; + recv.stop(1u8.into())?; + TestResult::Ok(()) + }); + let conn = client.connect(server_addr, "localhost")?.await?; + let (send, _) = conn.open_bi().await?; + let send1 = spsc::Sender::>::from(send); + let send2 = send1.clone(); + let send3 = send1.clone(); + let second_client = tokio::spawn(async move { + send2.closed().await; + }); + let third_client = tokio::spawn(async move { + // this should fail with an io error, since the stream was stopped + loop { + match send3.send(vec![1, 2, 3]).await { + Err(SendError::Io(e)) if e.kind() == ErrorKind::BrokenPipe => break, + _ => {} + }; + } + }); + // send until we get an error because the remote side stopped the stream + while send1.send(vec![1, 2, 3]).await.is_ok() {} + match send1.send(vec![4, 5, 6]).await { + Err(SendError::Io(e)) if e.kind() == ErrorKind::BrokenPipe => {} + e => panic!("Expected SendError::Io with kind BrokenPipe, got {:?}", e), + }; + // check that closed signal was received by the second sender + second_client.await?; + // check that the third sender will get the right kind of io error eventually + third_client.await?; + // server should finish without errors + server.await??; + Ok(()) +} + +/// Checks that all clones of a `Sender` will get the closed signal as soon as +/// a send future gets dropped before completing. +#[tokio::test] +async fn mpsc_sender_clone_drop_error() -> TestResult<()> { + let (server, client, server_addr) = create_connected_endpoints()?; + // accept a single bidi stream on a single connection, then read indefinitely + // until we get an error or the stream is finished + let server = tokio::spawn(async move { + let conn = server.accept().await.unwrap().await?; + let (_, mut recv) = conn.accept_bi().await?; + let mut buf = vec![0u8; 1024]; + while let Ok(Some(_)) = recv.read(&mut buf).await {} + TestResult::Ok(()) + }); + let conn = client.connect(server_addr, "localhost")?.await?; + let (send, _) = conn.open_bi().await?; + let send1 = spsc::Sender::>::from(send); + let send2 = send1.clone(); + let send3 = send1.clone(); + let second_client = tokio::spawn(async move { + send2.closed().await; + }); + let third_client = tokio::spawn(async move { + // this should fail with an io error, since the stream was stopped + loop { + match send3.send(vec![1, 2, 3]).await { + Err(SendError::Io(e)) if e.kind() == ErrorKind::BrokenPipe => break, + _ => {} + }; + } + }); + // send a lot of data with a tiny timeout, this will cause the send future to be dropped + loop { + let send_future = send1.send(vec![0u8; 1024 * 1024]); + // not sure if there is a better way. I want to poll the future a few times so it has time to + // start sending, but don't want to give it enough time to complete. + // I don't think now_or_never would work, since it wouldn't have time to start sending + if timeout(Duration::from_micros(1), send_future) + .await + .is_err() + { + break; + } + } + server.await??; + second_client.await?; + third_client.await?; + Ok(()) +}