In [None]:
from distributed import Client
from dask_cuda import LocalCUDACluster
from dask.distributed import wait
from cuml.dask.common.utils import persist_across_workers
import dask_cudf

import numpy as np
import cupy as cp
from cudf import DataFrame
from cuml.dask.neighbors import NearestNeighbors
import os

In [None]:
# Dask-CUDA configuration
# os.environ["DASK_RMM__POOL_SIZE"] = "500M"
os.environ["DASK_UCX__CUDA_COPY"] = "True"
# os.environ["DASK_UCX__TCP"] = "True"
# os.environ["DASK_UCX__NVLINK"] = "True"
# os.environ["DASK_UCX__INFINIBAND"] = "True"
os.environ["DASK_UCX__NET_DEVICES"] = "ib0"

In [None]:
cluster = LocalCUDACluster(n_workers=1, threads_per_worker=512, rmm_pool_size="500M", enable_tcp_over_ucx=True, enable_nvlink=True, enable_infiniband=True)
client = Client(cluster)

In [None]:
def distribute_data(client, np_array, n_workers=None, partitions_per_worker=1):
    # Get workers on cluster
    workers = list(client.has_what().keys())
    # Select only n_workers workers
    if n_workers:
        workers = workers[:n_workers]
    # Compute number of partitions
    n_partitions = partitions_per_worker * len(workers)
    # From host to device
    cp_array = cp.array(np_array)
    # From cuPy array to cuDF Dataframe
    cudf_df = DataFrame(cp_array)
    # From cuDF Dataframe to distributed Dask Dataframe
    dask_cudf_df = dask_cudf.from_cudf(cudf_df, npartitions=n_partitions)
    dask_cudf_df, = persist_across_workers(client, [dask_cudf_df], workers=workers)
    wait(dask_cudf_df)
    return dask_cudf_df

In [None]:
# Define index and query
n_points = 65536
index = np.random.rand(n_points, 3).astype(np.float32)

# Distribute index and query
dist_index = distribute_data(client, index, n_workers=1)

In [None]:
n_neighbors = 16

# Create cuML distributed KNN model
model = NearestNeighbors(client=client,
                         n_neighbors=n_neighbors)
# Fit model with index
model.fit(dist_index)
# Run search with a query
distances, indices = model.kneighbors(dist_index)
# Collect results back to the calling machine
distances, indices = client.compute([distances, indices])