diff --git a/hyperactor/src/channel.rs b/hyperactor/src/channel.rs index f22137404..ca106ebc4 100644 --- a/hyperactor/src/channel.rs +++ b/hyperactor/src/channel.rs @@ -108,16 +108,24 @@ pub enum TxStatus { /// The transmit end of an M-typed channel. #[async_trait] pub trait Tx: std::fmt::Debug { + /// Post a message; returning failed deliveries on the return channel, if provided. + /// If provided, the sender is dropped when the message has been + /// enqueued at the channel endpoint. + /// + /// Users should use the `try_post`, and `post` variants directly. + fn do_post(&self, message: M, return_channel: Option>>); + /// Enqueue a `message` on the local end of the channel. The /// message is either delivered, or we eventually discover that /// the channel has failed and it will be sent back on `return_channel`. #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SendError`. - // TODO: Consider making return channel optional to indicate that the log can be dropped. - fn try_post(&self, message: M, return_channel: oneshot::Sender>); + fn try_post(&self, message: M, return_channel: oneshot::Sender>) { + self.do_post(message, Some(return_channel)); + } /// Enqueue a message to be sent on the channel. fn post(&self, message: M) { - self.try_post(message, oneshot::channel().0); + self.do_post(message, None); } /// Send a message synchronously, returning when the messsage has @@ -176,10 +184,12 @@ impl MpscTx { #[async_trait] impl Tx for MpscTx { - fn try_post(&self, message: M, return_channel: oneshot::Sender>) { + fn do_post(&self, message: M, return_channel: Option>>) { if let Err(mpsc::error::SendError(message)) = self.tx.send(message) { - if let Err(m) = return_channel.send(SendError(ChannelError::Closed, message)) { - tracing::warn!("failed to deliver SendError: {}", m); + if let Some(return_channel) = return_channel { + return_channel + .send(SendError(ChannelError::Closed, message)) + .unwrap_or_else(|m| tracing::warn!("failed to deliver SendError: {}", m)); } } } @@ -744,13 +754,13 @@ enum ChannelTxKind { #[async_trait] impl Tx for ChannelTx { - fn try_post(&self, message: M, return_channel: oneshot::Sender>) { + fn do_post(&self, message: M, return_channel: Option>>) { match &self.inner { - ChannelTxKind::Local(tx) => tx.try_post(message, return_channel), - ChannelTxKind::Tcp(tx) => tx.try_post(message, return_channel), - ChannelTxKind::MetaTls(tx) => tx.try_post(message, return_channel), - ChannelTxKind::Sim(tx) => tx.try_post(message, return_channel), - ChannelTxKind::Unix(tx) => tx.try_post(message, return_channel), + ChannelTxKind::Local(tx) => tx.do_post(message, return_channel), + ChannelTxKind::Tcp(tx) => tx.do_post(message, return_channel), + ChannelTxKind::MetaTls(tx) => tx.do_post(message, return_channel), + ChannelTxKind::Sim(tx) => tx.do_post(message, return_channel), + ChannelTxKind::Unix(tx) => tx.do_post(message, return_channel), } } diff --git a/hyperactor/src/channel/local.rs b/hyperactor/src/channel/local.rs index bc157a2d7..0554de8c0 100644 --- a/hyperactor/src/channel/local.rs +++ b/hyperactor/src/channel/local.rs @@ -72,19 +72,23 @@ pub struct LocalTx { #[async_trait] impl Tx for LocalTx { - fn try_post(&self, message: M, return_channel: oneshot::Sender>) { + fn do_post(&self, message: M, return_channel: Option>>) { let data: Data = match bincode::serialize(&message) { Ok(data) => data, Err(err) => { - if let Err(m) = return_channel.send(SendError(err.into(), message)) { - tracing::warn!("failed to deliver SendError: {}", m); + if let Some(return_channel) = return_channel { + return_channel + .send(SendError(err.into(), message)) + .unwrap_or_else(|m| tracing::warn!("failed to deliver SendError: {}", m)); } return; } }; if self.tx.send(data).is_err() { - if let Err(m) = return_channel.send(SendError(ChannelError::Closed, message)) { - tracing::warn!("failed to deliver SendError: {}", m); + if let Some(return_channel) = return_channel { + return_channel + .send(SendError(ChannelError::Closed, message)) + .unwrap_or_else(|m| tracing::warn!("failed to deliver SendError: {}", m)); } } } diff --git a/hyperactor/src/channel/net.rs b/hyperactor/src/channel/net.rs index ac064fe53..f52b556ab 100644 --- a/hyperactor/src/channel/net.rs +++ b/hyperactor/src/channel/net.rs @@ -1030,10 +1030,14 @@ impl Tx for NetTx { &self.status } - fn try_post(&self, message: M, return_channel: oneshot::Sender>) { + fn do_post(&self, message: M, return_channel: Option>>) { tracing::trace!(name = "post", "sending message to {}", self.dest); - if let Err(err) = self.sender.send((message, return_channel, RealClock.now())) { - let _ = err.0.1.send(SendError(ChannelError::Closed, err.0.0)); + + let return_channel = return_channel.unwrap_or_else(|| oneshot::channel().0); + if let Err(mpsc::error::SendError((message, return_channel, _))) = + self.sender.send((message, return_channel, RealClock.now())) + { + let _ = return_channel.send(SendError(ChannelError::Closed, message)); } } } diff --git a/hyperactor/src/channel/sim.rs b/hyperactor/src/channel/sim.rs index e7609ab87..39c1238b2 100644 --- a/hyperactor/src/channel/sim.rs +++ b/hyperactor/src/channel/sim.rs @@ -277,12 +277,14 @@ pub(crate) struct SimRx { #[async_trait] impl Tx for SimTx { - fn try_post(&self, message: M, return_channel: oneshot::Sender>) { + fn do_post(&self, message: M, return_channel: Option>>) { let data = match Serialized::serialize(&message) { Ok(data) => data, Err(err) => { - if let Err(m) = return_channel.send(SendError(err.into(), message)) { - tracing::warn!("failed to deliver SendError: {}", m); + if let Some(return_channel) = return_channel { + return_channel + .send(SendError(err.into(), message)) + .unwrap_or_else(|m| tracing::warn!("failed to deliver SendError: {}", m)); } return; @@ -311,14 +313,20 @@ impl Tx for SimTx { _ => handle.send_event(event), }; if let Err(err) = result { - if let Err(m) = return_channel.send(SendError(err.into(), message)) { - tracing::warn!("failed to deliver SendError: {}", m); + if let Some(return_channel) = return_channel { + return_channel + .send(SendError(err.into(), message)) + .unwrap_or_else(|m| { + tracing::warn!("failed to deliver SendError: {}", m) + }); } } } Err(err) => { - if let Err(m) = return_channel.send(SendError(err.into(), message)) { - tracing::warn!("failed to deliver SendError: {}", m); + if let Some(return_channel) = return_channel { + return_channel + .send(SendError(err.into(), message)) + .unwrap_or_else(|m| tracing::warn!("failed to deliver SendError: {}", m)); } } }