diff --git a/hyperactor_mesh/src/bootstrap.rs b/hyperactor_mesh/src/bootstrap.rs index e8d49c535..138f5a9a4 100644 --- a/hyperactor_mesh/src/bootstrap.rs +++ b/hyperactor_mesh/src/bootstrap.rs @@ -72,7 +72,6 @@ use crate::logging::OutputTarget; use crate::logging::StreamFwder; use crate::proc_mesh::mesh_agent::ProcMeshAgent; use crate::resource; -use crate::resource::Status; use crate::v1; use crate::v1::host_mesh::mesh_agent::HostAgentMode; use crate::v1::host_mesh::mesh_agent::HostMeshAgent; @@ -1202,12 +1201,15 @@ impl BootstrapProcHandle { (out, err) } + /// Sends a StopAll message to the ProcMeshAgent, which should exit the process. + /// Waits for the successful state change of the process. If the process + /// doesn't reach a terminal state, returns Err. async fn send_stop_all( &self, cx: &impl context::Actor, agent: ActorRef, timeout: Duration, - ) -> anyhow::Result<()> { + ) -> anyhow::Result { // For all of the messages and replies in this function: // if the proc is already dead, then the message will be undeliverable, // which should be ignored. @@ -1217,25 +1219,12 @@ impl BootstrapProcHandle { let mut agent_port = agent.port(); agent_port.return_undeliverable(false); agent_port.send(cx, resource::StopAll {})?; - let (reply_port, mut rx) = cx.mailbox().open_port::>(); - let mut reply_port = reply_port.bind(); - reply_port.return_undeliverable(false); - // Similar to above, if we cannot query for the stopped actors, just - // proceed with SIGTERM. - let mut agent_port = agent.port(); - agent_port.return_undeliverable(false); - agent_port.send(cx, resource::GetAllRankStatus { reply: reply_port })?; - // If there's a timeout waiting for a reply, continue with SIGTERM. - let statuses = RealClock.timeout(timeout, rx.recv()).await??; - let has_failure = statuses.iter().any(|(_rank, status)| status.is_failure()); - - if has_failure { - Err(anyhow::anyhow!( - "StopAll had some actors that failed: {:?}", - statuses, - )) - } else { - Ok(()) + // The agent handling Stop should exit the process, if it doesn't within + // the time window, we escalate to SIGTERM. + match RealClock.timeout(timeout, self.wait()).await { + Ok(Ok(st)) => Ok(st), + Ok(Err(e)) => Err(anyhow::anyhow!("agent did not exit the process: {:?}", e)), + Err(_) => Err(anyhow::anyhow!("agent did not exit the process in time")), } } } @@ -1361,20 +1350,19 @@ impl hyperactor::host::ProcHandle for BootstrapProcHandle { // they are in the Ready state and have an Agent we can message. let agent = self.agent_ref(); if let Some(agent) = agent { - if let Err(e) = self.send_stop_all(cx, agent.clone(), timeout).await { - // Variety of possible errors, proceed with SIGTERM. - tracing::warn!( - "ProcMeshAgent {} could not successfully stop all actors: {}", - agent.actor_id(), - e, - ); + match self.send_stop_all(cx, agent.clone(), timeout).await { + Ok(st) => return Ok(st), + Err(e) => { + // Variety of possible errors, proceed with SIGTERM. + tracing::warn!( + "ProcMeshAgent {} could not successfully stop all actors: {}", + agent.actor_id(), + e, + ); + } } - // Even if the StopAll message and response is fully effective, we - // still want to send SIGTERM to actually exit the process and free - // any leftover resources. No actor should be running at this - // point. } - // After the stop all actors message may be successful, we still need + // If the stop all actors message was unsuccessful, we need // to actually stop the process. let _ = self.mark_stopping(); @@ -1690,7 +1678,7 @@ impl BootstrapProcManager { /// Return the current [`ProcStatus`] for the given [`ProcId`], if /// the proc is known to this manager. /// - /// This querprocies the live [`BootstrapProcHandle`] stored in the + /// This queries the live [`BootstrapProcHandle`] stored in the /// manager's internal map. It provides an immediate snapshot of /// lifecycle state (`Starting`, `Running`, `Stopping`, `Stopped`, /// etc.). diff --git a/hyperactor_mesh/src/proc_mesh/mesh_agent.rs b/hyperactor_mesh/src/proc_mesh/mesh_agent.rs index 88c940506..a6a410637 100644 --- a/hyperactor_mesh/src/proc_mesh/mesh_agent.rs +++ b/hyperactor_mesh/src/proc_mesh/mesh_agent.rs @@ -222,7 +222,6 @@ pub(crate) fn update_event_actor_id(mut event: ActorSupervisionEvent) -> ActorSu resource::StopAll { cast = true }, resource::GetState { cast = true }, resource::GetRankStatus { cast = true }, - resource::GetAllRankStatus { cast = true }, ] )] pub struct ProcMeshAgent { @@ -599,6 +598,10 @@ impl Handler for ProcMeshAgent { } } +/// Handles `StopAll` by coordinating an orderly stop of child actors and then +/// exiting the process. This handler never returns to the caller: it calls +/// `std::process::exit(0/1)` after shutdown. Any sender must *not* expect a +/// reply or send any further message, and should watch `ProcStatus` instead. #[async_trait] impl Handler for ProcMeshAgent { async fn handle( @@ -610,19 +613,26 @@ impl Handler for ProcMeshAgent { // By passing in the self context, destroy_and_wait will stop this agent // last, after all others are stopped. let stop_result = self.destroy_and_wait_except_current(cx, timeout).await; + // Exit here to cleanup all remaining resources held by the process. + // This means ProcMeshAgent will never run cleanup or any other code + // from exiting its root actor. Senders of this message should never + // send any further messages or expect a reply. match stop_result { - Ok(_) => { - for (_, actor_state) in self.actor_states.iter_mut() { - // Mark all actors as stopped. - actor_state.stopped = true; - } - Ok(()) + Ok((stopped_actors, aborted_actors)) => { + // No need to clean up any state, the process is exiting. + tracing::info!( + actor = %cx.self_id(), + "exiting process after receiving StopAll message on ProcMeshAgent. \ + stopped actors = {:?}, aborted actors = {:?}", + stopped_actors.into_iter().map(|a| a.to_string()).collect::>(), + aborted_actors.into_iter().map(|a| a.to_string()).collect::>(), + ); + std::process::exit(0); + } + Err(e) => { + tracing::error!(actor = %cx.self_id(), "failed to stop all actors on ProcMeshAgent: {:?}", e); + std::process::exit(1); } - Err(e) => Err(anyhow::anyhow!( - "failed to StopAll on {}: {:?}", - cx.self_id(), - e - )), } } } @@ -696,69 +706,6 @@ impl Handler for ProcMeshAgent { } } -#[async_trait] -impl Handler for ProcMeshAgent { - async fn handle( - &mut self, - cx: &Context, - get_rank_status: resource::GetAllRankStatus, - ) -> anyhow::Result<()> { - use crate::resource::Status; - - let mut ranks = Vec::new(); - for (_name, state) in self.actor_states.iter() { - match state { - ActorInstanceState { - spawn: Ok(actor_id), - create_rank, - stopped, - } => { - if *stopped { - ranks.push((*create_rank, resource::Status::Stopped)); - } else { - let supervision_events = self - .supervision_events - .get(actor_id) - .map_or_else(Vec::new, |a| a.clone()); - ranks.push(( - *create_rank, - if supervision_events.is_empty() { - resource::Status::Running - } else { - resource::Status::Failed(format!( - "because of supervision events: {:?}", - supervision_events - )) - }, - )); - } - } - ActorInstanceState { - spawn: Err(e), - create_rank, - .. - } => { - ranks.push((*create_rank, Status::Failed(e.to_string()))); - } - } - } - - let result = get_rank_status.reply.send(cx, ranks); - // Ignore errors, because returning Err from here would cause the ProcMeshAgent - // to be stopped, which would prevent querying and spawning other actors. - // This only means some actor that requested the state of an actor failed to receive it. - if let Err(e) = result { - tracing::warn!( - actor = %cx.self_id(), - "failed to send GetRankStatus reply to {} due to error: {}", - get_rank_status.reply.port_id().actor_id(), - e - ); - } - Ok(()) - } -} - #[async_trait] impl Handler> for ProcMeshAgent { async fn handle( diff --git a/hyperactor_mesh/src/resource.rs b/hyperactor_mesh/src/resource.rs index 5d9144c1e..32809beaa 100644 --- a/hyperactor_mesh/src/resource.rs +++ b/hyperactor_mesh/src/resource.rs @@ -217,26 +217,6 @@ impl GetRankStatus { } } -/// Get the status of all resources across the mesh. -#[derive( - Clone, - Debug, - Serialize, - Deserialize, - Named, - Handler, - HandleClient, - RefClient, - Bind, - Unbind -)] -pub struct GetAllRankStatus { - /// Returns the status and rank of all resources. - /// TODO: migrate to a ValueOverlay. - #[binding(include)] - pub reply: PortRef>, -} - /// The state of a resource. #[derive(Clone, Debug, Serialize, Deserialize, Named, PartialEq, Eq)] pub struct State {