diff --git a/examples/compute.rs b/examples/compute.rs index 2b56a77..7908425 100644 --- a/examples/compute.rs +++ b/examples/compute.rs @@ -55,12 +55,12 @@ struct Multiply { // The actor that processes requests struct ComputeActor { - recv: tokio::sync::mpsc::Receiver, + recv: irpc::channel::mpsc::Receiver, } impl ComputeActor { pub fn local() -> ComputeApi { - let (tx, rx) = tokio::sync::mpsc::channel(128); + let (tx, rx) = irpc::channel::mpsc::channel(128); let actor = Self { recv: rx }; n0_future::task::spawn(actor.run()); ComputeApi { @@ -69,7 +69,7 @@ impl ComputeActor { } async fn run(mut self) { - while let Some(msg) = self.recv.recv().await { + while let Ok(Some(msg)) = self.recv.recv().await { n0_future::task::spawn(async move { if let Err(cause) = Self::handle(msg).await { eprintln!("Error: {cause}"); diff --git a/src/lib.rs b/src/lib.rs index b11b9ff..a01738a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -305,7 +305,10 @@ use self::{ }, sealed::Sealed, }; +use crate::channel::SendError; +#[cfg(test)] +mod tests; #[cfg(feature = "rpc")] #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))] pub mod util; @@ -462,9 +465,7 @@ pub mod channel { Sender::Boxed(f) => f(value).await, } } - } - impl Sender { /// Check if this is a remote sender pub fn is_rpc(&self) -> bool where @@ -477,6 +478,45 @@ pub mod channel { } } + impl Sender { + /// Applies a filter before sending. + /// + /// Messages that don't pass the filter are dropped. + pub fn with_filter(self, f: impl Fn(&T) -> bool + Send + Sync + 'static) -> Sender { + self.with_filter_map(move |u| if f(&u) { Some(u) } else { None }) + } + + /// Applies a transform before sending. + pub fn with_map(self, f: F) -> Sender + where + F: Fn(U) -> T + Send + Sync + 'static, + U: Send + Sync + 'static, + { + self.with_filter_map(move |u| Some(f(u))) + } + + /// Applies a filter and transform before sending. + /// + /// Messages that don't pass the filter are dropped. + pub fn with_filter_map(self, f: F) -> Sender + where + F: Fn(U) -> Option + Send + Sync + 'static, + U: Send + Sync + 'static, + { + let inner: BoxedSender = Box::new(move |value| { + let opt = f(value); + Box::pin(async move { + if let Some(v) = opt { + self.send(v).await + } else { + Ok(()) + } + }) + }); + Sender::Boxed(inner) + } + } + impl crate::sealed::Sealed for Sender {} impl crate::Sender for Sender {} @@ -546,10 +586,9 @@ pub mod channel { /// /// For the rpc case, the send side can not be cloned, hence mpsc instead of mpsc. pub mod mpsc { - use std::{fmt::Debug, future::Future, pin::Pin, sync::Arc}; + use std::{fmt::Debug, future::Future, marker::PhantomData, pin::Pin, sync::Arc}; use super::{RecvError, SendError}; - use crate::RpcMessage; /// Create a local mpsc sender and receiver pair, with the given buffer size. /// @@ -562,12 +601,20 @@ pub mod channel { /// Single producer, single consumer sender. /// /// For the local case, this wraps a tokio::sync::mpsc::Sender. - #[derive(Clone)] pub enum Sender { Tokio(tokio::sync::mpsc::Sender), Boxed(Arc>), } + impl Clone for Sender { + fn clone(&self) -> Self { + match self { + Self::Tokio(tx) => Self::Tokio(tx.clone()), + Self::Boxed(inner) => Self::Boxed(inner.clone()), + } + } + } + impl Sender { pub fn is_rpc(&self) -> bool where @@ -579,20 +626,10 @@ pub mod channel { } } - pub async fn closed(&self) - where - T: RpcMessage, - { - match self { - Sender::Tokio(tx) => tx.closed().await, - Sender::Boxed(sink) => sink.closed().await, - } - } - #[cfg(feature = "stream")] pub fn into_sink(self) -> impl n0_future::Sink + Send + 'static where - T: RpcMessage, + T: Send + Sync + 'static, { futures_util::sink::unfold(self, |sink, value| async move { sink.send(value).await?; @@ -601,6 +638,58 @@ pub mod channel { } } + impl Sender { + /// Applies a filter before sending. + /// + /// Messages that don't pass the filter are dropped. + /// + /// If you want to combine multiple filters and maps with minimal + /// overhead, use `with_filter_map` directly. + pub fn with_filter(self, f: F) -> Sender + where + F: Fn(&T) -> bool + Send + Sync + 'static, + { + self.with_filter_map(move |u| if f(&u) { Some(u) } else { None }) + } + + /// Applies a transform before sending. + /// + /// If you want to combine multiple filters and maps with minimal + /// overhead, use `with_filter_map` directly. + pub fn with_map(self, f: F) -> Sender + where + F: Fn(U) -> T + Send + Sync + 'static, + U: Send + Sync + 'static, + { + self.with_filter_map(move |u| Some(f(u))) + } + + /// Applies a filter and transform before sending. + /// + /// Any combination of filters and maps can be expressed using + /// a single filter_map. + pub fn with_filter_map(self, f: F) -> Sender + where + F: Fn(U) -> Option + Send + Sync + 'static, + U: Send + Sync + 'static, + { + let inner: Arc> = Arc::new(FilterMapSender { + f, + sender: self, + _p: PhantomData, + }); + Sender::Boxed(inner) + } + + /// Future that resolves when the sender is closed + pub async fn closed(&self) { + match self { + Sender::Tokio(tx) => tx.closed().await, + Sender::Boxed(sink) => sink.closed().await, + } + } + } + impl From> for Sender { fn from(tx: tokio::sync::mpsc::Sender) -> Self { Self::Tokio(tx) @@ -673,7 +762,7 @@ pub mod channel { } } - impl Sender { + impl Sender { /// Send a message and yield until either it is sent or an error occurs. /// /// ## Cancellation safety @@ -734,7 +823,7 @@ pub mod channel { Boxed(Box>), } - impl Receiver { + impl Receiver { /// Receive a message /// /// Returns Ok(None) if the sender has been dropped or the remote end has @@ -748,6 +837,41 @@ pub mod channel { } } + /// Map messages, transforming them from type T to type U. + pub fn map(self, f: F) -> Receiver + where + F: Fn(T) -> U + Send + Sync + 'static, + U: Send + Sync + 'static, + { + self.filter_map(move |u| Some(f(u))) + } + + /// Filter messages, only passing through those for which the predicate returns true. + /// + /// Messages that don't pass the filter are dropped. + pub fn filter(self, f: F) -> Receiver + where + F: Fn(&T) -> bool + Send + Sync + 'static, + { + self.filter_map(move |u| if f(&u) { Some(u) } else { None }) + } + + /// Filter and map messages, only passing through those for which the function returns Some. + /// + /// Messages that don't pass the filter are dropped. + pub fn filter_map(self, f: F) -> Receiver + where + U: Send + Sync + 'static, + F: Fn(T) -> Option + Send + Sync + 'static, + { + let inner: Box> = Box::new(FilterMapReceiver { + f, + receiver: self, + _p: PhantomData, + }); + Receiver::Boxed(inner) + } + #[cfg(feature = "stream")] pub fn into_stream( self, @@ -789,6 +913,107 @@ pub mod channel { } } + struct FilterMapSender { + f: F, + sender: Sender, + _p: PhantomData, + } + + impl Debug for FilterMapSender { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FilterMapSender").finish_non_exhaustive() + } + } + + impl DynSender for FilterMapSender + where + F: Fn(U) -> Option + Send + Sync + 'static, + T: Send + Sync + 'static, + U: Send + Sync + 'static, + { + fn send( + &self, + value: U, + ) -> Pin> + Send + '_>> { + Box::pin(async move { + if let Some(v) = (self.f)(value) { + self.sender.send(v).await + } else { + Ok(()) + } + }) + } + + fn try_send( + &self, + value: U, + ) -> Pin> + Send + '_>> { + Box::pin(async move { + if let Some(v) = (self.f)(value) { + self.sender.try_send(v).await + } else { + Ok(true) + } + }) + } + + fn is_rpc(&self) -> bool { + self.sender.is_rpc() + } + + fn closed(&self) -> Pin + Send + Sync + '_>> { + match self { + FilterMapSender { + sender: Sender::Tokio(tx), + .. + } => Box::pin(tx.closed()), + FilterMapSender { + sender: Sender::Boxed(sink), + .. + } => sink.closed(), + } + } + } + + struct FilterMapReceiver { + f: F, + receiver: Receiver, + _p: PhantomData, + } + + impl Debug for FilterMapReceiver { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FilterMapReceiver").finish_non_exhaustive() + } + } + + impl DynReceiver for FilterMapReceiver + where + F: Fn(T) -> Option + Send + Sync + 'static, + T: Send + Sync + 'static, + U: Send + Sync + 'static, + { + fn recv( + &mut self, + ) -> Pin< + Box< + dyn Future, RecvError>> + + Send + + Sync + + '_, + >, + > { + Box::pin(async move { + while let Some(msg) = self.receiver.recv().await? { + if let Some(v) = (self.f)(msg) { + return Ok(Some(v)); + } + } + Ok(None) + }) + } + } + impl crate::sealed::Sealed for Receiver {} impl crate::Receiver for Receiver {} } @@ -1045,8 +1270,9 @@ impl Client { } /// Creates a new client from a `tokio::sync::mpsc::Sender`. - pub fn local(tx: tokio::sync::mpsc::Sender) -> Self { - tx.into() + pub fn local(tx: impl Into>) -> Self { + let tx: crate::channel::mpsc::Sender = tx.into(); + Self(ClientInner::Local(tx), PhantomData) } /// Get the local sender. This is useful if you don't care about remote @@ -1406,7 +1632,7 @@ impl Client { #[derive(Debug)] pub(crate) enum ClientInner { - Local(tokio::sync::mpsc::Sender), + Local(crate::channel::mpsc::Sender), #[cfg(feature = "rpc")] #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))] Remote(Box), @@ -1498,7 +1724,7 @@ impl From for io::Error { /// [`WithChannels`]. #[derive(Debug)] #[repr(transparent)] -pub struct LocalSender(tokio::sync::mpsc::Sender); +pub struct LocalSender(crate::channel::mpsc::Sender); impl Clone for LocalSender { fn clone(&self) -> Self { @@ -1508,6 +1734,12 @@ impl Clone for LocalSender { impl From> for LocalSender { fn from(tx: tokio::sync::mpsc::Sender) -> Self { + Self(tx.into()) + } +} + +impl From> for LocalSender { + fn from(tx: crate::channel::mpsc::Sender) -> Self { Self(tx) } } @@ -2168,77 +2400,24 @@ pub enum Request { impl LocalSender { /// Send a message to the service - pub fn send(&self, value: impl Into>) -> SendFut + pub fn send( + &self, + value: impl Into>, + ) -> impl Future> + Send + 'static where T: Channels, S::Message: From>, { let value: S::Message = value.into().into(); - SendFut::new(self.0.clone(), value) + self.send_raw(value) } /// Send a message to the service without the type conversion magic - pub fn send_raw(&self, value: S::Message) -> SendFut { - SendFut::new(self.0.clone(), value) - } -} - -mod send_fut { - use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, - }; - - use tokio::sync::mpsc::Sender; - use tokio_util::sync::PollSender; - - use crate::channel::SendError; - - pub struct SendFut { - poll_sender: PollSender, - value: Option, - } - - impl SendFut { - pub fn new(sender: Sender, value: T) -> Self { - Self { - poll_sender: PollSender::new(sender), - value: Some(value), - } - } - } - - impl Future for SendFut { - type Output = std::result::Result<(), SendError>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.get_mut(); - - // Safely extract the value - let value = match this.value.take() { - Some(v) => v, - None => return Poll::Ready(Ok(())), // Already completed - }; - - // Try to reserve capacity - match this.poll_sender.poll_reserve(cx) { - Poll::Ready(Ok(())) => { - // Send the item - this.poll_sender.send_item(value).ok(); - Poll::Ready(Ok(())) - } - Poll::Ready(Err(_)) => { - // Channel is closed - Poll::Ready(Err(SendError::ReceiverClosed)) - } - Poll::Pending => { - // Restore the value and wait - this.value = Some(value); - Poll::Pending - } - } - } + pub fn send_raw( + &self, + value: S::Message, + ) -> impl Future> + Send + 'static { + let x = self.0.clone(); + async move { x.send(value).await } } } -use send_fut::SendFut; diff --git a/src/tests.rs b/src/tests.rs new file mode 100644 index 0000000..fa57c13 --- /dev/null +++ b/src/tests.rs @@ -0,0 +1,28 @@ +use std::vec; + +#[tokio::test] +async fn test_map_filter() { + use crate::channel::mpsc; + let (tx, rx) = mpsc::channel::(100); + // *2, filter multipes of 4, *3 if multiple of 8 + // + // the transforms are applied in reverse order! + let tx = tx + .with_filter_map(|x: u64| if x % 8 == 0 { Some(x * 3) } else { None }) + .with_filter(|x| x % 4 == 0) + .with_map(|x: u64| x * 2); + for i in 0..100 { + tx.send(i).await.ok(); + } + drop(tx); + // /24, filter multiples of 3, /2 if even + let mut rx = rx + .map(|x: u64| x / 24) + .filter(|x| x % 3 == 0) + .filter_map(|x: u64| if x % 2 == 0 { Some(x / 2) } else { None }); + let mut res = vec![]; + while let Ok(Some(x)) = rx.recv().await { + res.push(x); + } + assert_eq!(res, vec![0, 3, 6, 9, 12]); +}