diff --git a/hyperactor_mesh/src/bootstrap.rs b/hyperactor_mesh/src/bootstrap.rs index 15021d587..605b51a6b 100644 --- a/hyperactor_mesh/src/bootstrap.rs +++ b/hyperactor_mesh/src/bootstrap.rs @@ -3622,7 +3622,7 @@ mod tests { // stored `BootstrapProcManager`, which does a // `Command::spawn()` to launch a new OS child process for // that proc. - let host_mesh = HostMesh::allocate(&instance, Box::new(alloc), "test", None) + let mut host_mesh = HostMesh::allocate(&instance, Box::new(alloc), "test", None) .await .unwrap(); diff --git a/hyperactor_mesh/src/v1/actor_mesh.rs b/hyperactor_mesh/src/v1/actor_mesh.rs index 58d4bca30..fb663a925 100644 --- a/hyperactor_mesh/src/v1/actor_mesh.rs +++ b/hyperactor_mesh/src/v1/actor_mesh.rs @@ -789,7 +789,7 @@ mod tests { let _guard = config.override_key(crate::bootstrap::MESH_BOOTSTRAP_ENABLE_PDEATHSIG, false); let instance = testing::instance().await; - let host_mesh = testing::host_mesh(extent!(host = 4)).await; + let mut host_mesh = testing::host_mesh(extent!(host = 4)).await; let proc_mesh = host_mesh .spawn(instance, "test", Extent::unity()) .await diff --git a/hyperactor_mesh/src/v1/host_mesh.rs b/hyperactor_mesh/src/v1/host_mesh.rs index 5b5edebb4..26f4bd0cd 100644 --- a/hyperactor_mesh/src/v1/host_mesh.rs +++ b/hyperactor_mesh/src/v1/host_mesh.rs @@ -476,7 +476,7 @@ impl HostMesh { /// table and sends SIGKILL to any procs it spawned—tying proc /// lifetimes to their hosts and preventing leaks. #[hyperactor::instrument(fields(host_mesh=self.name.to_string()))] - pub async fn shutdown(&self, cx: &impl hyperactor::context::Actor) -> anyhow::Result<()> { + pub async fn shutdown(&mut self, cx: &impl hyperactor::context::Actor) -> anyhow::Result<()> { tracing::info!(name = "HostMeshStatus", status = "Shutdown::Attempt"); let mut failed_hosts = vec![]; for host in self.current_ref.values() { @@ -501,6 +501,13 @@ impl HostMesh { failed_hosts ); } + + match &mut self.allocation { + HostMeshAllocation::ProcMesh { proc_mesh, .. } => { + proc_mesh.stop(cx).await?; + } + HostMeshAllocation::Owned { .. } => {} + } Ok(()) } } @@ -1325,7 +1332,7 @@ mod tests { let instance = testing::instance().await; for alloc in testing::allocs(extent!(replicas = 4)).await { - let host_mesh = HostMesh::allocate(instance, alloc, "test", None) + let mut host_mesh = HostMesh::allocate(instance, alloc, "test", None) .await .unwrap(); diff --git a/monarch_hyperactor/src/v1/host_mesh.rs b/monarch_hyperactor/src/v1/host_mesh.rs index c73c1ffcb..029f7b27c 100644 --- a/monarch_hyperactor/src/v1/host_mesh.rs +++ b/monarch_hyperactor/src/v1/host_mesh.rs @@ -197,12 +197,22 @@ impl PyHostMesh { match self { PyHostMesh::Owned(inner) => { let instance = instance.clone(); - let mesh_borrow = inner.0.borrow().map_err(anyhow::Error::from)?; + let mesh_borrow = inner.0.clone(); let fut = async move { - instance_dispatch!(instance, |cx_instance| { - mesh_borrow.shutdown(cx_instance).await - })?; - Ok(()) + match mesh_borrow.take().await { + Ok(mut mesh) => { + instance_dispatch!(instance, |cx_instance| { + mesh.shutdown(cx_instance).await + })?; + Ok(()) + } + Err(_) => { + // Don't return an exception, silently ignore the stop request + // because it was already done. + tracing::info!("shutdown was already called on host mesh"); + Ok(()) + } + } }; PyPythonTask::new(fut) } diff --git a/monarch_hyperactor/src/v1/logging.rs b/monarch_hyperactor/src/v1/logging.rs index b462342ed..be6e9462d 100644 --- a/monarch_hyperactor/src/v1/logging.rs +++ b/monarch_hyperactor/src/v1/logging.rs @@ -511,7 +511,7 @@ mod tests { #[tokio::test] async fn test_world_smoke() { - let (proc, instance, host_mesh, proc_mesh) = test_world().await.expect("world failed"); + let (proc, instance, mut host_mesh, proc_mesh) = test_world().await.expect("world failed"); assert_eq!( host_mesh.region().num_ranks(), @@ -534,7 +534,7 @@ mod tests { #[tokio::test] async fn spawn_respects_forwarding_flag() { - let (_, instance, host_mesh, proc_mesh) = test_world().await.expect("world failed"); + let (_, instance, mut host_mesh, proc_mesh) = test_world().await.expect("world failed"); let py_instance = PyInstance::from(&instance); let py_proc_mesh = PyProcMesh::new_owned(proc_mesh); @@ -591,7 +591,7 @@ mod tests { #[tokio::test] async fn set_mode_behaviors() { - let (_proc, instance, host_mesh, proc_mesh) = test_world().await.expect("world failed"); + let (_proc, instance, mut host_mesh, proc_mesh) = test_world().await.expect("world failed"); let py_instance = PyInstance::from(&instance); let py_proc_mesh = PyProcMesh::new_owned(proc_mesh); @@ -706,7 +706,7 @@ mod tests { #[tokio::test] async fn flush_behaviors() { - let (_proc, instance, host_mesh, proc_mesh) = test_world().await.expect("world failed"); + let (_proc, instance, mut host_mesh, proc_mesh) = test_world().await.expect("world failed"); let py_instance = PyInstance::from(&instance); let py_proc_mesh = PyProcMesh::new_owned(proc_mesh);