diff --git a/monarch_hyperactor/src/runtime.rs b/monarch_hyperactor/src/runtime.rs index 661f1a008..e8761845a 100644 --- a/monarch_hyperactor/src/runtime.rs +++ b/monarch_hyperactor/src/runtime.rs @@ -21,9 +21,6 @@ use pyo3::Python; use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; use pyo3::types::PyAnyMethods; -use pyo3::types::PyCFunction; -use pyo3::types::PyDict; -use pyo3::types::PyTuple; use pyo3_async_runtimes::TaskLocals; use tokio::task; @@ -63,7 +60,10 @@ pub fn get_tokio_runtime<'l>() -> std::sync::MappedRwLockReadGuard<'l, tokio::ru }) } +#[pyfunction] pub fn shutdown_tokio_runtime() { + // It is important to not hold the GIL while calling this function. + // Other runtime threads may be waiting to acquire it and we will never get to shutdown. if let Some(x) = INSTANCE.write().unwrap().take() { x.shutdown_timeout(Duration::from_secs(1)); } @@ -84,17 +84,9 @@ pub fn initialize(py: Python) -> Result<()> { ); IS_MAIN_THREAD.set(true); - let closure = PyCFunction::new_closure( - py, - None, - None, - |_args: &Bound<'_, PyTuple>, _kwargs: Option<&Bound<'_, PyDict>>| { - shutdown_tokio_runtime(); - }, - ) - .unwrap(); - let atexit = py.import("atexit").unwrap(); - atexit.call_method1("register", (closure,)).unwrap(); + let atexit = py.import("atexit")?; + let shutdown_fn = wrap_pyfunction!(shutdown_tokio_runtime, py)?; + atexit.call_method1("register", (shutdown_fn,))?; Ok(()) }