diff --git a/dask_mpi/cli.py b/dask_mpi/cli.py index ae8b048..ae0cd84 100644 --- a/dask_mpi/cli.py +++ b/dask_mpi/cli.py @@ -1,12 +1,11 @@ import click import asyncio -from mpi4py import MPI + from dask.distributed import Scheduler, Worker, Nanny from distributed.cli.utils import check_python_3 -comm = MPI.COMM_WORLD -rank = comm.Get_rank() +from mpi4py import MPI @click.command() @@ -72,34 +71,40 @@ def main( protocol, ): + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + if rank == 0 and scheduler: - async def run(): + async def run_scheduler(): async with Scheduler( interface=interface, protocol=protocol, - scheduler_file=scheduler_file, dashboard_address=dashboard_address, - port=scheduler_port, + scheduler_file=scheduler_file, ) as s: + comm.Barrier() await s.finished() + asyncio.get_event_loop().run_until_complete(run_scheduler()) + else: + comm.Barrier() - async def run(): + async def run_worker(): WorkerType = Nanny if nanny else Worker async with WorkerType( - scheduler_file=scheduler_file, interface=interface, protocol=protocol, nthreads=nthreads, memory_limit=memory_limit, local_directory=local_directory, name=rank, + scheduler_file=scheduler_file, ) as worker: await worker.finished() - asyncio.get_event_loop().run_until_complete(run()) + asyncio.get_event_loop().run_until_complete(run_worker()) def go():