diff --git a/dask_mpi/__init__.py b/dask_mpi/__init__.py index 5701ae0..ad7560e 100644 --- a/dask_mpi/__init__.py +++ b/dask_mpi/__init__.py @@ -1,5 +1,5 @@ from ._version import get_versions -from .core import initialize +from .core import initialize, send_close_signal __version__ = get_versions()["version"] del get_versions diff --git a/dask_mpi/core.py b/dask_mpi/core.py index 0688d07..22499dc 100644 --- a/dask_mpi/core.py +++ b/dask_mpi/core.py @@ -20,13 +20,21 @@ def initialize( protocol=None, worker_class="distributed.Worker", worker_options=None, + comm=None, + exit=True, ): """ Initialize a Dask cluster using mpi4py Using mpi4py, MPI rank 0 launches the Scheduler, MPI rank 1 passes through to the client script, and all other MPI ranks launch workers. All MPI ranks other than - MPI rank 1 block while their event loops run and exit once shut down. + MPI rank 1 block while their event loops run. + + In normal operation these ranks exit once rank 1 ends. If exit=False is set they + instead return an bool indicating whether they are the client and should execute + more client code, or a worker/scheduler who should not. In this case the user is + responsible for the client calling send_close_signal when work is complete, and + checking the returned value to choose further actions. Parameters ---------- @@ -51,10 +59,23 @@ def initialize( Class to use when creating workers worker_options : dict Options to pass to workers + comm: mpi4py.MPI.Intracomm + Optional MPI communicator to use instead of COMM_WORLD + exit: bool + Whether to call sys.exit on the workers and schedulers when the event + loop completes. + + Returns + ------- + is_client: bool + Only returned if exit=False. Inidcates whether this rank should continue + to run client code (True), or if it acts as a scheduler or worker (False). """ - from mpi4py import MPI + if comm is None: + from mpi4py import MPI + + comm = MPI.COMM_WORLD - comm = MPI.COMM_WORLD rank = comm.Get_rank() loop = IOLoop.current() @@ -75,7 +96,10 @@ async def run_scheduler(): await scheduler.finished() asyncio.get_event_loop().run_until_complete(run_scheduler()) - sys.exit() + if exit: + sys.exit() + else: + return False else: scheduler_address = comm.bcast(None, root=0) @@ -83,7 +107,9 @@ async def run_scheduler(): comm.Barrier() if rank == 1: - atexit.register(send_close_signal) + if exit: + atexit.register(send_close_signal) + return True else: async def run_worker(): @@ -106,10 +132,24 @@ async def run_worker(): await worker.finished() asyncio.get_event_loop().run_until_complete(run_worker()) - sys.exit() + if exit: + sys.exit() + else: + return False def send_close_signal(): + """ + The client can call this function to explicitly stop + the event loop. + + This is not needed in normal usage, where it is run + automatically when the client code exits python. + + You only need to call this manually when using exit=False + in initialize. + """ + async def stop(dask_scheduler): await dask_scheduler.close() await gen.sleep(0.1) diff --git a/dask_mpi/tests/core_no_exit.py b/dask_mpi/tests/core_no_exit.py new file mode 100644 index 0000000..fde7baf --- /dev/null +++ b/dask_mpi/tests/core_no_exit.py @@ -0,0 +1,25 @@ +from distributed import Client +from mpi4py.MPI import COMM_WORLD as world + +from dask_mpi import initialize, send_close_signal + +# Split our MPI world into two pieces, one consisting just of +# the old rank 3 process and the other with everything else +new_comm_assignment = 1 if world.rank == 3 else 0 +comm = world.Split(new_comm_assignment) + +if world.rank != 3: + # run tests with rest of comm + is_client = initialize(comm=comm, exit=False) + + if is_client: + with Client() as c: + c.submit(lambda x: x + 1, 10).result() == 11 + c.submit(lambda x: x + 1, 20).result() == 21 + send_close_signal() + +# check that our original comm is intact +world.Barrier() +x = 100 if world.rank == 0 else 200 +x = world.bcast(x) +assert x == 100 diff --git a/dask_mpi/tests/test_no_exit.py b/dask_mpi/tests/test_no_exit.py new file mode 100644 index 0000000..65c27d5 --- /dev/null +++ b/dask_mpi/tests/test_no_exit.py @@ -0,0 +1,20 @@ +from __future__ import absolute_import, division, print_function + +import os +import subprocess +import sys + +import pytest + +pytest.importorskip("mpi4py") + + +def test_no_exit(mpirun): + script_file = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "core_no_exit.py" + ) + + p = subprocess.Popen(mpirun + ["-np", "4", sys.executable, script_file]) + + p.communicate() + assert p.returncode == 0