# XGBoost.Dask in many threads

Sometimes we want to train many large XGBoost models in parallel.  We do so in this example with ...

1.  The `xgboost.dask` project to do large training runs
2.  Optuna to do hyper-parameter-optimization
3.  A thread pool, to run many of these in parallel
4.  Coiled to launch Dask clusters (but you could swap in your favorite Dask deployment technology as you like)

Using `xgboost.dask` from many threads tooks a couple of small tweaks across projects.  This notebook resulted in the following PRs and issues:

-  https://github.com/dask/distributed/issues/7377
-  https://github.com/dask/dask/pull/9723
-  https://github.com/dask/distributed/pull/7369
-  https://github.com/dmlc/xgboost/pull/8558 (mostly cosmetic, not necessary)
-  Also something in Coiled to allow package_sync to be thread-safe, should be released by 2022-12-07

In [1]:
import datetime
import threading
from concurrent.futures import ThreadPoolExecutor

from distributed import Client
import dask.dataframe as dd
from coiled import Cluster
import coiled

import optuna
from dask_ml.metrics import mean_squared_error as lazy_mse
import xgboost as xgb
from xgboost.dask import DaskDMatrix

from dask_ml.datasets import make_classification_df
from dask_ml.model_selection import train_test_split, KFold
from dask_ml.preprocessing import OneHotEncoder
import dask.array as da
import dask.dataframe as dd
from s3fs import S3FileSystem

import pandas as pd

In [2]:
import dask, coiled
print("coiled:", coiled.__version__)
print("dask:", dask.__version__)
print("dask.distributed:", dask.distributed.__version__)
print("optuna:", optuna.__version__)
print("xgboost:", xgb.__version__)
print("coiled:", coiled.__version__)

coiled: 0.2.58
dask: 2022.12.1
dask.distributed: 2022.12.1
optuna: 3.1.0.dev
xgboost: 1.7.2
coiled: 0.2.58


In [3]:
Q3 = 1415.0

### Load data

In [4]:
BOROUGH_MAPPING = {
    "Manhattan": "Superborough 1",
    "Bronx": "Superborough 1",
    "EWR": "Superborough 1",
    "Brooklyn": "Superborough 2",
    "Queens": "Superborough 2",
    "Staten Island": "Superborough 3",
    "Unknown": "Unknown",
    }

In [5]:
def load_data():
    print("loading data")
    to_exclude=["string", "category", "object"]
    ddf= dd.read_parquet("s3://prefect-dask-examples/nyc-uber-lyft/processed_files.parquet")
    ddf = ddf.assign(accessible_vehicle = 1)
    print("Make accessible feature")
    ddf.accessible_vehicle = ddf.accessible_vehicle.where(ddf.on_scene_datetime.isnull(),0)  # Only applies if the vehicle is wheelchair accessible
    ddf = ddf.assign(pickup_month = ddf.pickup_datetime.dt.month)
    ddf = ddf.assign(pickup_dow = ddf.pickup_datetime.dt.dayofweek)
    ddf = ddf.assign(pickup_hour = ddf.pickup_datetime.dt.hour)
    
    ddf = ddf.drop(columns=['on_scene_datetime', 'request_datetime',
                            'pickup_datetime', 'dispatching_base_num',
                            'originating_base_num', 'shared_request_flag',
                           'shared_match_flag','dropoff_datetime',
                            'base_passenger_fare', 'bcf', 'sales_tax',
                            'tips', 'driver_pay', 'access_a_ride_flag',
                            'wav_match_flag',
                           ]
                  )

    ddf = ddf.dropna(how="any")
    ddf = ddf.repartition(partition_size="128MB").persist()
    ddf = ddf.reset_index(drop=True)

    original_rowcount = len(ddf.index)

    # Remove outliers
    # Based on our earlier EDA, we will set the lower bound at zero, which is consistent with our
    # domain knowledge that no trip should have a duration less than zero.  We calculate the upper_bound
    # and filter the IQR
    lower_bound = 0
    upper_bound = Q3 + (1.5*(Q3 - lower_bound))
    
    ddf = ddf.loc[(ddf['trip_time'] >= lower_bound) & (ddf['trip_time'] <= upper_bound)]
    
    ddf = ddf.repartition(partition_size="128MB").persist()
    print(f"Fraction of dataset left after removing outliers:  {len(ddf.index) / original_rowcount}")

    return ddf

In [6]:
def get_superborough(df):
    PUSuperborough = [BOROUGH_MAPPING.get(i) for i in df.PUBorough.tolist()]
    DOSuperborough = [BOROUGH_MAPPING.get(i) for i in df.DOBorough.tolist()]
    cross_superborough = ["N" if i==j else "Y" for (i,j) in zip(PUSuperborough, DOSuperborough)]
    return df.assign(CrossSuperborough = cross_superborough)

In [7]:
def make_taxi_data(ddf):
    print("Load taxi data")
    taxi_df = pd.read_csv("data/taxi+_zone_lookup.csv", usecols=["LocationID", "Borough"])

    ddf = dd.merge(ddf, taxi_df, left_on="PULocationID", right_on="LocationID", how="inner")
    ddf = ddf.rename(columns={"Borough": "PUBorough"})
    ddf = ddf.drop(columns="LocationID")

    ddf = dd.merge(ddf, taxi_df, left_on="DOLocationID", right_on="LocationID", how="inner")
    ddf = ddf.rename(columns={"Borough": "DOBorough"})
    ddf = ddf.drop(columns="LocationID")  
    
    print("Make superboroughs")
    ddf = ddf.map_partitions(lambda df: get_superborough(df))
    ddf['airport_fee'] = ddf['airport_fee'].replace("None", 0)
    ddf['airport_fee'] = ddf['airport_fee'].replace('nan', 0)
    ddf['airport_fee'] = ddf['airport_fee'].astype(float)
    ddf['airport_fee'] = ddf['airport_fee'].fillna(0)

    ddf = ddf.repartition(partition_size="128MB").persist()

    print("Make categoricals")
    categories = ['hvfhs_license_num', 'PULocationID', "DOLocationID", 'wav_request_flag',
                  'accessible_vehicle', 'pickup_month', 'pickup_dow', 'pickup_hour', 
                  'PUBorough', 'DOBorough', 'CrossSuperborough'
                 ]
    ddf[categories] = ddf[categories].astype('category')
    ddf = ddf.categorize(columns=categories)
    ddf = ddf.repartition(partition_size="128MB")
    return ddf

## Test Loading Dataset

In [8]:
cluster = coiled.Cluster(
    worker_vm_types=["m6i.4xlarge"],
    scheduler_vm_types=["m6i.2xlarge"],
    package_sync=True, # copy local packages,
    name="dask-engineering-f799f650-0",
    shutdown_on_close=True,  # reuse cluster across runs
    show_widget=False,
    n_workers=20,
    use_best_zone=True,
    account="dask-engineering",
    )
client = Client(cluster)

In [9]:
client

0,1
Connection method: Cluster object,Cluster type: coiled.ClusterBeta
Dashboard: http://3.144.205.198:8787,

0,1
Dashboard: http://3.144.205.198:8787,Workers: 20
Total threads: 320,Total memory: 1.19 TiB

0,1
Comm: tls://10.0.22.159:8786,Workers: 20
Dashboard: http://10.0.22.159:8787/status,Total threads: 320
Started: 44 minutes ago,Total memory: 1.19 TiB

0,1
Comm: tls://10.0.26.25:44441,Total threads: 16
Dashboard: http://10.0.26.25:8787/status,Memory: 60.87 GiB
Nanny: tls://10.0.26.25:40463,
Local directory: /scratch/dask-worker-space/worker-pbbrzm8d,Local directory: /scratch/dask-worker-space/worker-pbbrzm8d

0,1
Comm: tls://10.0.28.215:46583,Total threads: 16
Dashboard: http://10.0.28.215:8787/status,Memory: 60.88 GiB
Nanny: tls://10.0.28.215:34261,
Local directory: /scratch/dask-worker-space/worker-_t25wxl3,Local directory: /scratch/dask-worker-space/worker-_t25wxl3

0,1
Comm: tls://10.0.30.91:35215,Total threads: 16
Dashboard: http://10.0.30.91:8787/status,Memory: 60.87 GiB
Nanny: tls://10.0.30.91:34261,
Local directory: /scratch/dask-worker-space/worker-g3gheltg,Local directory: /scratch/dask-worker-space/worker-g3gheltg

0,1
Comm: tls://10.0.26.22:39929,Total threads: 16
Dashboard: http://10.0.26.22:8787/status,Memory: 60.88 GiB
Nanny: tls://10.0.26.22:38247,
Local directory: /scratch/dask-worker-space/worker-149qa39e,Local directory: /scratch/dask-worker-space/worker-149qa39e

0,1
Comm: tls://10.0.22.222:40961,Total threads: 16
Dashboard: http://10.0.22.222:8787/status,Memory: 60.88 GiB
Nanny: tls://10.0.22.222:36209,
Local directory: /scratch/dask-worker-space/worker-u5ncensb,Local directory: /scratch/dask-worker-space/worker-u5ncensb

0,1
Comm: tls://10.0.27.129:41107,Total threads: 16
Dashboard: http://10.0.27.129:8787/status,Memory: 60.86 GiB
Nanny: tls://10.0.27.129:40415,
Local directory: /scratch/dask-worker-space/worker-b7npp83k,Local directory: /scratch/dask-worker-space/worker-b7npp83k

0,1
Comm: tls://10.0.22.173:43619,Total threads: 16
Dashboard: http://10.0.22.173:8787/status,Memory: 60.88 GiB
Nanny: tls://10.0.22.173:33349,
Local directory: /scratch/dask-worker-space/worker-lqpf1khy,Local directory: /scratch/dask-worker-space/worker-lqpf1khy

0,1
Comm: tls://10.0.22.166:37913,Total threads: 16
Dashboard: http://10.0.22.166:8787/status,Memory: 60.87 GiB
Nanny: tls://10.0.22.166:36259,
Local directory: /scratch/dask-worker-space/worker-h18y8qwz,Local directory: /scratch/dask-worker-space/worker-h18y8qwz

0,1
Comm: tls://10.0.24.184:38675,Total threads: 16
Dashboard: http://10.0.24.184:8787/status,Memory: 60.87 GiB
Nanny: tls://10.0.24.184:45753,
Local directory: /scratch/dask-worker-space/worker-fulz1rgx,Local directory: /scratch/dask-worker-space/worker-fulz1rgx

0,1
Comm: tls://10.0.19.234:45163,Total threads: 16
Dashboard: http://10.0.19.234:8787/status,Memory: 60.88 GiB
Nanny: tls://10.0.19.234:39401,
Local directory: /scratch/dask-worker-space/worker-oovfephe,Local directory: /scratch/dask-worker-space/worker-oovfephe

0,1
Comm: tls://10.0.26.194:34789,Total threads: 16
Dashboard: http://10.0.26.194:8787/status,Memory: 60.87 GiB
Nanny: tls://10.0.26.194:37063,
Local directory: /scratch/dask-worker-space/worker-v4ssml07,Local directory: /scratch/dask-worker-space/worker-v4ssml07

0,1
Comm: tls://10.0.19.129:36575,Total threads: 16
Dashboard: http://10.0.19.129:8787/status,Memory: 60.87 GiB
Nanny: tls://10.0.19.129:41425,
Local directory: /scratch/dask-worker-space/worker-65bdqsw1,Local directory: /scratch/dask-worker-space/worker-65bdqsw1

0,1
Comm: tls://10.0.19.127:46237,Total threads: 16
Dashboard: http://10.0.19.127:8787/status,Memory: 60.87 GiB
Nanny: tls://10.0.19.127:40283,
Local directory: /scratch/dask-worker-space/worker-r8acvocg,Local directory: /scratch/dask-worker-space/worker-r8acvocg

0,1
Comm: tls://10.0.24.53:45997,Total threads: 16
Dashboard: http://10.0.24.53:8787/status,Memory: 60.88 GiB
Nanny: tls://10.0.24.53:35241,
Local directory: /scratch/dask-worker-space/worker-4gs_6hgh,Local directory: /scratch/dask-worker-space/worker-4gs_6hgh

0,1
Comm: tls://10.0.31.55:39901,Total threads: 16
Dashboard: http://10.0.31.55:8787/status,Memory: 60.89 GiB
Nanny: tls://10.0.31.55:35993,
Local directory: /scratch/dask-worker-space/worker-yzdlenxf,Local directory: /scratch/dask-worker-space/worker-yzdlenxf

0,1
Comm: tls://10.0.26.249:39305,Total threads: 16
Dashboard: http://10.0.26.249:8787/status,Memory: 60.87 GiB
Nanny: tls://10.0.26.249:46433,
Local directory: /scratch/dask-worker-space/worker-q03juphg,Local directory: /scratch/dask-worker-space/worker-q03juphg

0,1
Comm: tls://10.0.20.21:34173,Total threads: 16
Dashboard: http://10.0.20.21:8787/status,Memory: 60.87 GiB
Nanny: tls://10.0.20.21:38619,
Local directory: /scratch/dask-worker-space/worker-w41pcee2,Local directory: /scratch/dask-worker-space/worker-w41pcee2

0,1
Comm: tls://10.0.30.70:44813,Total threads: 16
Dashboard: http://10.0.30.70:8787/status,Memory: 60.87 GiB
Nanny: tls://10.0.30.70:33549,
Local directory: /scratch/dask-worker-space/worker-eacfqai7,Local directory: /scratch/dask-worker-space/worker-eacfqai7

0,1
Comm: tls://10.0.23.16:44797,Total threads: 16
Dashboard: http://10.0.23.16:8787/status,Memory: 60.88 GiB
Nanny: tls://10.0.23.16:36477,
Local directory: /scratch/dask-worker-space/worker-17z9mzaa,Local directory: /scratch/dask-worker-space/worker-17z9mzaa

0,1
Comm: tls://10.0.17.43:40225,Total threads: 16
Dashboard: http://10.0.17.43:8787/status,Memory: 60.89 GiB
Nanny: tls://10.0.17.43:39603,
Local directory: /scratch/dask-worker-space/worker-c2rvuoqu,Local directory: /scratch/dask-worker-space/worker-c2rvuoqu


In [10]:
ddf = load_data()
ddf = make_taxi_data(ddf)

loading data
Make accessible feature
Fraction of dataset left after removing outliers:  0.9842347215803948
Load taxi data
Make superboroughs
Make categoricals


In [11]:
ddf.head()

Unnamed: 0,hvfhs_license_num,PULocationID,DOLocationID,trip_miles,trip_time,tolls,congestion_surcharge,airport_fee,wav_request_flag,accessible_vehicle,pickup_month,pickup_dow,pickup_hour,PUBorough,DOBorough,CrossSuperborough
0,HV0003,47,152,4.32,1279,0.0,0.0,0.0,N,0,4,1,6,Bronx,Manhattan,N
1,HV0003,47,152,5.56,1547,0.0,0.0,0.0,N,0,4,1,7,Bronx,Manhattan,N
2,HV0003,47,152,5.49,1153,0.0,0.0,0.0,N,0,4,1,12,Bronx,Manhattan,N
3,HV0005,47,152,5.88,1080,0.02,0.0,0.0,N,1,4,1,12,Bronx,Manhattan,N
4,HV0003,47,152,6.53,1372,0.0,0.0,0.0,N,0,4,1,13,Bronx,Manhattan,N


In [12]:
ddf.columns.tolist()

['hvfhs_license_num',
 'PULocationID',
 'DOLocationID',
 'trip_miles',
 'trip_time',
 'tolls',
 'congestion_surcharge',
 'airport_fee',
 'wav_request_flag',
 'accessible_vehicle',
 'pickup_month',
 'pickup_dow',
 'pickup_hour',
 'PUBorough',
 'DOBorough',
 'CrossSuperborough']

In [13]:
ddf.dtypes

hvfhs_license_num       category
PULocationID            category
DOLocationID            category
trip_miles               float64
trip_time                  int64
tolls                    float64
congestion_surcharge     float64
airport_fee              float64
wav_request_flag        category
accessible_vehicle      category
pickup_month            category
pickup_dow              category
pickup_hour             category
PUBorough               category
DOBorough               category
CrossSuperborough       category
dtype: object

In [14]:
ddf.to_parquet("s3://prefect-dask-examples/nyc-uber-lyft/feature_table_fixed_upper_bound.parquet")

In [None]:
client.shutdown()

In [None]:
client.restart()