Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions hyperactor/src/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,24 @@ pub enum TxStatus {
/// The transmit end of an M-typed channel.
#[async_trait]
pub trait Tx<M: RemoteMessage>: std::fmt::Debug {
/// Post a message; returning failed deliveries on the return channel, if provided.
/// If provided, the sender is dropped when the message has been
/// enqueued at the channel endpoint.
///
/// Users should use the `try_post`, and `post` variants directly.
fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>);

/// Enqueue a `message` on the local end of the channel. The
/// message is either delivered, or we eventually discover that
/// the channel has failed and it will be sent back on `return_channel`.
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SendError`.
// TODO: Consider making return channel optional to indicate that the log can be dropped.
fn try_post(&self, message: M, return_channel: oneshot::Sender<SendError<M>>);
fn try_post(&self, message: M, return_channel: oneshot::Sender<SendError<M>>) {
self.do_post(message, Some(return_channel));
}

/// Enqueue a message to be sent on the channel.
fn post(&self, message: M) {
self.try_post(message, oneshot::channel().0);
self.do_post(message, None);
}

/// Send a message synchronously, returning when the messsage has
Expand Down Expand Up @@ -176,10 +184,12 @@ impl<M: RemoteMessage> MpscTx<M> {

#[async_trait]
impl<M: RemoteMessage> Tx<M> for MpscTx<M> {
fn try_post(&self, message: M, return_channel: oneshot::Sender<SendError<M>>) {
fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
if let Err(mpsc::error::SendError(message)) = self.tx.send(message) {
if let Err(m) = return_channel.send(SendError(ChannelError::Closed, message)) {
tracing::warn!("failed to deliver SendError: {}", m);
if let Some(return_channel) = return_channel {
return_channel
.send(SendError(ChannelError::Closed, message))
.unwrap_or_else(|m| tracing::warn!("failed to deliver SendError: {}", m));
}
}
}
Expand Down Expand Up @@ -744,13 +754,13 @@ enum ChannelTxKind<M: RemoteMessage> {

#[async_trait]
impl<M: RemoteMessage> Tx<M> for ChannelTx<M> {
fn try_post(&self, message: M, return_channel: oneshot::Sender<SendError<M>>) {
fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
match &self.inner {
ChannelTxKind::Local(tx) => tx.try_post(message, return_channel),
ChannelTxKind::Tcp(tx) => tx.try_post(message, return_channel),
ChannelTxKind::MetaTls(tx) => tx.try_post(message, return_channel),
ChannelTxKind::Sim(tx) => tx.try_post(message, return_channel),
ChannelTxKind::Unix(tx) => tx.try_post(message, return_channel),
ChannelTxKind::Local(tx) => tx.do_post(message, return_channel),
ChannelTxKind::Tcp(tx) => tx.do_post(message, return_channel),
ChannelTxKind::MetaTls(tx) => tx.do_post(message, return_channel),
ChannelTxKind::Sim(tx) => tx.do_post(message, return_channel),
ChannelTxKind::Unix(tx) => tx.do_post(message, return_channel),
}
}

Expand Down
14 changes: 9 additions & 5 deletions hyperactor/src/channel/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,19 +72,23 @@ pub struct LocalTx<M: RemoteMessage> {

#[async_trait]
impl<M: RemoteMessage> Tx<M> for LocalTx<M> {
fn try_post(&self, message: M, return_channel: oneshot::Sender<SendError<M>>) {
fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
let data: Data = match bincode::serialize(&message) {
Ok(data) => data,
Err(err) => {
if let Err(m) = return_channel.send(SendError(err.into(), message)) {
tracing::warn!("failed to deliver SendError: {}", m);
if let Some(return_channel) = return_channel {
return_channel
.send(SendError(err.into(), message))
.unwrap_or_else(|m| tracing::warn!("failed to deliver SendError: {}", m));
}
return;
}
};
if self.tx.send(data).is_err() {
if let Err(m) = return_channel.send(SendError(ChannelError::Closed, message)) {
tracing::warn!("failed to deliver SendError: {}", m);
if let Some(return_channel) = return_channel {
return_channel
.send(SendError(ChannelError::Closed, message))
.unwrap_or_else(|m| tracing::warn!("failed to deliver SendError: {}", m));
}
}
}
Expand Down
10 changes: 7 additions & 3 deletions hyperactor/src/channel/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1030,10 +1030,14 @@ impl<M: RemoteMessage> Tx<M> for NetTx<M> {
&self.status
}

fn try_post(&self, message: M, return_channel: oneshot::Sender<SendError<M>>) {
fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
tracing::trace!(name = "post", "sending message to {}", self.dest);
if let Err(err) = self.sender.send((message, return_channel, RealClock.now())) {
let _ = err.0.1.send(SendError(ChannelError::Closed, err.0.0));

let return_channel = return_channel.unwrap_or_else(|| oneshot::channel().0);
if let Err(mpsc::error::SendError((message, return_channel, _))) =
self.sender.send((message, return_channel, RealClock.now()))
{
let _ = return_channel.send(SendError(ChannelError::Closed, message));
}
}
}
Expand Down
22 changes: 15 additions & 7 deletions hyperactor/src/channel/sim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,12 +277,14 @@ pub(crate) struct SimRx<M: RemoteMessage> {

#[async_trait]
impl<M: RemoteMessage + Any> Tx<M> for SimTx<M> {
fn try_post(&self, message: M, return_channel: oneshot::Sender<SendError<M>>) {
fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
let data = match Serialized::serialize(&message) {
Ok(data) => data,
Err(err) => {
if let Err(m) = return_channel.send(SendError(err.into(), message)) {
tracing::warn!("failed to deliver SendError: {}", m);
if let Some(return_channel) = return_channel {
return_channel
.send(SendError(err.into(), message))
.unwrap_or_else(|m| tracing::warn!("failed to deliver SendError: {}", m));
}

return;
Expand Down Expand Up @@ -311,14 +313,20 @@ impl<M: RemoteMessage + Any> Tx<M> for SimTx<M> {
_ => handle.send_event(event),
};
if let Err(err) = result {
if let Err(m) = return_channel.send(SendError(err.into(), message)) {
tracing::warn!("failed to deliver SendError: {}", m);
if let Some(return_channel) = return_channel {
return_channel
.send(SendError(err.into(), message))
.unwrap_or_else(|m| {
tracing::warn!("failed to deliver SendError: {}", m)
});
}
}
}
Err(err) => {
if let Err(m) = return_channel.send(SendError(err.into(), message)) {
tracing::warn!("failed to deliver SendError: {}", m);
if let Some(return_channel) = return_channel {
return_channel
.send(SendError(err.into(), message))
.unwrap_or_else(|m| tracing::warn!("failed to deliver SendError: {}", m));
}
}
}
Expand Down
Loading