From c09c82f9431a0c6a076a8d067ac5b1d48320e862 Mon Sep 17 00:00:00 2001 From: Joe Zuntz Date: Fri, 6 Aug 2021 09:41:39 +0100 Subject: [PATCH 1/5] Add mode to use non-world comm and return to calling code without exiting at end of loop --- dask_mpi/core.py | 50 ++++++++++++++++++++++++++++++---- dask_mpi/tests/core_no_exit.py | 31 +++++++++++++++++++++ dask_mpi/tests/test_no_exit.py | 20 ++++++++++++++ 3 files changed, 95 insertions(+), 6 deletions(-) create mode 100644 dask_mpi/tests/core_no_exit.py create mode 100644 dask_mpi/tests/test_no_exit.py diff --git a/dask_mpi/core.py b/dask_mpi/core.py index 0688d07..083c992 100644 --- a/dask_mpi/core.py +++ b/dask_mpi/core.py @@ -20,13 +20,21 @@ def initialize( protocol=None, worker_class="distributed.Worker", worker_options=None, + comm=None, + exit=True, ): """ Initialize a Dask cluster using mpi4py Using mpi4py, MPI rank 0 launches the Scheduler, MPI rank 1 passes through to the client script, and all other MPI ranks launch workers. All MPI ranks other than - MPI rank 1 block while their event loops run and exit once shut down. + MPI rank 1 block while their event loops run. + + In normal operation these ranks exit once rank 1 ends. If exit=False is set they + instead return an bool indicating whether they are the client and should execute + more client code, or a worker/scheduler who should not. In this case the user is + responsible for the client calling send_close_signal when work is complete, and + checking the returned value to choose further actions. Parameters ---------- @@ -51,10 +59,22 @@ def initialize( Class to use when creating workers worker_options : dict Options to pass to workers + comm: mpi4py.MPI.Intracomm + Optional MPI communicator to use instead of COMM_WORLD + exit: bool + Whether to call sys.exit on the workers and schedulers when the event + loop completes. + + Returns + ------- + is_client: bool + Only returned if exit=False. Inidcates whether this rank should continue + to run client code (True), or if it acts as a scheduler or worker (False). """ - from mpi4py import MPI + if comm is None: + from mpi4py import MPI + comm = MPI.COMM_WORLD - comm = MPI.COMM_WORLD rank = comm.Get_rank() loop = IOLoop.current() @@ -75,7 +95,10 @@ async def run_scheduler(): await scheduler.finished() asyncio.get_event_loop().run_until_complete(run_scheduler()) - sys.exit() + if exit: + sys.exit() + else: + return False else: scheduler_address = comm.bcast(None, root=0) @@ -83,7 +106,9 @@ async def run_scheduler(): comm.Barrier() if rank == 1: - atexit.register(send_close_signal) + if exit: + atexit.register(send_close_signal) + return True else: async def run_worker(): @@ -106,10 +131,23 @@ async def run_worker(): await worker.finished() asyncio.get_event_loop().run_until_complete(run_worker()) - sys.exit() + if exit: + sys.exit() + else: + return False def send_close_signal(): + """ + The client can call this function to explicitly stop + the event loop. + + This is not needed in normal usage, where it is run + automatically when the client code exits python. + + You only need to call this manually when using exit=False + in initialize. + """ async def stop(dask_scheduler): await dask_scheduler.close() await gen.sleep(0.1) diff --git a/dask_mpi/tests/core_no_exit.py b/dask_mpi/tests/core_no_exit.py new file mode 100644 index 0000000..7bf952d --- /dev/null +++ b/dask_mpi/tests/core_no_exit.py @@ -0,0 +1,31 @@ +from time import sleep + +from distributed import Client +from distributed.metrics import time + +from dask_mpi import initialize, send_close_signal + +from mpi4py.MPI import COMM_WORLD as world + +# Split our MPI world into two pieces, one consisting just of +# the old rank 3 process and the other with everything else +new_comm_assignment = 1 if world.rank == 3 else 0 +comm = world.Split(new_comm_assignment) + +if world.rank != 3: + # run tests with rest of comm + is_client = initialize(comm=comm, exit=False) + + if is_client: + with Client() as c: + c.submit(lambda x: x + 1, 10).result() == 11 + c.submit(lambda x: x + 1, 20).result() == 21 + send_close_signal() + + + +# check that our original comm is intact +world.Barrier() +x = 100 if world.rank == 0 else 200 +x = world.bcast(x) +assert x == 100 diff --git a/dask_mpi/tests/test_no_exit.py b/dask_mpi/tests/test_no_exit.py new file mode 100644 index 0000000..65c27d5 --- /dev/null +++ b/dask_mpi/tests/test_no_exit.py @@ -0,0 +1,20 @@ +from __future__ import absolute_import, division, print_function + +import os +import subprocess +import sys + +import pytest + +pytest.importorskip("mpi4py") + + +def test_no_exit(mpirun): + script_file = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "core_no_exit.py" + ) + + p = subprocess.Popen(mpirun + ["-np", "4", sys.executable, script_file]) + + p.communicate() + assert p.returncode == 0 From cd8496ddf49b3e9c6522067b877ca6f3f91dae49 Mon Sep 17 00:00:00 2001 From: Joe Zuntz Date: Fri, 6 Aug 2021 09:52:16 +0100 Subject: [PATCH 2/5] appease flake8 --- dask_mpi/tests/core_no_exit.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/dask_mpi/tests/core_no_exit.py b/dask_mpi/tests/core_no_exit.py index 7bf952d..ba8daf0 100644 --- a/dask_mpi/tests/core_no_exit.py +++ b/dask_mpi/tests/core_no_exit.py @@ -1,10 +1,5 @@ -from time import sleep - from distributed import Client -from distributed.metrics import time - from dask_mpi import initialize, send_close_signal - from mpi4py.MPI import COMM_WORLD as world # Split our MPI world into two pieces, one consisting just of @@ -22,8 +17,6 @@ c.submit(lambda x: x + 1, 20).result() == 21 send_close_signal() - - # check that our original comm is intact world.Barrier() x = 100 if world.rank == 0 else 200 From d652180a00c89e2522d4db6941da7dcf7d4171ef Mon Sep 17 00:00:00 2001 From: Joe Zuntz Date: Fri, 6 Aug 2021 09:54:52 +0100 Subject: [PATCH 3/5] run black --- dask_mpi/core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dask_mpi/core.py b/dask_mpi/core.py index 083c992..22499dc 100644 --- a/dask_mpi/core.py +++ b/dask_mpi/core.py @@ -73,6 +73,7 @@ def initialize( """ if comm is None: from mpi4py import MPI + comm = MPI.COMM_WORLD rank = comm.Get_rank() @@ -148,6 +149,7 @@ def send_close_signal(): You only need to call this manually when using exit=False in initialize. """ + async def stop(dask_scheduler): await dask_scheduler.close() await gen.sleep(0.1) From 33314df5840731d00fa1a40b559bd29548b5e84b Mon Sep 17 00:00:00 2001 From: Joe Zuntz Date: Fri, 6 Aug 2021 09:57:31 +0100 Subject: [PATCH 4/5] run isort --- dask_mpi/tests/core_no_exit.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dask_mpi/tests/core_no_exit.py b/dask_mpi/tests/core_no_exit.py index ba8daf0..fde7baf 100644 --- a/dask_mpi/tests/core_no_exit.py +++ b/dask_mpi/tests/core_no_exit.py @@ -1,7 +1,8 @@ from distributed import Client -from dask_mpi import initialize, send_close_signal from mpi4py.MPI import COMM_WORLD as world +from dask_mpi import initialize, send_close_signal + # Split our MPI world into two pieces, one consisting just of # the old rank 3 process and the other with everything else new_comm_assignment = 1 if world.rank == 3 else 0 From b0ada4aab48af8ac9bde633a33a46a7b941c75ea Mon Sep 17 00:00:00 2001 From: Joe Zuntz Date: Fri, 6 Aug 2021 10:03:33 +0100 Subject: [PATCH 5/5] import send_close_signal to main --- dask_mpi/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dask_mpi/__init__.py b/dask_mpi/__init__.py index 5701ae0..ad7560e 100644 --- a/dask_mpi/__init__.py +++ b/dask_mpi/__init__.py @@ -1,5 +1,5 @@ from ._version import get_versions -from .core import initialize +from .core import initialize, send_close_signal __version__ = get_versions()["version"] del get_versions