Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 22 additions & 34 deletions hyperactor_mesh/src/bootstrap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ use crate::logging::OutputTarget;
use crate::logging::StreamFwder;
use crate::proc_mesh::mesh_agent::ProcMeshAgent;
use crate::resource;
use crate::resource::Status;
use crate::v1;
use crate::v1::host_mesh::mesh_agent::HostAgentMode;
use crate::v1::host_mesh::mesh_agent::HostMeshAgent;
Expand Down Expand Up @@ -1202,12 +1201,15 @@ impl BootstrapProcHandle {
(out, err)
}

/// Sends a StopAll message to the ProcMeshAgent, which should exit the process.
/// Waits for the successful state change of the process. If the process
/// doesn't reach a terminal state, returns Err.
async fn send_stop_all(
&self,
cx: &impl context::Actor,
agent: ActorRef<ProcMeshAgent>,
timeout: Duration,
) -> anyhow::Result<()> {
) -> anyhow::Result<ProcStatus> {
// For all of the messages and replies in this function:
// if the proc is already dead, then the message will be undeliverable,
// which should be ignored.
Expand All @@ -1217,25 +1219,12 @@ impl BootstrapProcHandle {
let mut agent_port = agent.port();
agent_port.return_undeliverable(false);
agent_port.send(cx, resource::StopAll {})?;
let (reply_port, mut rx) = cx.mailbox().open_port::<Vec<(usize, Status)>>();
let mut reply_port = reply_port.bind();
reply_port.return_undeliverable(false);
// Similar to above, if we cannot query for the stopped actors, just
// proceed with SIGTERM.
let mut agent_port = agent.port();
agent_port.return_undeliverable(false);
agent_port.send(cx, resource::GetAllRankStatus { reply: reply_port })?;
// If there's a timeout waiting for a reply, continue with SIGTERM.
let statuses = RealClock.timeout(timeout, rx.recv()).await??;
let has_failure = statuses.iter().any(|(_rank, status)| status.is_failure());

if has_failure {
Err(anyhow::anyhow!(
"StopAll had some actors that failed: {:?}",
statuses,
))
} else {
Ok(())
// The agent handling Stop should exit the process, if it doesn't within
// the time window, we escalate to SIGTERM.
match RealClock.timeout(timeout, self.wait()).await {
Ok(Ok(st)) => Ok(st),
Ok(Err(e)) => Err(anyhow::anyhow!("agent did not exit the process: {:?}", e)),
Err(_) => Err(anyhow::anyhow!("agent did not exit the process in time")),
}
}
}
Expand Down Expand Up @@ -1361,20 +1350,19 @@ impl hyperactor::host::ProcHandle for BootstrapProcHandle {
// they are in the Ready state and have an Agent we can message.
let agent = self.agent_ref();
if let Some(agent) = agent {
if let Err(e) = self.send_stop_all(cx, agent.clone(), timeout).await {
// Variety of possible errors, proceed with SIGTERM.
tracing::warn!(
"ProcMeshAgent {} could not successfully stop all actors: {}",
agent.actor_id(),
e,
);
match self.send_stop_all(cx, agent.clone(), timeout).await {
Ok(st) => return Ok(st),
Err(e) => {
// Variety of possible errors, proceed with SIGTERM.
tracing::warn!(
"ProcMeshAgent {} could not successfully stop all actors: {}",
agent.actor_id(),
e,
);
}
}
// Even if the StopAll message and response is fully effective, we
// still want to send SIGTERM to actually exit the process and free
// any leftover resources. No actor should be running at this
// point.
}
// After the stop all actors message may be successful, we still need
// If the stop all actors message was unsuccessful, we need
// to actually stop the process.
let _ = self.mark_stopping();

Expand Down Expand Up @@ -1690,7 +1678,7 @@ impl BootstrapProcManager {
/// Return the current [`ProcStatus`] for the given [`ProcId`], if
/// the proc is known to this manager.
///
/// This querprocies the live [`BootstrapProcHandle`] stored in the
/// This queries the live [`BootstrapProcHandle`] stored in the
/// manager's internal map. It provides an immediate snapshot of
/// lifecycle state (`Starting`, `Running`, `Stopping`, `Stopped`,
/// etc.).
Expand Down
97 changes: 22 additions & 75 deletions hyperactor_mesh/src/proc_mesh/mesh_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,6 @@ pub(crate) fn update_event_actor_id(mut event: ActorSupervisionEvent) -> ActorSu
resource::StopAll { cast = true },
resource::GetState<ActorState> { cast = true },
resource::GetRankStatus { cast = true },
resource::GetAllRankStatus { cast = true },
]
)]
pub struct ProcMeshAgent {
Expand Down Expand Up @@ -599,6 +598,10 @@ impl Handler<resource::Stop> for ProcMeshAgent {
}
}

/// Handles `StopAll` by coordinating an orderly stop of child actors and then
/// exiting the process. This handler never returns to the caller: it calls
/// `std::process::exit(0/1)` after shutdown. Any sender must *not* expect a
/// reply or send any further message, and should watch `ProcStatus` instead.
#[async_trait]
impl Handler<resource::StopAll> for ProcMeshAgent {
async fn handle(
Expand All @@ -610,19 +613,26 @@ impl Handler<resource::StopAll> for ProcMeshAgent {
// By passing in the self context, destroy_and_wait will stop this agent
// last, after all others are stopped.
let stop_result = self.destroy_and_wait_except_current(cx, timeout).await;
// Exit here to cleanup all remaining resources held by the process.
// This means ProcMeshAgent will never run cleanup or any other code
// from exiting its root actor. Senders of this message should never
// send any further messages or expect a reply.
match stop_result {
Ok(_) => {
for (_, actor_state) in self.actor_states.iter_mut() {
// Mark all actors as stopped.
actor_state.stopped = true;
}
Ok(())
Ok((stopped_actors, aborted_actors)) => {
// No need to clean up any state, the process is exiting.
tracing::info!(
actor = %cx.self_id(),
"exiting process after receiving StopAll message on ProcMeshAgent. \
stopped actors = {:?}, aborted actors = {:?}",
stopped_actors.into_iter().map(|a| a.to_string()).collect::<Vec<_>>(),
aborted_actors.into_iter().map(|a| a.to_string()).collect::<Vec<_>>(),
);
std::process::exit(0);
}
Err(e) => {
tracing::error!(actor = %cx.self_id(), "failed to stop all actors on ProcMeshAgent: {:?}", e);
std::process::exit(1);
}
Err(e) => Err(anyhow::anyhow!(
"failed to StopAll on {}: {:?}",
cx.self_id(),
e
)),
}
}
}
Expand Down Expand Up @@ -696,69 +706,6 @@ impl Handler<resource::GetRankStatus> for ProcMeshAgent {
}
}

#[async_trait]
impl Handler<resource::GetAllRankStatus> for ProcMeshAgent {
async fn handle(
&mut self,
cx: &Context<Self>,
get_rank_status: resource::GetAllRankStatus,
) -> anyhow::Result<()> {
use crate::resource::Status;

let mut ranks = Vec::new();
for (_name, state) in self.actor_states.iter() {
match state {
ActorInstanceState {
spawn: Ok(actor_id),
create_rank,
stopped,
} => {
if *stopped {
ranks.push((*create_rank, resource::Status::Stopped));
} else {
let supervision_events = self
.supervision_events
.get(actor_id)
.map_or_else(Vec::new, |a| a.clone());
ranks.push((
*create_rank,
if supervision_events.is_empty() {
resource::Status::Running
} else {
resource::Status::Failed(format!(
"because of supervision events: {:?}",
supervision_events
))
},
));
}
}
ActorInstanceState {
spawn: Err(e),
create_rank,
..
} => {
ranks.push((*create_rank, Status::Failed(e.to_string())));
}
}
}

let result = get_rank_status.reply.send(cx, ranks);
// Ignore errors, because returning Err from here would cause the ProcMeshAgent
// to be stopped, which would prevent querying and spawning other actors.
// This only means some actor that requested the state of an actor failed to receive it.
if let Err(e) = result {
tracing::warn!(
actor = %cx.self_id(),
"failed to send GetRankStatus reply to {} due to error: {}",
get_rank_status.reply.port_id().actor_id(),
e
);
}
Ok(())
}
}

#[async_trait]
impl Handler<resource::GetState<ActorState>> for ProcMeshAgent {
async fn handle(
Expand Down
20 changes: 0 additions & 20 deletions hyperactor_mesh/src/resource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,26 +217,6 @@ impl GetRankStatus {
}
}

/// Get the status of all resources across the mesh.
#[derive(
Clone,
Debug,
Serialize,
Deserialize,
Named,
Handler,
HandleClient,
RefClient,
Bind,
Unbind
)]
pub struct GetAllRankStatus {
/// Returns the status and rank of all resources.
/// TODO: migrate to a ValueOverlay.
#[binding(include)]
pub reply: PortRef<Vec<(usize, Status)>>,
}

/// The state of a resource.
#[derive(Clone, Debug, Serialize, Deserialize, Named, PartialEq, Eq)]
pub struct State<S> {
Expand Down