Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dask_mpi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ._version import get_versions
from .core import initialize
from .core import initialize, send_close_signal

__version__ = get_versions()["version"]
del get_versions
52 changes: 46 additions & 6 deletions dask_mpi/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,21 @@ def initialize(
protocol=None,
worker_class="distributed.Worker",
worker_options=None,
comm=None,
exit=True,
):
"""
Initialize a Dask cluster using mpi4py

Using mpi4py, MPI rank 0 launches the Scheduler, MPI rank 1 passes through to the
client script, and all other MPI ranks launch workers. All MPI ranks other than
MPI rank 1 block while their event loops run and exit once shut down.
MPI rank 1 block while their event loops run.

In normal operation these ranks exit once rank 1 ends. If exit=False is set they
instead return an bool indicating whether they are the client and should execute
more client code, or a worker/scheduler who should not. In this case the user is
responsible for the client calling send_close_signal when work is complete, and
checking the returned value to choose further actions.

Parameters
----------
Expand All @@ -51,10 +59,23 @@ def initialize(
Class to use when creating workers
worker_options : dict
Options to pass to workers
comm: mpi4py.MPI.Intracomm
Optional MPI communicator to use instead of COMM_WORLD
exit: bool
Whether to call sys.exit on the workers and schedulers when the event
loop completes.

Returns
-------
is_client: bool
Only returned if exit=False. Inidcates whether this rank should continue
to run client code (True), or if it acts as a scheduler or worker (False).
"""
from mpi4py import MPI
if comm is None:
from mpi4py import MPI

comm = MPI.COMM_WORLD

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
loop = IOLoop.current()

Expand All @@ -75,15 +96,20 @@ async def run_scheduler():
await scheduler.finished()

asyncio.get_event_loop().run_until_complete(run_scheduler())
sys.exit()
if exit:
sys.exit()
else:
return False

else:
scheduler_address = comm.bcast(None, root=0)
dask.config.set(scheduler_address=scheduler_address)
comm.Barrier()

if rank == 1:
atexit.register(send_close_signal)
if exit:
atexit.register(send_close_signal)
return True
else:

async def run_worker():
Expand All @@ -106,10 +132,24 @@ async def run_worker():
await worker.finished()

asyncio.get_event_loop().run_until_complete(run_worker())
sys.exit()
if exit:
sys.exit()
else:
return False


def send_close_signal():
"""
The client can call this function to explicitly stop
the event loop.

This is not needed in normal usage, where it is run
automatically when the client code exits python.

You only need to call this manually when using exit=False
in initialize.
"""

async def stop(dask_scheduler):
await dask_scheduler.close()
await gen.sleep(0.1)
Expand Down
25 changes: 25 additions & 0 deletions dask_mpi/tests/core_no_exit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from distributed import Client
from mpi4py.MPI import COMM_WORLD as world

from dask_mpi import initialize, send_close_signal

# Split our MPI world into two pieces, one consisting just of
# the old rank 3 process and the other with everything else
new_comm_assignment = 1 if world.rank == 3 else 0
comm = world.Split(new_comm_assignment)

if world.rank != 3:
# run tests with rest of comm
is_client = initialize(comm=comm, exit=False)

if is_client:
with Client() as c:
c.submit(lambda x: x + 1, 10).result() == 11
c.submit(lambda x: x + 1, 20).result() == 21
send_close_signal()

# check that our original comm is intact
world.Barrier()
x = 100 if world.rank == 0 else 200
x = world.bcast(x)
assert x == 100
20 changes: 20 additions & 0 deletions dask_mpi/tests/test_no_exit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from __future__ import absolute_import, division, print_function

import os
import subprocess
import sys

import pytest

pytest.importorskip("mpi4py")


def test_no_exit(mpirun):
script_file = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "core_no_exit.py"
)

p = subprocess.Popen(mpirun + ["-np", "4", sys.executable, script_file])

p.communicate()
assert p.returncode == 0