diff --git a/python/monarch/_src/actor/v1/proc_mesh.py b/python/monarch/_src/actor/v1/proc_mesh.py index 00b56f37d..3d75c04c0 100644 --- a/python/monarch/_src/actor/v1/proc_mesh.py +++ b/python/monarch/_src/actor/v1/proc_mesh.py @@ -258,6 +258,9 @@ async def task( await pm._logging_manager.init(hy_proc_mesh, stream_log_to_client) + # If the user has passed the setup lambda, we need to call + # it here before any of the other python actors are spawned so + # that the environment variables are set up before cuda init. if setup_actor is not None: await setup_actor.setup.call() @@ -267,9 +270,9 @@ async def task( if setup is not None: from monarch._src.actor.proc_mesh import SetupActor # noqa - # If the user has passed the setup lambda, we need to call - # it here before any of the other actors are spawned so that - # the environment variables are set up before cuda init. + # The SetupActor needs to be spawned outside of `task` for now, + # since spawning a python actor requires a blocking call to + # pickle the proc mesh, and we can't do that from the tokio runtime. setup_actor = pm._spawn_nonblocking_on( hy_proc_mesh, "setup", SetupActor, setup ) @@ -402,8 +405,9 @@ def stop(self) -> Future[None]: instance = context().actor_instance._as_rust() async def _stop_nonblocking(instance: HyInstance) -> None: + pm = await self._proc_mesh await PythonTask.spawn_blocking(lambda: self._logging_manager.flush()) - await (await self._proc_mesh).stop_nonblocking(instance) + await pm.stop_nonblocking(instance) self._stopped = True return Future(coro=_stop_nonblocking(instance)) diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index 8cd3e6bfc..cd0d91b40 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -149,6 +149,9 @@ async def test_mesh_passed_to_mesh(): proc = fake_in_process_host().spawn_procs(per_host={"gpus": 2}) f = proc.spawn("from", From) t = proc.spawn("to", To) + # Make sure t is initialized before sending to f. Otherwise + # f might call t.whoami before t.__init__. + await t.whoami.call() all = [y for x in f.fetch.stream(t) for y in await x] assert len(all) == 4 assert all[0] != all[1] @@ -160,6 +163,9 @@ async def test_mesh_passed_to_mesh_on_different_proc_mesh(): proc2 = fake_in_process_host().spawn_procs(per_host={"gpus": 2}) f = proc.spawn("from", From) t = proc2.spawn("to", To) + # Make sure t is initialized before sending to f. Otherwise + # f might call t.whoami before t.__init__. + await t.whoami.call() all = [y for x in f.fetch.stream(t) for y in await x] assert len(all) == 4 assert all[0] != all[1]