## UMAP MNMG benchmark (runtime & trustworthiness)

In [1]:
from cuml.dask.manifold import UMAP as UMAP_MNMG
from cuml.manifold import UMAP
from cuml.metrics import trustworthiness

from dask_cuda import LocalCUDACluster
from dask.distributed import Client
import dask.array as da

from sklearn.datasets import make_blobs
import time
import numpy as np

In [2]:
def benchmark(args):
    # Generate dataset
    X, y = make_blobs(n_samples=args['n_samples'], n_features=args['n_features'],
                      centers=args['centers'])
    
    # Number of samples for local train
    n_sampling = int(args['n_samples'] * args['sampling_ratio'])
    
    # Generate local train data
    selection = np.random.choice(args['n_samples'], n_sampling)
    lX = X[selection]
    
    # Number of samples per partition
    n_samples_per_part = int(args['n_samples'] / args['n_parts'])
    
    # Generate partitioning of distributed data for inference
    chunks = [n_samples_per_part] * args['n_parts']
    chunks[-1] += args['n_samples'] % n_samples_per_part
    chunks = tuple(chunks)
    dX = da.from_array(X, chunks=(chunks, -1))

    # Warm up (used to prevent statistical anomalies in time measurement due to first time initialization)
    local_model = UMAP(n_components=2, n_neighbors=args['n_neighbors'],
                       n_epochs=args['n_epochs'], random_state=args['random_state'])
    local_model.fit(lX)
    model = UMAP_MNMG(local_model)
    model.transform(dX).compute()
    
    # Measure and average runtime and trustworthiness accross multiple runs
    durations = []
    trust_scores = []
    for i in range(args['n_iter']):
        
        # Train local model
        local_model = UMAP(n_components=2, n_neighbors=args['n_neighbors'],
                       n_epochs=args['n_epochs'], random_state=args['random_state'])
        local_model.fit(lX)
        
        # Pass trained model and order distributed inference
        model = UMAP_MNMG(local_model)
        lazy_transformed = model.transform(dX)
        
        # Perform distributed inference and measure time
        start = time.time()
        transformed = lazy_transformed.compute()
        durations.append(time.time()-start)
        
        # Compute trustworthiness score
        trust_scores.append(trustworthiness(X, transformed, n_neighbors=args['n_neighbors']))
        
    durations = np.array(durations)
    trust_scores = np.array(trust_scores)
    
    # Return runtime average and variance as well as trustworthiness score average
    return durations.mean(), durations.var(), trust_scores.mean()

### Parameters definitions :
- **n_samples** : number of samples
- **n_features** : number of features
- **centers** : number of blobs to generate the dataset
- **n_neighbors** : number of neighbors used to generate fuzzy simplicial set in UMAP
- **n_epochs** : number of iterations during UMAP optimization step
- **random_state** : random seed used in UMAP
- **n_parts** : number of partitions into which the dataset is divided, also number of workers/GPUs to be used
- **sampling_ratio** : ratio of samples used during local training

In [None]:
import warnings
warnings.filterwarnings('ignore')

args = {'n_features':100, 'centers':300,
        'n_neighbors':15, 'n_epochs':5000, 'random_state': 42,
        'sampling_ratio':0.1, 'n_iter': 3}

# Start Dask client
cluster = LocalCUDACluster(n_workers=8, threads_per_worker=1)
client = Client(cluster)

# Comparing runtime and trustworthiness with different number of partitions, number of samples and ratio of local train
for n_parts in [3, 8]:
    for n_samples in [100000, 500000, 1000000]:
        for sampling_ratio in [0.001,0.005]:
            args['n_parts'] = n_parts
            args['n_samples'] = n_samples
            args['sampling_ratio'] = sampling_ratio
            duration_mean, duration_var, trust = benchmark(args)
            print("n_parts: {}, n_samples: {}, sampling_ratio: {}, duration avg - var: {:.2f} - {:.2f}, tustworthiness: {:.2f}".format(n_parts, n_samples,
                                                                                                        sampling_ratio, duration_mean, duration_var, trust))
# Stop Dask client
client.close()