Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 11 additions & 21 deletions hyperactor/src/channel/net/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -640,34 +640,24 @@ async fn run<M: RemoteMessage>(
let _ = notify.send(TxStatus::Closed);

match conn {
Conn::Connected { write_state, .. } => {
let write_half = match write_state {
WriteState::Writing(mut frame_writer, ()) => {
if let Err(err) = frame_writer.send().await {
tracing::info!(
parent: &span,
dest = %dest,
error = %err,
session_id = session_id,
"write error during cleanup"
);
}
Some(frame_writer.complete())
}
WriteState::Idle(writer) => Some(writer),
WriteState::Broken => None,
};

if let Some(mut w) = write_half {
if let Err(err) = w.shutdown().await {
Conn::Connected {
mut write_state, ..
} => {
if let WriteState::Writing(frame_writer, ()) = &mut write_state {
if let Err(err) = frame_writer.send().await {
tracing::info!(
parent: &span,
dest = %dest,
error = %err,
session_id = session_id,
"failed to shutdown NetTx write stream during cleanup"
"write error during cleanup"
);
}
};
if let Some(mut w) = write_state.into_writer() {
// Try to shutdown the connection gracefully. This is a best effort
// operation, and we don't care if it fails.
let _ = w.shutdown().await;
}
}
Conn::Disconnected(_) => (),
Expand Down
15 changes: 15 additions & 0 deletions hyperactor/src/channel/net/framed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,21 @@ impl<W: AsyncWrite + Unpin, F: Buf, T> WriteState<W, F, T> {
Self::Broken => panic!("illegal state"),
}
}

/// Consume the state and return the underlying writer, if the
/// stream is not broken.
///
/// For `Idle`, this returns the stored writer. For `Writing`,
/// this assumes no more frames will be sent and calls
/// `complete()` to recover the writer. For `Broken`, this returns
/// `None`.
pub fn into_writer(self) -> Option<W> {
match self {
Self::Idle(w) => Some(w),
Self::Writing(w, _) => Some(w.complete()),
Self::Broken => None,
}
}
}

#[cfg(test)]
Expand Down
16 changes: 13 additions & 3 deletions hyperactor/src/channel/net/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use dashmap::DashMap;
use dashmap::mapref::entry::Entry;
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
use tokio::io::AsyncWriteExt as _;
use tokio::io::ReadHalf;
use tokio::io::WriteHalf;
use tokio::sync::mpsc;
Expand Down Expand Up @@ -80,7 +81,7 @@ impl<S: AsyncRead + AsyncWrite + Send + 'static + Unpin> ServerConn<S> {

/// Handles a server side stream created during the `listen` loop.
async fn process<M: RemoteMessage>(
&mut self,
mut self,
session_id: u64,
tx: mpsc::Sender<M>,
cancel_token: CancellationToken,
Expand Down Expand Up @@ -421,6 +422,12 @@ impl<S: AsyncRead + AsyncWrite + Send + 'static + Unpin> ServerConn<S> {
};
}

if let Some(mut w) = self.write_state.into_writer() {
// Try to shutdown the connection gracefully. This is a best effort
// operation, and we don't care if it fails.
let _ = w.shutdown().await;
}

(final_next, final_result)
}

Expand Down Expand Up @@ -524,14 +531,17 @@ impl SessionManager {
}
};

let source = conn.source.clone();
let dest = conn.dest.clone();

let next = session_var.take().await;
let (next, res) = conn.process(session_id, tx, cancel_token, next).await;
session_var.put(next).await;

if let Err(ref err) = res {
tracing::info!(
source = %conn.source,
dest = %conn.dest,
source = %source,
dest = %dest,
error = ?err,
session_id = session_id,
"process encountered an error"
Expand Down
Loading