In [1]:
from dask.distributed import Client
import os

import dask

## Build Dask Cluster

If you have an existing Dask cluster running already, set the scheduler address below. Otherwise, leave it to `None` and a local cluster will be created.

In [2]:
scheduler_address = None #"tcp://10.2.168.161:8786"

if scheduler_address is None:
    from dask_cuda import LocalCUDACluster
    cluster = LocalCUDACluster()
    c = Client(cluster)
else:
    c = Client(scheduler_address)
c

Port 8787 is already in use. 
Perhaps you already have a cluster running?
Hosting the diagnostics dashboard on a random port instead.


0,1
Client  Scheduler: tcp://127.0.0.1:46718  Dashboard: http://127.0.0.1:38116/status,Cluster  Workers: 8  Cores: 8  Memory: 540.95 GB


## Imports

In [3]:
from nccl_example import nccl, inject_comms_on_handle
from cuml.common.handle import Handle

In [4]:
from dask import delayed
import dask.dataframe as dd
from dask.distributed import wait
from dask.distributed import get_worker

import numba.cuda
import cudf
import numpy as np

import random

import asyncio
import ucp

import uuid

## Helper functions

In [5]:
from tornado import gen
from dask.distributed import default_client
from toolz import first
import logging
import dask.dataframe as dd

import dask_cudf
import numpy as np
import cudf
import pandas as pd

from dask.distributed import wait


def parse_host_port(address):
    if '://' in address:
        address = address.rsplit('://', 1)[1]
    host, port = address.split(':')
    port = int(port)
    return host, port

import dask_cudf


@gen.coroutine
def extract_ddf_partitions(ddf):
    """
    Given a Dask cuDF, return a tuple with (worker, future) for each partition
    """
    client = default_client()
    
    delayed_ddf = ddf.to_delayed()
    parts = client.compute(delayed_ddf)
    yield wait(parts)
    
    key_to_part_dict = dict([(str(part.key), part) for part in parts])
    who_has = yield client.who_has(parts)

    worker_map = []
    for key, workers in who_has.items():
        worker = parse_host_port(first(workers))
        worker_map.append((worker, key_to_part_dict[key]))

    gpu_data = [(worker, part) for worker, part in worker_map]

    yield wait(gpu_data)

    raise gen.Return(gpu_data)
    
    
def create_df(f, m, n, c):
    """
    Generates a cudf of the given size with sklearn's make_blobs 
    """
    from sklearn.datasets.samples_generator import make_blobs
    X, y = make_blobs(n_samples=m, centers=c, n_features=n, random_state=0)
    ret = cudf.DataFrame([(i,
                           X[:, i].astype(np.float64)) for i in range(n)],
                         index=cudf.dataframe.RangeIndex(f * m,
                                                         f * m + m, 1))
    return ret

def get_meta(df):
    ret = df.iloc[:0]
    return ret

def gen_dask_cudf(nrows, ncols, clusters):
    workers = c.has_what().keys()

    # Create dfs on each worker (gpu)
    dfs = [c.submit(create_df, n, nrows, ncols, clusters, workers=[worker])
           for worker, n in list(zip(workers, list(range(len(workers)))))]
    # Wait for completion
    wait(dfs)

    meta = c.submit(get_meta, dfs[0]).result()
    return dask_cudf.from_delayed(dfs, meta=meta)

def to_dask_cudf(futures):
    # Convert a list of futures containing dfs back into a dask_cudf
    dfs = [d for d in futures if d.type != type(None)]
    meta = c.submit(get_meta, dfs[0]).result()
    return dd.from_delayed(dfs, meta=meta)


async def connection_func(ep, listener):
    print("connection received from " + str(ep))

    



## Reusable NCCL Base Class

In [6]:
class CommsBase:
    
    def __init__(self, comms_coll = True, comms_p2p = False):
        self.client = default_client()
        self.comms_coll = comms_coll
        self.comms_p2p = comms_p2p
        
        # Used to identify this distinct session on workers
        self.sessionId = uuid.uuid4().bytes
            
        self.worker_addresses = self.get_workers_()
        self.workers = list(map(lambda x: parse_host_port(x), self.worker_addresses))

    def __dealloc__(self):
        self.destroy()

    def get_workers_(self, parse_address = True):
        """
        Return the list of workers parsed as [(address, port)]
        """
        return list(self.client.has_what().keys())
    
    def worker_ranks(self):
        """
        Builds a dictionary of { (worker_address, worker_port) : worker_rank }
        """
        return dict(list(map(lambda x: (x[1], x[0]), self.nccl_clique)))
    
    def worker_ports(self):
        """
        Builds a dictionary of { (worker_address, worker_port) : worker_port }
        """
        return dict(list(self.ucp_ports))
    
    
    def worker_info(self):
        """
        Builds a dictionary of { (worker_address, worker_port) : (worker_rank, worker_port ) }
        """
        ranks = self.worker_ranks() if self.comms_coll else None
        ports = self.worker_ports() if self.comms_p2p else None
        
        if self.comms_coll and self.comms_p2p:
            output = {}
            for k in self.worker_ranks().keys():
                output[k] = (ranks[k], ports[k])
            return output
                
        elif self.comms_coll:
            return ranks
        elif self.comms_p2p:
            return ports

    @staticmethod
    def func_init_nccl(workerId, nWorkers, uniqueId):
        """
        Initialize ncclComm_t on worker
        """
        n = nccl()
        n.init(nWorkers, uniqueId, workerId)
        return n

    @staticmethod
    def func_get_ucp_port(sessionId, r):
        """
        Return the port assigned to a UCP listener on worker
        """
        dask_worker = get_worker()
        port = dask_worker.data[sessionId].port
        return port

    @staticmethod
    async def ucp_create_listener(sessionId, r):
        dask_worker = get_worker()
        if sessionId in dask_worker.data:
            print("Listener already started for sessionId=" + str(sessionId))
        else:
            ucp.init()
            listener =  ucp.start_listener(connection_func, 0, is_coroutine=True)

            dask_worker.data[sessionId] = listener
            task = asyncio.create_task(listener.coroutine)

            while not task.done():
                await task
                await asyncio.sleep(1)
                
            ucp.fin()
            del dask_worker.data[sessionId]
            

            
    @staticmethod
    def ucp_stop_listener(sessionId, r):
        dask_worker = get_worker()
        if sessionId in dask_worker.data:
            listener = dask_worker.data[sessionId]
            ucp.stop_listener(listener)
        else:
            print("Listener not found with sessionId=" + str(sessionId))

    def create_ucp_listeners(self):
        """
        Build a UCP listener on each worker. Since this async function is long-running, the listener is 
        placed in the worker's data dict. 
        
        NOTE: This is not the most ideal design because the worker's data dict could be serialized at 
        any point, which would cause an error. Need to sync w/ the Dask team to see if there's a better
        way to do this. 
        """
        [self.client.run(CommsBase.ucp_create_listener, self.sessionId, random.random(), workers = [w], wait = False) for w in 
         self.worker_addresses]
        
    def get_ucp_ports(self):
        """
        Return the UCP listener ports attached to this session
        """
        self.ucp_ports = [(w, self.client.submit(CommsBase.func_get_ucp_port, self.sessionId, random.random(), workers = [w]).result()) 
                          for w in self.workers]
        
    def stop_ucp_listeners(self):
        """
        Stops the UCP listeners attached to this session
        """
        a = [c.submit(CommsBase.ucp_stop_listener, self.sessionId, random.random(), workers=[w]) 
         for w in self.workers]
        wait(a)
    
    @staticmethod
    def func_build_handle(nccl_comm, eps, nWorkers, workerId):
        
        ucp_worker = ucp.get_ucp_worker()
        
        handle = Handle()
        inject_comms_on_handle(handle, nccl_comm, ucp_worker, eps, nWorkers, workerId)
        return handle

    def init_nccl(self):
        """
        Use nccl-py to initialize ncclComm_t on each worker and 
        store the futures for this instance. 
        """
        self.uniqueId = nccl.get_unique_id()

        workers_indices = list(zip(self.workers, range(len(self.workers))))

        self.nccl_clique = [(idx, worker, self.client.submit(KMeans.func_init_nccl, 
                                           idx, 
                                           len(self.workers), 
                                           self.uniqueId,
                                           workers=[worker]))
             for worker, idx in workers_indices]
        
    def init_ucp(self):
        """
        Use ucx-py to initialize ucp endpoints so that every 
        worker can communicate, point-to-point, with every other worker
        """
        self.create_ucp_listeners()
        self.get_ucp_ports()
        self.ucp_create_endpoints()
        
    def init(self):
        if self.comms_coll:
            self.init_nccl()
        
        if self.comms_p2p:
            self.init_ucp()
            
        # Combine ucp ports w/ nccl ranks
            
        eps_futures = dict(self.ucp_endpoints)
            
        self.handles = [(wid, w, self.client.submit(CommsBase.func_build_handle, f, eps_futures[w], len(self.workers), wid, workers = [w])) 
                        for wid, w, f in self.nccl_clique]
        
        
    @staticmethod
    async def func_ucp_create_endpoints(sessionId, worker_info, r):
        """
        Runs on each worker to create ucp endpoints to all other workers
        """
        dask_worker = get_worker()
        local_address = parse_host_port(dask_worker.address)
        
        eps = [None]*len(worker_info)
        
        count = 1
        size = len(worker_info)-1
        
        for k in worker_info:
            if k != local_address:
                ip, port = k
                rank, ucp_port = worker_info[k]
                ep = await ucp.get_endpoint(ip.encode(), ucp_port, timeout = 1)
                eps[rank] = ep
                count +=1
        dask_worker.data[str(sessionId) + "_eps"] = eps
    
    @staticmethod
    def func_get_endpoints(sessionId, r):
        """
        Fetches (and removes) the endpoints from the worker's data dict
        """
        dask_worker = get_worker()
        eps = dask_worker.data[str(sessionId)+"_eps"]
        del dask_worker.data[str(sessionId)+"_eps"]
        return eps
        
    def ucp_create_endpoints(self):
        
        worker_info = self.worker_info()
        
        [self.client.run(CommsBase.func_ucp_create_endpoints, self.sessionId, worker_info, random.random(), workers = [w], wait = True)
                          for w in self.worker_addresses]
        
        ret = [(w, self.client.submit(CommsBase.func_get_endpoints, self.sessionId, random.random(), workers=[w])) for w in self.workers]
        wait(ret)
        
        self.ucp_endpoints = ret

    @staticmethod
    def func_destroy_nccl(nccl_comm, r):
        """
        Destroys NCCL communicator on worker
        """
        nccl_comm.destroy()
        

    def destroy_nccl(self):
        """
        Destroys all NCCL communicators on workers
        """
        a = [self.client.submit(CommsBase.func_destroy_nccl, f, random.random(), workers=[w]) for wid, w, f in self.nccl_clique]
        wait(a)
        
    def func_destroy_ep(eps, r):
        """
        Destroys UCP endpoints on worker
        """
        for ep in eps:
            if ep is not None:
                ucp.destroy_ep(ep)
                
    def destroy_eps(self):
        """
        Destroys all UCP endpoints on all workers
        """
        a = [self.client.submit(CommsBase.func_destroy_ep, f, random.random(), workers = [w]) for w, f in self.ucp_endpoints]
        wait(a)
        
    def destroy_ucp(self):
        self.destroy_eps()
        self.stop_ucp_listeners()
        

    def destroy(self):

        self.handles = None
        
        if self.comms_p2p:
            self.destroy_ucp()
            self.ucp_ports = None
            self.ucp_endpoints = None

        if self.comms_coll:
            # TODO: Figure out why this fails when UCP + NCCL are both used
#             self.destroy_nccl()
            self.nccl_clique = None

## Dask-cuML OPG KMeans Implementation

In [7]:
from cuml.cluster import KMeans as cumlKMeans

class KMeans(CommsBase):
    
    def __init__(self, n_clusters = 8, init_method = "random", verbose = 0):
        super(KMeans, self).__init__(comms_coll = True, comms_p2p = True)
        self.init_(n_clusters = n_clusters, init_method = init_method, verbose = verbose)

    def init_(self, n_clusters, init_method, verbose = 0):
        """
        Creates local kmeans instance on each worker
        """
        self.init()
        
        self.kmeans = [(w, c.submit(KMeans.func_build_kmeans_, 
                                    a, n_clusters, init_method, verbose, i, 
                                    workers=[w])) for i, w, a in self.handles]
        wait(self.kmeans)

    
    @staticmethod
    def func_build_kmeans_(handle, n_clusters, init_method, verbose, wid):
        """
        Create local KMeans instance on worker
        """
        w = dask.distributed.get_worker()
        
        return cumlKMeans(handle = handle, init = init_method, n_clusters = n_clusters, verbose = verbose)
    
    @staticmethod
    def func_fit(model, df, wid): return model.fit(df)
    
    @staticmethod
    def func_predict(model, df, wid): return model.predict(df)

    def run_model_func_on_dask_cudf(self, func, X):
        
        gpu_futures = c.sync(extract_ddf_partitions, X)

        worker_model_map = dict(map(lambda x: (x[0], x[1]), self.kmeans))
        worker_rank_map = self.worker_ranks()

        f = [c.submit(func,                              # Function to run on worker
                      worker_model_map[w],               # Model instance
                      f,                                 # Input DataFrame partition
                      worker_rank_map[w])                # Worker ID
             for w, f in gpu_futures] 
        wait(f)
        return f
    
    def fit(self, X):
        self.run_model_func_on_dask_cudf(KMeans.func_fit, X)
        return self
        
    def predict(self, X):
        f = self.run_model_func_on_dask_cudf(KMeans.func_predict, X)
        return to_dask_cudf(f)
    
    def fit_predict(self, X):
        return self.fit(X).predict(X)

## Execute End-To-End Example

In [8]:
n_samples_per_worker = 10
n_features = 50
n_clusters = 10

Create a Dask cuDF using sklearn's `make_blobs` for testing

In [9]:
X = gen_dask_cudf(n_samples_per_worker, n_features, n_clusters)

First, a Dask-cuML `KMeans` instance is created, which initializes it's own NCCL clique

In [10]:
demo = KMeans(n_clusters, init_method = "random", verbose = 1)

Print out the ranks assigned to the workers in the NCCL clique

In [11]:
demo.worker_ranks()

{('127.0.0.1', 37141): 0,
 ('127.0.0.1', 39300): 1,
 ('127.0.0.1', 39305): 2,
 ('127.0.0.1', 39349): 3,
 ('127.0.0.1', 39415): 4,
 ('127.0.0.1', 41404): 5,
 ('127.0.0.1', 41590): 6,
 ('127.0.0.1', 41620): 7}

Verify we have one cuDF partition per worker

In [12]:
c.has_what()

{'tcp://127.0.0.1:37141': ('func_get_endpoints-213ba02b18b277caacdafa1607e0648d',
  'func_build_kmeans_-10ce05c3ac583aef53e6e8ba78ebdaba',
  'create_df-14e81bf95a9b4429e55a40d3db9bdc5e',
  'func_build_handle-4d186a1b7aa303b1041b99ee15cdffcd',
  'func_init_nccl-43ef591fa7f9de497cd33568b50cec93'),
 'tcp://127.0.0.1:39300': ('func_init_nccl-d10a6c3596b064632bd5d9fee14e7332',
  'func_build_handle-be5f29e285d4976eaa5d60b6ffdeb785',
  'create_df-01846fbe2c680b2ec06cddce6288ace1',
  'func_get_endpoints-77e49dc6ba2c0123a9714787c6721204',
  'func_build_kmeans_-4fbf7cff2b285a3d049efdf782abde5c'),
 'tcp://127.0.0.1:39305': ('create_df-db0b1ab4a3df18d1c2752b47a9e05a02',
  'func_get_endpoints-f345f42380219bbe28d5f8880b3fb031',
  'func_build_kmeans_-fed32eed3cf763ca6dbddf421acccdc2',
  'func_init_nccl-c4fc3fb5343136e13d8043a8641b4a35',
  'func_build_handle-96ff7342a8e8a59f8418d72745712d61'),
 'tcp://127.0.0.1:39349': ('func_init_nccl-c374cbfc684464757de8aceb48a1162e',
  'func_build_handle-6e73faaa37

In [13]:
demo.worker_ranks()

{('127.0.0.1', 37141): 0,
 ('127.0.0.1', 39300): 1,
 ('127.0.0.1', 39305): 2,
 ('127.0.0.1', 39349): 3,
 ('127.0.0.1', 39415): 4,
 ('127.0.0.1', 41404): 5,
 ('127.0.0.1', 41590): 6,
 ('127.0.0.1', 41620): 7}

Fit the KMeans MNMG model

In [14]:
demo.fit(X)

<__main__.KMeans at 0x7ff9c03d5ba8>

Predict labels for the same inputs we trained on

In [15]:
result = demo.predict(X)

In [16]:
print(str(result))

<dask_cudf.Series | 16 tasks | 8 npartitions>


In [17]:
print(str(result.compute()))

0    1
1    0
2    0
3    1
4    1
5    2
6    3
7    2
8    0
9    1
[70 more rows]
dtype: int32


In [18]:
demo.destroy()