In [1]:
try:
    import inotify
except ImportError as e:
    !pip install inotify

In [2]:
from dask.distributed import Client
import subprocess as subp
import os
import dask
import inotify.adapters


In [3]:
from dask_cuda import LocalCUDACluster
cluster = LocalCUDACluster()

In [4]:
c = Client(cluster)
c

0,1
Client  Scheduler: tcp://127.0.0.1:38083  Dashboard: http://127.0.0.1:8787/status,Cluster  Workers: 2  Cores: 2  Memory: 50.39 GB


In [5]:
from hello_mpi import MPI_World
import random
from dask.distributed import wait

In [6]:
class Dask_MPI_Demo:
    
    def __init__(self, client, uriFile = "ompi.server.uri" ):
        self.client = client
        self.uriFile = uriFile
        
        print("Starting ompi-server")
        self.create_ompi_server_()
        
    def __dealloc__(self):
        self.mpiServer.kill()
        os.remove(self.uriFile)
    
    def create_ompi_server_(self):
        
        cmd = ["ompi-server", "--no-daemonize","-r", self.uriFile]
        cmdStr = "exec " + " ".join(cmd)

        i = inotify.adapters.Inotify()

        with open(self.uriFile, 'w') as f:
            pass

        i.add_watch(self.uriFile)

        proc = subp.Popen(cmdStr, shell=True)

        # Polls the ompi-server uri file for 
        # changes. 
        # @todo: Make this robust to failures
        for event in i.event_gen(yield_nones=False):
            (_, type_names, path, filename) = event
            if "IN_CLOSE_WRITE" in type_names:
                break

        i.remove_watch(self.uriFile)

        mpiServer = proc
        import time
        with open(self.uriFile, "r") as fp:
            mpiServerUri = fp.read().rstrip()

        self.mpiServer = mpiServer
        self.mpiServerUri = mpiServerUri

        os.environ["OMPI_MCA_pmix_server_uri"] = self.mpiServerUri

    
    @staticmethod
    def func_parse_host_port_(address):
        if '://' in address:
            address = address.rsplit('://', 1)[1]
        host, port = address.split(':')
        port = int(port)
        return host, port
    
    @staticmethod
    def func_init_(workerId, nWorkers, ompiServerUri):
        if ompiServerUri is None:
            raise Exception("ompiServerUri is mandatory!")
        os.environ["OMPI_MCA_pmix_server_uri"] = ompiServerUri
        w = dask.distributed.get_worker()
        print("Hello World! from ip=%s worker=%s/%d uri=%s" % \
              (w.address, w.name, nWorkers, ompiServerUri))
        print("Worker=%s finished" % w.name)

        a = MPI_World(workerId, nWorkers)
        a.init()
        a.create_builder()

        return a
    
    @staticmethod
    def func_build_session_(world, r):
        world.new_session()
        
    @staticmethod
    def func_open_server_port_(world, r):
        world.open_server_port()

                
    @staticmethod
    def func_get_server_port_(world, r):
        world.get_server_port()

    @staticmethod
    def func_connect_to_server_(world, r):
        world.connect_to_server()

    @staticmethod
    def func_connect_to_client_(world, clientId, r):
        world.connect_to_client(clientId)

    @staticmethod
    def func_merge_clients_(world, r):
        world.merge_clients()

        
    @staticmethod
    def func_get_rank_(world, r):
        return world.rank()

    @staticmethod
    def func_finalize_(world, r):
        world.finalize()
        
        
    def get_workers_(self):
        return list(map(lambda x: Dask_MPI_Demo.func_parse_host_port_(x), self.client.has_what().keys()))
    
    def init(self):
        workers = self.get_workers_()
        workers_indices = list(zip(workers, range(len(workers))))

        
        w, i = workers_indices[0]
        self.server = self.client.submit(Dask_MPI_Demo.func_init_,
                                               i, len(workers), self.mpiServerUri, workers=[w])
        
        
        self.clients = [(idx, worker, self.client.submit(Dask_MPI_Demo.func_init_, 
                                           idx, 
                                           len(workers), 
                                           self.mpiServerUri, 
                                           workers=[worker])) 
             for worker, idx in workers_indices[1:]]
                       
        
    def build_session(self):
        
        print("Building server")
                       
        self.client.submit(Dask_MPI_Demo.func_open_server_port_, self.server, random.random()).result()
        [self.client.submit(Dask_MPI_Demo.func_get_server_port_, f, random.random()).result() for i, w, f in self.clients]

        print("Connecting clients to server")
        for idx, worker, cur_client in self.clients:
            s = self.client.submit(Dask_MPI_Demo.func_connect_to_server_, cur_client, random.random())
            c = self.client.submit(Dask_MPI_Demo.func_connect_to_client_, self.server, idx, random.random())
            
            wait([s, c])
            
        print("Merging client ranks")
        for idx, worker, cur_client in self.clients:
            self.client.submit(Dask_MPI_Demo.func_merge_clients_, cur_client, random.random()).result()
        
    def finalize(self):
        [c.submit(Dask_MPI_Demo.func_finalize_, a, random.random()) for i, w, a in self.clients]
        c.submit(Dask_MPI_Demo.func_finalize_, self.server, random.random())
        self.clients = None
        self.server = None
        
    def get_client_ranks(self):
        return [c.submit(Dask_MPI_Demo.func_get_rank_, a, random.random()).result() for i, w, a in self.clients]
    
    def get_server_rank(self):
        return c.submit(Dask_MPI_Demo.func_get_rank_, self.server, random.random()).result()
        
        

In [7]:
demo = Dask_MPI_Demo(c)
demo.init()

Starting ompi-server


In [8]:
demo.build_session()

Building server
Connecting clients to server
Merging client ranks


In [9]:
print("Client Ranks: " + str(demo.get_client_ranks()))

Client Ranks: [1]


In [10]:
print("Server Rank: " + str(demo.get_server_rank()))

Server Rank: 0


In [11]:
demo.finalize()