Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion hyperactor/src/channel/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,10 @@ enum Frame<M> {
#[derive(Debug, Serialize, Deserialize, EnumAsInner)]
enum NetRxResponse {
Ack(u64),
/// This channel is closed with the given reason. NetTx should stop reconnecting.
/// This session is rejected with the given reason. NetTx should stop reconnecting.
Reject(String),
/// This channel is closed.
Closed,
}

fn serialize_response(response: NetRxResponse) -> Result<Bytes, bincode::Error> {
Expand Down Expand Up @@ -1612,6 +1614,9 @@ mod tests {
handle.await.unwrap().unwrap();
// mpsc is closed too and there should be no unread message left.
assert!(rx.recv().await.is_none());
// should send NetRxResponse::Closed before stopping server.
let bytes = reader.next().await.unwrap().unwrap();
assert!(deserialize_response(bytes).unwrap().is_closed());
// No more acks from server.
assert!(reader.next().await.unwrap().is_none());
};
Expand Down Expand Up @@ -1646,6 +1651,9 @@ mod tests {
handle.await.unwrap().unwrap();
// mpsc is closed too and there should be no unread message left.
assert!(rx.recv().await.is_none());
// should send NetRxResponse::Closed before stopping server.
let bytes = reader.next().await.unwrap().unwrap();
assert!(deserialize_response(bytes).unwrap().is_closed());
// No more acks from server.
assert!(reader.next().await.unwrap().is_none());
}
Expand Down Expand Up @@ -2385,4 +2393,41 @@ mod tests {
let bytes = reader.next().await.unwrap().unwrap();
assert!(deserialize_response(bytes).unwrap().is_reject());
}

#[async_timed_test(timeout_secs = 60)]
// TODO: OSS: called `Result::unwrap()` on an `Err` value: Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" })
#[cfg_attr(not(fbcode_build), ignore)]
async fn test_stop_net_tx_after_stopping_net_rx() {
hyperactor_telemetry::initialize_logging_for_test();

let config = config::global::lock();
let _guard =
config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(300));
let (addr, mut rx) = tcp::serve::<u64>("[::1]:0".parse().unwrap()).unwrap();
let socket_addr = match addr {
ChannelAddr::Tcp(a) => a,
_ => panic!("unexpected channel type"),
};
let tx = tcp::dial::<u64>(socket_addr);
// NetTx will not establish a connection until it sends the 1st message.
// Without a live connection, NetTx cannot received the Closed message
// from NetRx. Therefore, we need to send a message to establish the
//connection.
tx.send(100).await.unwrap();
assert_eq!(rx.recv().await.unwrap(), 100);
// Drop rx will close the NetRx server.
rx.2.stop("testing");
assert!(rx.recv().await.is_err());

// NetTx will only read from the stream when it needs to send a message
// or wait for an ack. Therefore we need to send a message to trigger that.
tx.post(101);
let mut watcher = tx.status().clone();
// When NetRx exits, it should notify NetTx to exit as well.
let _ = watcher.wait_for(|val| *val == TxStatus::Closed).await;
// wait_for could return Err due to race between when watch's sender was
// dropped and when wait_for was called. So we still need to do an
// equality check.
assert_eq!(*watcher.borrow(), TxStatus::Closed);
}
}
46 changes: 24 additions & 22 deletions hyperactor/src/channel/net/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -640,34 +640,24 @@ async fn run<M: RemoteMessage>(
let _ = notify.send(TxStatus::Closed);

match conn {
Conn::Connected { write_state, .. } => {
let write_half = match write_state {
WriteState::Writing(mut frame_writer, ()) => {
if let Err(err) = frame_writer.send().await {
tracing::info!(
parent: &span,
dest = %dest,
error = %err,
session_id = session_id,
"write error during cleanup"
);
}
Some(frame_writer.complete())
}
WriteState::Idle(writer) => Some(writer),
WriteState::Broken => None,
};

if let Some(mut w) = write_half {
if let Err(err) = w.shutdown().await {
Conn::Connected {
mut write_state, ..
} => {
if let WriteState::Writing(frame_writer, ()) = &mut write_state {
if let Err(err) = frame_writer.send().await {
tracing::info!(
parent: &span,
dest = %dest,
error = %err,
session_id = session_id,
"failed to shutdown NetTx write stream during cleanup"
"write error during cleanup"
);
}
};
if let Some(mut w) = write_state.into_writer() {
// Try to shutdown the connection gracefully. This is a best effort
// operation, and we don't care if it fails.
let _ = w.shutdown().await;
}
}
Conn::Disconnected(_) => (),
Expand Down Expand Up @@ -953,7 +943,19 @@ where
);
(State::Closing {
deliveries: Deliveries{outbox, unacked},
reason: format!("{log_id}: {error_msg}"),
reason: error_msg,
}, Conn::reconnect_with_default())
}
NetRxResponse::Closed => {
let msg = "server closed the channel".to_string();
tracing::info!(
dest = %link.dest(),
session_id = session_id,
"{}", msg
);
(State::Closing {
deliveries: Deliveries{outbox, unacked},
reason: msg,
}, Conn::reconnect_with_default())
}
}
Expand Down
15 changes: 15 additions & 0 deletions hyperactor/src/channel/net/framed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,21 @@ impl<W: AsyncWrite + Unpin, F: Buf, T> WriteState<W, F, T> {
Self::Broken => panic!("illegal state"),
}
}

/// Consume the state and return the underlying writer, if the
/// stream is not broken.
///
/// For `Idle`, this returns the stored writer. For `Writing`,
/// this assumes no more frames will be sent and calls
/// `complete()` to recover the writer. For `Broken`, this returns
/// `None`.
pub fn into_writer(self) -> Option<W> {
match self {
Self::Idle(w) => Some(w),
Self::Writing(w, _) => Some(w.complete()),
Self::Broken => None,
}
}
}

#[cfg(test)]
Expand Down
40 changes: 31 additions & 9 deletions hyperactor/src/channel/net/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use dashmap::DashMap;
use dashmap::mapref::entry::Entry;
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
use tokio::io::AsyncWriteExt as _;
use tokio::io::ReadHalf;
use tokio::io::WriteHalf;
use tokio::sync::mpsc;
Expand Down Expand Up @@ -80,15 +81,19 @@ impl<S: AsyncRead + AsyncWrite + Send + 'static + Unpin> ServerConn<S> {

/// Handles a server side stream created during the `listen` loop.
async fn process<M: RemoteMessage>(
&mut self,
mut self,
session_id: u64,
tx: mpsc::Sender<M>,
cancel_token: CancellationToken,
mut next: Next,
) -> (Next, Result<(), anyhow::Error>) {
#[derive(Debug)]
enum RejectConn {
Yes(String),
/// Reject the connection due to the given error.
EncounterError(String),
/// The server is being closed.
ServerClosing,
/// Do not reject the connection.
No,
}

Expand Down Expand Up @@ -169,7 +174,7 @@ impl<S: AsyncRead + AsyncWrite + Send + 'static + Unpin> ServerConn<S> {
// Have a tick to abort select! call to make sure the ack for the last message can get the chance
// to be sent as a result of time interval being reached.
_ = RealClock.sleep_until(last_ack_time + ack_time_interval), if next.ack < next.seq => {},
_ = cancel_token.cancelled() => break (next, Ok(()), RejectConn::No),
_ = cancel_token.cancelled() => break (next, Ok(()), RejectConn::ServerClosing),
bytes_result = self.reader.next() => {
rcv_raw_frame_count += 1;
// First handle transport-level I/O errors, and EOFs.
Expand Down Expand Up @@ -230,7 +235,7 @@ impl<S: AsyncRead + AsyncWrite + Send + 'static + Unpin> ServerConn<S> {
break (
next,
Err(anyhow::anyhow!("{log_id}: unexpected init frame")),
RejectConn::Yes("expect Frame::Message; got Frame::Int".to_string()),
RejectConn::EncounterError("expect Frame::Message; got Frame::Int".to_string()),
)
},
// Ignore retransmits.
Expand Down Expand Up @@ -259,7 +264,7 @@ impl<S: AsyncRead + AsyncWrite + Send + 'static + Unpin> ServerConn<S> {
break (
next,
Err(anyhow::anyhow!(format!("{log_id}: {error_msg}"))),
RejectConn::Yes(error_msg),
RejectConn::EncounterError(error_msg),
)
}
match self.send_with_buffer_metric(session_id, &tx, message).await {
Expand Down Expand Up @@ -390,12 +395,20 @@ impl<S: AsyncRead + AsyncWrite + Send + 'static + Unpin> ServerConn<S> {
}

if self.write_state.is_idle()
&& let RejectConn::Yes(reason) = reject_conn
&& matches!(
reject_conn,
RejectConn::EncounterError(_) | RejectConn::ServerClosing
)
{
let Ok(writer) = replace(&mut self.write_state, WriteState::Broken).into_idle() else {
panic!("illegal state");
};
if let Ok(data) = serialize_response(NetRxResponse::Reject(reason)) {
let rsp = match reject_conn {
RejectConn::EncounterError(reason) => NetRxResponse::Reject(reason),
RejectConn::ServerClosing => NetRxResponse::Closed,
RejectConn::No => panic!("illegal state"),
};
if let Ok(data) = serialize_response(rsp) {
match FrameWrite::new(
writer,
data,
Expand All @@ -421,6 +434,12 @@ impl<S: AsyncRead + AsyncWrite + Send + 'static + Unpin> ServerConn<S> {
};
}

if let Some(mut w) = self.write_state.into_writer() {
// Try to shutdown the connection gracefully. This is a best effort
// operation, and we don't care if it fails.
let _ = w.shutdown().await;
}

(final_next, final_result)
}

Expand Down Expand Up @@ -524,14 +543,17 @@ impl SessionManager {
}
};

let source = conn.source.clone();
let dest = conn.dest.clone();

let next = session_var.take().await;
let (next, res) = conn.process(session_id, tx, cancel_token, next).await;
session_var.put(next).await;

if let Err(ref err) = res {
tracing::info!(
source = %conn.source,
dest = %conn.dest,
source = %source,
dest = %dest,
error = ?err,
session_id = session_id,
"process encountered an error"
Expand Down