diff --git a/hyperactor/src/proc.rs b/hyperactor/src/proc.rs index c3bc4ceee..ccf4c89b5 100644 --- a/hyperactor/src/proc.rs +++ b/hyperactor/src/proc.rs @@ -517,11 +517,9 @@ impl Proc { .next() } - // Iterating over a proc's root actors signaling each to stop. - // Return the root actor IDs and status observers. - async fn destroy( - &mut self, - ) -> Result>, anyhow::Error> { + /// Iterating over a proc's root actors signaling each to stop. + /// Return the root actor IDs and status observers. + pub fn destroy(&self) -> Result>, anyhow::Error> { tracing::debug!("{}: proc stopping", self.proc_id()); let mut statuses = HashMap::new(); @@ -558,7 +556,7 @@ impl Proc { timeout: Duration, skip_waiting: Option<&ActorId>, ) -> Result<(Vec, Vec), anyhow::Error> { - let mut statuses = self.destroy().await?; + let mut statuses = self.destroy()?; let waits: Vec<_> = statuses .iter_mut() .filter(|(actor_id, _)| Some(*actor_id) != skip_waiting) diff --git a/hyperactor_mesh/src/proc_mesh/mod.rs b/hyperactor_mesh/src/proc_mesh/mod.rs index afbe42881..7196a0e48 100644 --- a/hyperactor_mesh/src/proc_mesh/mod.rs +++ b/hyperactor_mesh/src/proc_mesh/mod.rs @@ -414,6 +414,8 @@ impl ProcMesh { /// An event stream of proc events. Each ProcMesh can produce only one such /// stream, returning None after the first call. pub fn events(&mut self) -> Option { + let (stop_alloc_tx, stop_alloc_rx) = tokio::sync::mpsc::unbounded_channel::<()>(); + self.event_state.take().map(|event_state| ProcEvents { event_state, ranks: self @@ -422,6 +424,8 @@ impl ProcMesh { .enumerate() .map(|(rank, (proc_id, _))| (proc_id.clone(), rank)) .collect(), + stop_alloc_tx, + stop_alloc_rx, }) } pub fn shape(&self) -> &Shape { @@ -457,6 +461,8 @@ impl fmt::Display for ProcEvent { pub struct ProcEvents { event_state: EventState, ranks: HashMap, + stop_alloc_tx: tokio::sync::mpsc::UnboundedSender<()>, + stop_alloc_rx: tokio::sync::mpsc::UnboundedReceiver<()>, } impl ProcEvents { @@ -491,9 +497,19 @@ impl ProcEvents { }; break Some(ProcEvent::Crashed(*rank, actor_status.to_string())) } + Some(_) = self.stop_alloc_rx.recv() => { + if let Err(err) = self.event_state.alloc.stop_and_wait().await { + tracing::error!("failed to stop alloc: {}", err); + } + break None; + } } } } + + pub fn stop_alloc_tx(&self) -> &tokio::sync::mpsc::UnboundedSender<()> { + &self.stop_alloc_tx + } } /// Spawns from shared ([`Arc`]) proc meshes, providing [`ActorMesh`]es with @@ -697,4 +713,26 @@ mod tests { assert!(events.next().await.is_none()); } + + #[tracing_test::traced_test] + #[tokio::test] + async fn test_proc_mesh_stop() { + let alloc_spec = AllocSpec { + shape: shape! { replica = 4 }, + constraints: Default::default(), + }; + let alloc = LocalAllocator.allocate(alloc_spec).await.unwrap(); + let mut proc_mesh = ProcMesh::allocate(alloc).await.unwrap(); + + let _ = proc_mesh.spawn::("foo", &()).await.unwrap(); + let _ = proc_mesh.spawn::("bar", &()).await.unwrap(); + + let mut proc_state = proc_mesh.events().unwrap(); + let stop_sender = proc_state.stop_alloc_tx(); + stop_sender.send(()).unwrap(); + + while (proc_state.next().await).is_some() {} + + assert!(logs_contain("4 actors stopped")); + } } diff --git a/monarch_hyperactor/src/proc_mesh.rs b/monarch_hyperactor/src/proc_mesh.rs index 27c0dde2a..24fd83044 100644 --- a/monarch_hyperactor/src/proc_mesh.rs +++ b/monarch_hyperactor/src/proc_mesh.rs @@ -34,6 +34,8 @@ use crate::shape::PyShape; pub struct PyProcMesh { pub inner: Arc, keepalive: Keepalive, + stop_signal_sender: tokio::sync::mpsc::UnboundedSender<()>, + stop_observer: tokio::sync::watch::Receiver, } fn allocate_proc_mesh<'py>(py: Python<'py>, alloc: &PyAlloc) -> PyResult> { let alloc = match alloc.take() { @@ -75,19 +77,25 @@ impl PyProcMesh { /// Create a new [`PyProcMesh`] with a monitor that crashes the /// process on any proc failure. fn monitored(mut proc_mesh: ProcMesh, world_id: WorldId) -> Self { - let monitor = tokio::spawn(Self::monitor_proc_mesh( - proc_mesh.events().unwrap(), - world_id, - )); + let events = proc_mesh.events().unwrap(); + let stop_signal_sender = events.stop_alloc_tx().clone(); + let (stopped_sender, stop_observer) = tokio::sync::watch::channel(false); + let monitor = tokio::spawn(Self::monitor_proc_mesh(events, world_id, stopped_sender)); Self { inner: Arc::new(proc_mesh), keepalive: Keepalive::new(monitor), + stop_signal_sender, + stop_observer, } } /// Monitor the proc mesh for crashes. If a proc crashes, we print the reason /// to stderr and exit with code 1. - async fn monitor_proc_mesh(mut events: ProcEvents, world_id: WorldId) { + async fn monitor_proc_mesh( + mut events: ProcEvents, + world_id: WorldId, + stopped_sender: tokio::sync::watch::Sender, + ) { while let Some(event) = events.next().await { match event { // A graceful stop should not be cause for alarm, but @@ -99,6 +107,7 @@ impl PyProcMesh { } } } + let _ = stopped_sender.send(true); } } @@ -173,6 +182,27 @@ impl PyProcMesh { } } + fn stop(&mut self) -> PyResult<()> { + self.stop_signal_sender.send(()).map_err(|err| { + PyException::new_err(format!("Failed to send stop signal to alloc: {}", err)) + })?; + self.inner.client_proc().destroy().map_err(|err| { + PyException::new_err(format!("Failed to destroy client proc: {}", err)) + })?; + Ok(()) + } + + fn wait_for_stop<'py>(&mut self, py: Python<'py>) -> PyResult> { + let mut stop_observer = self.stop_observer.clone(); + pyo3_async_runtimes::tokio::future_into_py(py, async move { + stop_observer + .wait_for(|stopped| *stopped) + .await + .map_err(|err| PyException::new_err(format!("Failed to wait for stop: {}", err)))?; + Ok(()) + }) + } + fn __repr__(&self) -> PyResult { Ok(format!("", self.inner)) } diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/proc_mesh.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/proc_mesh.pyi index cc6da3a94..d25bf919a 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/proc_mesh.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/proc_mesh.pyi @@ -74,4 +74,20 @@ class ProcMesh: """ ... + def stop(self) -> None: + """ + Signals the `alloc` to stop all `Proc`s and the + `client`'s `Proc` to stop. + This returns immediately after the signal is sent. + + Call `await wait_for_stop()` to wait until all the `Proc`s have completed stopping. + """ + ... + + async def wait_for_stop(self) -> None: + """ + Wait for all `Proc`s in the `alloc` and the `client`'s `Proc` to stop. + """ + ... + def __repr__(self) -> str: ... diff --git a/python/monarch/proc_mesh.py b/python/monarch/proc_mesh.py index 26fb280e8..2b1ba8230 100644 --- a/python/monarch/proc_mesh.py +++ b/python/monarch/proc_mesh.py @@ -14,6 +14,7 @@ Any, cast, Dict, + Generator, List, Optional, Sequence, @@ -214,6 +215,12 @@ async def sync_workspace(self) -> None: ) await self._rsync_mesh_client.sync_workspace() + def stop(self) -> None: + self._proc_mesh.stop() + + def __await__(self) -> Generator[None, None, None]: + return self._proc_mesh.wait_for_stop().__await__() + async def local_proc_mesh_nonblocking( *, gpus: Optional[int] = None, hosts: int = 1