Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 6 additions & 14 deletions monarch_hyperactor/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ use pyo3::Python;
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use pyo3::types::PyAnyMethods;
use pyo3::types::PyCFunction;
use pyo3::types::PyDict;
use pyo3::types::PyTuple;
use pyo3_async_runtimes::TaskLocals;
use tokio::task;

Expand Down Expand Up @@ -63,7 +60,10 @@ pub fn get_tokio_runtime<'l>() -> std::sync::MappedRwLockReadGuard<'l, tokio::ru
})
}

#[pyfunction]
pub fn shutdown_tokio_runtime() {
// It is important to not hold the GIL while calling this function.
// Other runtime threads may be waiting to acquire it and we will never get to shutdown.
if let Some(x) = INSTANCE.write().unwrap().take() {
x.shutdown_timeout(Duration::from_secs(1));
}
Expand All @@ -84,17 +84,9 @@ pub fn initialize(py: Python) -> Result<()> {
);
IS_MAIN_THREAD.set(true);

let closure = PyCFunction::new_closure(
py,
None,
None,
|_args: &Bound<'_, PyTuple>, _kwargs: Option<&Bound<'_, PyDict>>| {
shutdown_tokio_runtime();
},
)
.unwrap();
let atexit = py.import("atexit").unwrap();
atexit.call_method1("register", (closure,)).unwrap();
let atexit = py.import("atexit")?;
let shutdown_fn = wrap_pyfunction!(shutdown_tokio_runtime, py)?;
atexit.call_method1("register", (shutdown_fn,))?;
Ok(())
}

Expand Down