From a1707ba000684867b75cab480056460d4f790899 Mon Sep 17 00:00:00 2001 From: Dave Kwon Date: Wed, 5 Nov 2025 17:51:05 -0800 Subject: [PATCH] try_post to have SendError return_channel (#1568) Summary: Refactor so that caller does not have to maintain additional error handling. Differential Revision: D84834014 --- .../src/channels/transports/local.md | 1 - .../hyperactor-book/src/channels/tx_rx.md | 9 +- .../src/mailboxes/mailbox_client.md | 23 ++-- hyperactor/benches/main.rs | 4 +- hyperactor/src/channel.rs | 49 ++++--- hyperactor/src/channel/local.rs | 34 ++--- hyperactor/src/channel/net.rs | 125 +++++++++++------- hyperactor/src/channel/sim.rs | 76 +++++------ hyperactor/src/mailbox.rs | 25 ++-- hyperactor_mesh/src/bootstrap.rs | 3 +- 10 files changed, 183 insertions(+), 166 deletions(-) diff --git a/docs/source/books/hyperactor-book/src/channels/transports/local.md b/docs/source/books/hyperactor-book/src/channels/transports/local.md index 17886e2b0..fdfbd6bae 100644 --- a/docs/source/books/hyperactor-book/src/channels/transports/local.md +++ b/docs/source/books/hyperactor-book/src/channels/transports/local.md @@ -12,4 +12,3 @@ **notes:** - `Tx::send` completes after local enqueue (oneshot dropped). -- if the receiver is dropped, `try_post` fails immediately with `Err(SendError(ChannelError::Closed, message))`. diff --git a/docs/source/books/hyperactor-book/src/channels/tx_rx.md b/docs/source/books/hyperactor-book/src/channels/tx_rx.md index 4aafec1e3..16d18b176 100644 --- a/docs/source/books/hyperactor-book/src/channels/tx_rx.md +++ b/docs/source/books/hyperactor-book/src/channels/tx_rx.md @@ -22,7 +22,7 @@ Under the hood, network transports use a length-prefixed, multipart frame with c ```rust #[async_trait] pub trait Tx: std::fmt::Debug { - fn try_post(&self, message: M, return_channel: oneshot::Sender) -> Result<(), SendError>; + fn try_post(&self, message: M, return_channel: oneshot::Sender>); fn post(&self, message: M); async fn send(&self, message: M) -> Result<(), SendError>; fn addr(&self) -> ChannelAddr; @@ -32,8 +32,7 @@ pub trait Tx: std::fmt::Debug { - **`try_post(message, return_channel)`** Enqueues locally. - - Immediate failure → `Err(SendError(ChannelError::Closed, message))`. - - `Ok(())` means queued; if delivery later fails, the original message is sent back on `return_channel`. + - If delivery later fails, the original message is sent back on `return_channel` as SendError. - **`post(message)`** Fire-and-forget wrapper around `try_post`. The caller should monitor `status()` for health instead of relying on return values. @@ -91,7 +90,7 @@ pub trait Rx: std::fmt::Debug { ### Failure semantics - **Closed receiver:** `recv()` returns `Err(ChannelError::Closed)`. - **Network transports:** disconnects trigger exponential backoff reconnects; unacked messages are retried. If recovery ultimately fails (e.g., connection cannot be re-established within the delivery timeout window), the client closes and returns all undelivered/unacked messages via their `return_channel`. `status()` flips to `Closed`. -- **Local transport:** no delayed return path; if the receiver is gone, `try_post` fails immediately with `Err(SendError(ChannelError::Closed, message))`. +- **Local transport:** no delayed return path. - **Network disconnects (EOF/I/O error/temporary break):** the client reconnects with exponential backoff and resends any unacked messages; the server deduplicates by `seq`. - **Delivery timeout:** see [Size & time limits](#size--time-limits). @@ -104,7 +103,7 @@ pub trait Rx: std::fmt::Debug { Concrete channel implementations that satisfy `Tx` / `Rx`: -- **Local** — in-process only; uses `tokio::sync::mpsc`. No network framing/acks. `try_post` fails immediately if the receiver is gone. +- **Local** — in-process only; uses `tokio::sync::mpsc`. No network framing/acks. _Dial/serve:_ `serve_local::()`, `ChannelAddr::Local(_)`. - **TCP** — `tokio::net::TcpStream` with 8-byte BE length-prefixed frames; `seq`/`ack` for exactly-once into the server queue; reconnects with backoff. diff --git a/docs/source/books/hyperactor-book/src/mailboxes/mailbox_client.md b/docs/source/books/hyperactor-book/src/mailboxes/mailbox_client.md index b2021fb3d..3725a4981 100644 --- a/docs/source/books/hyperactor-book/src/mailboxes/mailbox_client.md +++ b/docs/source/books/hyperactor-book/src/mailboxes/mailbox_client.md @@ -124,23 +124,18 @@ impl MailboxClient { let return_handle_0 = return_handle.clone(); tokio::spawn(async move { let result = return_receiver.await; - if let Ok(message) = result { - let _ = return_handle_0.send(Undeliverable(message)); - } else { - // Sender dropped, this task can end. + if let Ok(SendError(e, message)) = result { + message.undeliverable( + DeliveryError::BrokenLink(format!( + "failed to enqueue in MailboxClient when processing buffer: {e}" + )), + return_handle_0, + ); } }); // Send the message for transmission. - let return_handle_1 = return_handle.clone(); - async move { - if let Err(SendError(_, envelope)) = tx.try_post(envelope, return_channel) { - // Failed to enqueue. - envelope.undeliverable( - DeliveryError::BrokenLink("failed to enqueue in MailboxClient".to_string()), - return_handle_1.clone(), - ); - } - } + tx.try_post(envelope, return_channel); + future::ready(()) }); let this = Self { buffer, diff --git a/hyperactor/benches/main.rs b/hyperactor/benches/main.rs index bae362d65..5514a8ac2 100644 --- a/hyperactor/benches/main.rs +++ b/hyperactor/benches/main.rs @@ -150,9 +150,7 @@ fn bench_message_rates(c: &mut Criterion) { Vec::with_capacity(rate as usize); for _ in 0..rate { let (return_sender, return_receiver) = oneshot::channel(); - if let Err(e) = tx.try_post(message.clone(), return_sender) { - panic!("Failed to send message: {:?}", e); - } + tx.try_post(message.clone(), return_sender); let handle = tokio::spawn(async move { _ = tokio::time::timeout( diff --git a/hyperactor/src/channel.rs b/hyperactor/src/channel.rs index 43db21e46..08c7b9a5a 100644 --- a/hyperactor/src/channel.rs +++ b/hyperactor/src/channel.rs @@ -109,28 +109,24 @@ pub enum TxStatus { pub trait Tx: std::fmt::Debug { /// Enqueue a `message` on the local end of the channel. The /// message is either delivered, or we eventually discover that - /// the channel has failed and it will be sent back on `return_handle`. - // TODO: the return channel should be SendError directly, and we should drop - // the returned result. + /// the channel has failed and it will be sent back on `return_channel`. #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SendError`. - fn try_post(&self, message: M, return_channel: oneshot::Sender) -> Result<(), SendError>; + // TODO: Consider making return channel optional to indicate that the log can be dropped. + fn try_post(&self, message: M, return_channel: oneshot::Sender>); - /// Enqueue a message to be sent on the channel. The caller is expected to monitor - /// the channel status for failures. + /// Enqueue a message to be sent on the channel. fn post(&self, message: M) { - // We ignore errors here because the caller is meant to monitor the channel's - // status, rather than rely on this function to report errors. - let _ignore = self.try_post(message, oneshot::channel().0); + self.try_post(message, oneshot::channel().0); } /// Send a message synchronously, returning when the messsage has /// been delivered to the remote end of the channel. async fn send(&self, message: M) -> Result<(), SendError> { let (tx, rx) = oneshot::channel(); - self.try_post(message, tx)?; + self.try_post(message, tx); match rx.await { // Channel was closed; the message was not delivered. - Ok(m) => Err(SendError(ChannelError::Closed, m)), + Ok(err) => Err(err), // Channel was dropped; the message was successfully enqueued // on the remote end of the channel. @@ -179,14 +175,12 @@ impl MpscTx { #[async_trait] impl Tx for MpscTx { - fn try_post( - &self, - message: M, - _return_channel: oneshot::Sender, - ) -> Result<(), SendError> { - self.tx - .send(message) - .map_err(|mpsc::error::SendError(message)| SendError(ChannelError::Closed, message)) + fn try_post(&self, message: M, return_channel: oneshot::Sender>) { + if let Err(mpsc::error::SendError(message)) = self.tx.send(message) { + if let Err(m) = return_channel.send(SendError(ChannelError::Closed, message)) { + tracing::warn!("failed to deliver SendError: {}", m); + } + } } fn addr(&self) -> ChannelAddr { @@ -749,7 +743,7 @@ enum ChannelTxKind { #[async_trait] impl Tx for ChannelTx { - fn try_post(&self, message: M, return_channel: oneshot::Sender) -> Result<(), SendError> { + fn try_post(&self, message: M, return_channel: oneshot::Sender>) { match &self.inner { ChannelTxKind::Local(tx) => tx.try_post(message, return_channel), ChannelTxKind::Tcp(tx) => tx.try_post(message, return_channel), @@ -1054,7 +1048,7 @@ mod tests { let addr = listen_addr.clone(); sends.spawn(async move { let tx = dial::(addr).unwrap(); - tx.try_post(message, oneshot::channel().0).unwrap(); + tx.post(message); }); } @@ -1089,7 +1083,7 @@ mod tests { let (listen_addr, rx) = crate::channel::serve::(addr).unwrap(); let tx = dial::(listen_addr).unwrap(); - tx.try_post(123, oneshot::channel().0).unwrap(); + tx.post(123); drop(rx); // New transmits should fail... but there is buffering, etc., @@ -1099,12 +1093,15 @@ mod tests { let start = RealClock.now(); let result = loop { - let result = tx.try_post(123, oneshot::channel().0); - if result.is_err() || start.elapsed() > Duration::from_secs(10) { + let (return_tx, return_rx) = oneshot::channel(); + tx.try_post(123, return_tx); + let result = return_rx.await; + + if result.is_ok() || start.elapsed() > Duration::from_secs(10) { break result; } }; - assert_matches!(result, Err(SendError(ChannelError::Closed, 123))); + assert_matches!(result, Ok(SendError(ChannelError::Closed, 123))); } } @@ -1137,7 +1134,7 @@ mod tests { for addr in addrs() { let (listen_addr, mut rx) = crate::channel::serve::(addr).unwrap(); let tx = crate::channel::dial(listen_addr).unwrap(); - tx.try_post(123, oneshot::channel().0).unwrap(); + tx.post(123); assert_eq!(rx.recv().await.unwrap(), 123); } } diff --git a/hyperactor/src/channel/local.rs b/hyperactor/src/channel/local.rs index 6fc2eef65..a6735fe7f 100644 --- a/hyperactor/src/channel/local.rs +++ b/hyperactor/src/channel/local.rs @@ -72,18 +72,21 @@ pub struct LocalTx { #[async_trait] impl Tx for LocalTx { - fn try_post( - &self, - message: M, - _return_channel: oneshot::Sender, - ) -> Result<(), SendError> { + fn try_post(&self, message: M, return_channel: oneshot::Sender>) { let data: Data = match bincode::serialize(&message) { Ok(data) => data, - Err(err) => return Err(SendError(err.into(), message)), + Err(err) => { + if let Err(m) = return_channel.send(SendError(err.into(), message)) { + tracing::warn!("failed to deliver SendError: {}", m); + } + return; + } }; - self.tx - .send(data) - .map_err(|_| SendError(ChannelError::Closed, message)) + if self.tx.send(data).is_err() { + if let Err(m) = return_channel.send(SendError(ChannelError::Closed, message)) { + tracing::warn!("failed to deliver SendError: {}", m); + } + } } fn addr(&self) -> ChannelAddr { @@ -167,7 +170,7 @@ mod tests { async fn test_local_basic() { let (tx, mut rx) = local::new::(); - tx.try_post(123, unused_return_channel()).unwrap(); + tx.try_post(123, unused_return_channel()); assert_eq!(rx.recv().await.unwrap(), 123); } @@ -178,15 +181,14 @@ mod tests { let tx = local::dial::(port).unwrap(); - tx.try_post(123, unused_return_channel()).unwrap(); + tx.try_post(123, unused_return_channel()); assert_eq!(rx.recv().await.unwrap(), 123); drop(rx); - assert_matches!( - tx.try_post(123, unused_return_channel()), - Err(SendError(ChannelError::Closed, 123)) - ); + let (return_tx, return_rx) = oneshot::channel(); + tx.try_post(123, return_tx); + assert_matches!(return_rx.await, Ok(SendError(ChannelError::Closed, 123))); } #[tokio::test] @@ -194,7 +196,7 @@ mod tests { let (port, mut rx) = local::serve::(); let tx = local::dial::(port).unwrap(); - tx.try_post(123, unused_return_channel()).unwrap(); + tx.try_post(123, unused_return_channel()); assert_eq!(rx.recv().await.unwrap(), 123); drop(rx); diff --git a/hyperactor/src/channel/net.rs b/hyperactor/src/channel/net.rs index e9051487e..7f74af2e3 100644 --- a/hyperactor/src/channel/net.rs +++ b/hyperactor/src/channel/net.rs @@ -88,7 +88,6 @@ use tokio_util::net::Listener; use tokio_util::sync::CancellationToken; use super::*; -use crate::Message; use crate::RemoteMessage; use crate::clock::Clock; use crate::clock::RealClock; @@ -176,8 +175,8 @@ fn serialize_bincode( /// A Tx implemented on top of a Link. The Tx manages the link state, /// reconnections, etc. #[derive(Debug)] -pub(crate) struct NetTx { - sender: mpsc::UnboundedSender<(M, oneshot::Sender, Instant)>, +pub(crate) struct NetTx { + sender: mpsc::UnboundedSender<(M, oneshot::Sender>, Instant)>, dest: ChannelAddr, status: watch::Receiver, } @@ -202,7 +201,7 @@ impl NetTx { // hard to maintain. async fn run( link: impl Link, - mut receiver: mpsc::UnboundedReceiver<(M, oneshot::Sender, Instant)>, + mut receiver: mpsc::UnboundedReceiver<(M, oneshot::Sender>, Instant)>, notify: watch::Sender, ) { // If we can't deliver a message within this limit consider @@ -215,7 +214,7 @@ impl NetTx { // When this message was written to the stream. None means it is not // written yet. sent_at: Option, - return_channel: oneshot::Sender, + return_channel: oneshot::Sender>, } impl fmt::Display for QueuedMessage { @@ -304,7 +303,12 @@ impl NetTx { pub(crate) fn try_return(self) { match serde_multipart::deserialize_bincode::>(self.message) { Ok(Frame::Message(_, msg)) => { - let _ = self.return_channel.send(msg); + if let Err(m) = self + .return_channel + .send(SendError(ChannelError::Closed, msg)) + { + tracing::warn!("failed to deliver SendError: {}", m); + } } Ok(_) => { tracing::debug!( @@ -364,7 +368,7 @@ impl NetTx { fn push_back( &mut self, - (message, return_channel, received_at): (M, oneshot::Sender, Instant), + (message, return_channel, received_at): (M, oneshot::Sender>, Instant), ) -> Result<(), String> { assert!( self.deque.back().is_none_or(|msg| msg.seq < self.next_seq), @@ -987,7 +991,9 @@ impl NetTx { .chain(outbox.deque.drain(..)) .for_each(|queued| queued.try_return()); while let Ok((msg, return_channel, _)) = receiver.try_recv() { - let _ = return_channel.send(msg); + if let Err(m) = return_channel.send(SendError(ChannelError::Closed, msg)) { + tracing::warn!("failed to deliver SendError: {}", m); + } } } _ => (), @@ -1024,11 +1030,11 @@ impl Tx for NetTx { &self.status } - fn try_post(&self, message: M, return_channel: oneshot::Sender) -> Result<(), SendError> { + fn try_post(&self, message: M, return_channel: oneshot::Sender>) { tracing::trace!(name = "post", "sending message to {}", self.dest); - self.sender - .send((message, return_channel, RealClock.now())) - .map_err(|err| SendError(ChannelError::Closed, err.0.0)) + if let Err(err) = self.sender.send((message, return_channel, RealClock.now())) { + let _ = err.0.1.send(SendError(ChannelError::Closed, err.0.0)); + } } } @@ -2455,6 +2461,7 @@ pub(crate) mod meta { #[cfg(test)] mod tests { + use std::assert_matches::assert_matches; use std::marker::PhantomData; use std::sync::RwLock; use std::sync::atomic::AtomicBool; @@ -2502,21 +2509,30 @@ mod tests { // channel. { let tx = crate::channel::dial::(addr.clone()).unwrap(); - tx.try_post(123, unused_return_channel()).unwrap(); + tx.try_post(123, unused_return_channel()); assert_eq!(rx.recv().await.unwrap(), 123); } { - let tx = dial::(addr).unwrap(); - tx.try_post(321, unused_return_channel()).unwrap(); - tx.try_post(111, unused_return_channel()).unwrap(); - tx.try_post(444, unused_return_channel()).unwrap(); + let tx = dial::(addr.clone()).unwrap(); + tx.try_post(321, unused_return_channel()); + tx.try_post(111, unused_return_channel()); + tx.try_post(444, unused_return_channel()); assert_eq!(rx.recv().await.unwrap(), 321); assert_eq!(rx.recv().await.unwrap(), 111); assert_eq!(rx.recv().await.unwrap(), 444); } + { + let tx = dial::(addr).unwrap(); + drop(rx); + + let (return_tx, return_rx) = oneshot::channel(); + tx.try_post(123, return_tx); + assert_matches!(return_rx.await, Ok(SendError(ChannelError::Closed, 123))); + } + Ok(()) } @@ -2537,14 +2553,14 @@ mod tests { // Dial the channel before we actually serve it. let addr = ChannelAddr::Unix(socket_addr.clone()); let tx = crate::channel::dial::(addr.clone()).unwrap(); - tx.try_post(123, unused_return_channel()).unwrap(); + tx.try_post(123, unused_return_channel()); let (_, mut rx) = net::unix::serve::(socket_addr).unwrap(); assert_eq!(rx.recv().await.unwrap(), 123); - tx.try_post(321, unused_return_channel()).unwrap(); - tx.try_post(111, unused_return_channel()).unwrap(); - tx.try_post(444, unused_return_channel()).unwrap(); + tx.try_post(321, unused_return_channel()); + tx.try_post(111, unused_return_channel()); + tx.try_post(444, unused_return_channel()); assert_eq!(rx.recv().await.unwrap(), 321); assert_eq!(rx.recv().await.unwrap(), 111); @@ -2554,27 +2570,36 @@ mod tests { } #[tracing_test::traced_test] - #[async_timed_test(timeout_secs = 30)] + #[async_timed_test(timeout_secs = 60)] // TODO: OSS: called `Result::unwrap()` on an `Err` value: Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" }) #[cfg_attr(not(feature = "fb"), ignore)] async fn test_tcp_basic() { let (addr, mut rx) = tcp::serve::("[::1]:0".parse().unwrap()).unwrap(); { let tx = dial::(addr.clone()).unwrap(); - tx.try_post(123, unused_return_channel()).unwrap(); + tx.try_post(123, unused_return_channel()); assert_eq!(rx.recv().await.unwrap(), 123); } { - let tx = dial::(addr).unwrap(); - tx.try_post(321, unused_return_channel()).unwrap(); - tx.try_post(111, unused_return_channel()).unwrap(); - tx.try_post(444, unused_return_channel()).unwrap(); + let tx = dial::(addr.clone()).unwrap(); + tx.try_post(321, unused_return_channel()); + tx.try_post(111, unused_return_channel()); + tx.try_post(444, unused_return_channel()); assert_eq!(rx.recv().await.unwrap(), 321); assert_eq!(rx.recv().await.unwrap(), 111); assert_eq!(rx.recv().await.unwrap(), 444); } + + { + let tx = dial::(addr).unwrap(); + drop(rx); + + let (return_tx, return_rx) = oneshot::channel(); + tx.try_post(123, return_tx); + assert_matches!(return_rx.await, Ok(SendError(ChannelError::Closed, 123))); + } } // The message size is limited by CODEC_MAX_FRAME_LENGTH. @@ -2595,17 +2620,16 @@ mod tests { { // Leave some headroom because Tx will wrap the payload in Frame::Message. let message = "a".repeat(default_size_in_bytes - 1024); - tx.try_post(message.clone(), unused_return_channel()) - .unwrap(); + tx.try_post(message.clone(), unused_return_channel()); assert_eq!(rx.recv().await.unwrap(), message); } // Bigger than the default size will fail. { let (return_channel, return_receiver) = oneshot::channel(); let message = "a".repeat(default_size_in_bytes + 1024); - tx.try_post(message.clone(), return_channel).unwrap(); + tx.try_post(message.clone(), return_channel); let returned = return_receiver.await.unwrap(); - assert_eq!(message, returned); + assert_eq!(message, returned.1); } } @@ -2624,7 +2648,7 @@ mod tests { let (addr, mut net_rx) = tcp::serve::("[::1]:0".parse().unwrap()).unwrap(); let net_tx = dial::(addr.clone()).unwrap(); let (tx, rx) = oneshot::channel(); - net_tx.try_post(1, tx).unwrap(); + net_tx.try_post(1, tx); assert_eq!(net_rx.recv().await.unwrap(), 1); drop(net_rx); // Using `is_err` to confirm the message is delivered/acked is confusing, @@ -2645,19 +2669,28 @@ mod tests { let (local_addr, mut rx) = net::meta::serve::(meta_addr).unwrap(); { let tx = dial::(local_addr.clone()).unwrap(); - tx.try_post(123, unused_return_channel()).unwrap(); + tx.try_post(123, unused_return_channel()); } assert_eq!(rx.recv().await.unwrap(), 123); { - let tx = dial::(local_addr).unwrap(); - tx.try_post(321, unused_return_channel()).unwrap(); - tx.try_post(111, unused_return_channel()).unwrap(); - tx.try_post(444, unused_return_channel()).unwrap(); + let tx = dial::(local_addr.clone()).unwrap(); + tx.try_post(321, unused_return_channel()); + tx.try_post(111, unused_return_channel()); + tx.try_post(444, unused_return_channel()); assert_eq!(rx.recv().await.unwrap(), 321); assert_eq!(rx.recv().await.unwrap(), 111); assert_eq!(rx.recv().await.unwrap(), 444); } + + { + let tx = dial::(local_addr).unwrap(); + drop(rx); + + let (return_tx, return_rx) = oneshot::channel(); + tx.try_post(123, return_tx); + assert_matches!(return_rx.await, Ok(SendError(ChannelError::Closed, 123))); + } } #[tokio::test] @@ -3250,7 +3283,7 @@ mod tests { let _guard = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(1)); let mut tx_receiver = tx.status().clone(); let (return_channel, _return_receiver) = oneshot::channel(); - tx.try_post(123, return_channel).unwrap(); + tx.try_post(123, return_channel); verify_tx_closed(&mut tx_receiver, "failed to deliver message within timeout").await; } @@ -3305,7 +3338,7 @@ mod tests { async fn net_tx_send(tx: &NetTx, msgs: &[u64]) { for msg in msgs { - tx.try_post(*msg, unused_return_channel()).unwrap(); + tx.try_post(*msg, unused_return_channel()); } } @@ -3568,7 +3601,7 @@ mod tests { // Verify sent-and-ack a message. This is necessary for the test to // trigger a connection. let (return_channel_tx, return_channel_rx) = oneshot::channel(); - net_tx.try_post(100, return_channel_tx).unwrap(); + net_tx.try_post(100, return_channel_tx); let (mut reader, mut writer) = take_receiver(&receiver_storage).await; verify_stream(&mut reader, &[(0u64, 100u64)], None, line!()).await; // ack it @@ -3600,7 +3633,7 @@ mod tests { .unwrap(); let (return_channel_tx, return_channel_rx) = oneshot::channel(); - net_tx.try_post(101, return_channel_tx).unwrap(); + net_tx.try_post(101, return_channel_tx); // Verify the message is sent to Rx. verify_message(&mut reader, (1u64, 101u64), line!()).await; // although we did not ack the message after it is sent, since we already @@ -3624,7 +3657,7 @@ mod tests { let tx = NetTx::::new(link); let mut tx_status = tx.status().clone(); // send a message - tx.try_post(100, unused_return_channel()).unwrap(); + tx.try_post(100, unused_return_channel()); let (mut reader, writer) = take_receiver(&receiver_storage).await; // Confirm message is sent to rx. verify_stream(&mut reader, &[(0u64, 100u64)], None, line!()).await; @@ -3642,7 +3675,7 @@ mod tests { assert!(!tx_status.has_changed().unwrap()); assert_eq!(*tx_status.borrow(), TxStatus::Active); - tx.try_post(101, unused_return_channel()).unwrap(); + tx.try_post(101, unused_return_channel()); // Confirm message is sent to rx. verify_message(&mut reader, (1u64, 101u64), line!()).await; @@ -3718,7 +3751,7 @@ mod tests { RealClock .sleep(Duration::from_micros(rand::random::() % 100)) .await; - tx.try_post(message, unused_return_channel()).unwrap(); + tx.try_post(message, unused_return_channel()); } tracing::debug!("NetTx sent all messages"); // It is important to return tx instead of dropping it here, because @@ -3786,7 +3819,7 @@ mod tests { RealClock .sleep(Duration::from_micros(rand::random::() % 100)) .await; - tx.try_post(message, unused_return_channel()).unwrap(); + tx.try_post(message, unused_return_channel()); } RealClock.sleep(Duration::from_secs(5)).await; tracing::debug!("NetTx sent all messages"); @@ -3883,7 +3916,7 @@ mod tests { .map(char::from) .collect::(); for _ in 0..total_num_msgs { - let _ = tx2.try_post(random_string.clone(), unused_return_channel()); + tx2.try_post(random_string.clone(), unused_return_channel()); } })); } diff --git a/hyperactor/src/channel/sim.rs b/hyperactor/src/channel/sim.rs index 3d7207880..e7609ab87 100644 --- a/hyperactor/src/channel/sim.rs +++ b/hyperactor/src/channel/sim.rs @@ -277,10 +277,16 @@ pub(crate) struct SimRx { #[async_trait] impl Tx for SimTx { - fn try_post(&self, message: M, _return_handle: oneshot::Sender) -> Result<(), SendError> { + fn try_post(&self, message: M, return_channel: oneshot::Sender>) { let data = match Serialized::serialize(&message) { Ok(data) => data, - Err(err) => return Err(SendError(err.into(), message)), + Err(err) => { + if let Err(m) = return_channel.send(SendError(err.into(), message)) { + tracing::warn!("failed to deliver SendError: {}", m); + } + + return; + } }; let envelope = (&message as &dyn Any) @@ -297,16 +303,24 @@ impl Tx for SimTx { handle.sample_latency(sender.proc_id(), dest.proc_id()), )); - match &self.src_addr { + let result = match &self.src_addr { Some(_) if self.client => handle.send_scheduled_event(ScheduledEvent { event, time: RealClock.now(), }), _ => handle.send_event(event), + }; + if let Err(err) = result { + if let Err(m) = return_channel.send(SendError(err.into(), message)) { + tracing::warn!("failed to deliver SendError: {}", m); + } + } + } + Err(err) => { + if let Err(m) = return_channel.send(SendError(err.into(), message)) { + tracing::warn!("failed to deliver SendError: {}", m); } } - .map_err(|err: SimNetError| SendError(ChannelError::from(err), message)), - Err(err) => Err(SendError(ChannelError::from(err), message)), } } @@ -430,7 +444,7 @@ mod tests { ); let msg = MessageEnvelope::new(sender, PortId(dest, 0), data.clone(), Attrs::new()); - tx.try_post(msg, oneshot::channel().0).unwrap(); + tx.post(msg); assert_eq!(*rx.recv().await.unwrap().data(), data); } @@ -509,16 +523,12 @@ mod tests { ); // This message will be delievered at simulator time = 100 seconds - tx.try_post( - MessageEnvelope::new( - controller, - PortId(dest, 0), - Serialized::serialize(&456).unwrap(), - Attrs::new(), - ), - oneshot::channel().0, - ) - .unwrap(); + tx.post(MessageEnvelope::new( + controller, + PortId(dest, 0), + Serialized::serialize(&456).unwrap(), + Attrs::new(), + )); { // Allow simnet to run tokio::task::yield_now().await; @@ -594,29 +604,19 @@ mod tests { tokio::time::advance(tokio::time::Duration::from_secs(5)).await; { // Send client message - client_tx - .try_post( - MessageEnvelope::new( - client.clone(), - PortId(dest.clone(), 0), - Serialized::serialize(&456).unwrap(), - Attrs::new(), - ), - oneshot::channel().0, - ) - .unwrap(); + client_tx.post(MessageEnvelope::new( + client.clone(), + PortId(dest.clone(), 0), + Serialized::serialize(&456).unwrap(), + Attrs::new(), + )); // Send system message - controller_tx - .try_post( - MessageEnvelope::new( - controller.clone(), - PortId(dest.clone(), 0), - Serialized::serialize(&456).unwrap(), - Attrs::new(), - ), - oneshot::channel().0, - ) - .unwrap(); + controller_tx.post(MessageEnvelope::new( + controller.clone(), + PortId(dest.clone(), 0), + Serialized::serialize(&456).unwrap(), + Attrs::new(), + )); // Allow some time for simnet to run RealClock.sleep(tokio::time::Duration::from_secs(1)).await; } diff --git a/hyperactor/src/mailbox.rs b/hyperactor/src/mailbox.rs index d59e36ce3..d2ee6a549 100644 --- a/hyperactor/src/mailbox.rs +++ b/hyperactor/src/mailbox.rs @@ -70,6 +70,7 @@ use std::collections::BTreeMap; use std::collections::BTreeSet; use std::fmt; use std::fmt::Debug; +use std::future; use std::future::Future; use std::ops::Bound::Excluded; use std::pin::Pin; @@ -1100,30 +1101,24 @@ impl MailboxClient { let tx_monitoring = CancellationToken::new(); let buffer = Buffer::new(move |envelope, return_handle| { let tx = Arc::clone(&tx); - let (return_channel, return_receiver) = oneshot::channel(); + let (return_channel, return_receiver) = + oneshot::channel::>(); // Set up for delivery failure. let return_handle_0 = return_handle.clone(); tokio::spawn(async move { let result = return_receiver.await; - if let Ok(message) = result { - let _ = return_handle_0.send(Undeliverable(message)); - } else { - // Sender dropped, this task can end. - } - }); - // Send the message for transmission. - let return_handle_1 = return_handle.clone(); - async move { - if let Err(SendError(e, envelope)) = tx.try_post(envelope, return_channel) { - // Failed to enqueue. - envelope.undeliverable( + if let Ok(SendError(e, message)) = result { + message.undeliverable( DeliveryError::BrokenLink(format!( "failed to enqueue in MailboxClient when processing buffer: {e}" )), - return_handle_1.clone(), + return_handle_0, ); } - } + }); + // Send the message for transmission. + tx.try_post(envelope, return_channel); + future::ready(()) }); let this = Self { buffer, diff --git a/hyperactor_mesh/src/bootstrap.rs b/hyperactor_mesh/src/bootstrap.rs index ce262a161..cd55b59fd 100644 --- a/hyperactor_mesh/src/bootstrap.rs +++ b/hyperactor_mesh/src/bootstrap.rs @@ -47,7 +47,6 @@ use hyperactor::config::ConfigAttr; use hyperactor::config::global as config; use hyperactor::context; use hyperactor::declare_attrs; -use hyperactor::host; use hyperactor::host::Host; use hyperactor::host::HostError; use hyperactor::host::ProcHandle; @@ -2127,7 +2126,7 @@ async fn bootstrap_v0_proc_mesh() -> anyhow::Error { tx.try_post( Process2Allocator(bootstrap_index, Process2AllocatorMessage::Hello(serve_addr)), rtx, - )?; + ); tokio::spawn(exit_if_missed_heartbeat(bootstrap_index, bootstrap_addr)); let _ = entered.exit();