# XGBoost.Dask in many threads

Sometimes we want to train many large XGBoost models in parallel.  We do so in this example with ...

1.  The `xgboost.dask` project to do large training runs
2.  Optuna to do hyper-parameter-optimization
3.  A thread pool, to run many of these in parallel
4.  Coiled to launch Dask clusters (but you could swap in your favorite Dask deployment technology as you like)

Using `xgboost.dask` from many threads tooks a couple of small tweaks across projects.  This notebook resulted in the following PRs and issues:

-  https://github.com/dask/distributed/issues/7377
-  https://github.com/dask/dask/pull/9723
-  https://github.com/dask/distributed/pull/7369
-  https://github.com/dmlc/xgboost/pull/8558 (mostly cosmetic, not necessary)
-  Also something in Coiled to allow package_sync to be thread-safe, should be released by 2022-12-07

In [1]:
import datetime
import threading
from concurrent.futures import ThreadPoolExecutor

from distributed import Client
import dask.dataframe as dd
from coiled import Cluster
import coiled

import optuna
# from sklearn.metrics import mean_squared_error
from dask_ml.metrics import mean_squared_error as lazy_mse
import xgboost as xgb
from xgboost.dask import DaskDMatrix

from dask_ml.datasets import make_classification_df
from dask_ml.model_selection import train_test_split, KFold
from dask_ml.preprocessing import OneHotEncoder
import dask.array as da
import dask.dataframe as dd
from s3fs import S3FileSystem

import pandas as pd

In [2]:
import dask, coiled
print("coiled:", coiled.__version__)
print("dask:", dask.__version__)
print("dask.distributed:", dask.distributed.__version__)
print("optuna:", optuna.__version__)
print("xgboost:", xgb.__version__)
print("coiled:", coiled.__version__)

coiled: 0.2.58
dask: 2022.12.1
dask.distributed: 2022.12.1
optuna: 3.1.0.dev
xgboost: 1.7.2
coiled: 0.2.58


### Load data

In [3]:
BOROUGH_MAPPING = {
    "Manhattan": "Superborough 1",
    "Bronx": "Superborough 1",
    "EWR": "Superborough 1",
    "Brooklyn": "Superborough 2",
    "Queens": "Superborough 2",
    "Staten Island": "Superborough 3",
    "Unknown": "Unknown",
    }

In [4]:
def load_data():
    print("loading data")
    to_exclude=["string", "category", "object"]
    ddf= dd.read_parquet("s3://prefect-dask-examples/nyc-uber-lyft/processed_files.parquet")
    ddf = ddf.assign(accessible_vehicle = 1)
    print("Make accessible feature")
    ddf.accessible_vehicle = ddf.accessible_vehicle.where(ddf.on_scene_datetime.isnull(),0)  # Only applies if the vehicle is wheelchair accessible
    ddf = ddf.assign(pickup_month = ddf.pickup_datetime.dt.month)
    ddf = ddf.assign(pickup_dow = ddf.pickup_datetime.dt.dayofweek)
    ddf = ddf.assign(pickup_hour = ddf.pickup_datetime.dt.hour)
    
    ddf = ddf.drop(columns=['on_scene_datetime', 'request_datetime',
                            'pickup_datetime', 'dispatching_base_num',
                            'originating_base_num', 'shared_request_flag',
                           'shared_match_flag','dropoff_datetime',
                           ]
                  )

    ddf = ddf.dropna(how="any")
    ddf = ddf.repartition(partition_size="128MB")
    ddf = ddf.reset_index(drop=True)

    original_rowcount = len(ddf.index)

    # Remove outliers
    # Based on our earlier EDA, we will set the lower bound at zero, which is consistent with our
    # domain knowledge that no trip should have a duration less than zero.  We calculate the upper_bound
    # and filter the IQR
    lower_bound = 0
    Q3 = ddf['trip_time'].quantile(0.75)
    upper_bound = Q3 + (1.5*(Q3 - lower_bound))
    
    ddf = ddf.loc[(ddf['trip_time'] >= lower_bound) & (ddf['trip_time'] <= upper_bound)]
    
    ddf = ddf.repartition(partition_size="128MB")
    print(f"Fraction of dataset left after removing outliers:  {len(ddf.index) / original_rowcount}")

    return ddf

In [5]:
def get_superborough(df):
    PUSuperborough = [BOROUGH_MAPPING.get(i) for i in df.PUBorough.tolist()]
    DOSuperborough = [BOROUGH_MAPPING.get(i) for i in df.DOBorough.tolist()]
    cross_superborough = ["N" if i==j else "Y" for (i,j) in zip(PUSuperborough, DOSuperborough)]
    return df.assign(CrossSuperborough = cross_superborough)

In [6]:
def make_taxi_data(ddf):
    print("Load taxi data")
    taxi_df = pd.read_csv("data/taxi+_zone_lookup.csv", usecols=["LocationID", "Borough"])

    ddf = dd.merge(ddf, taxi_df, left_on="PULocationID", right_on="LocationID", how="inner")
    ddf = ddf.rename(columns={"Borough": "PUBorough"})
    ddf = ddf.drop(columns="LocationID")

    ddf = dd.merge(ddf, taxi_df, left_on="DOLocationID", right_on="LocationID", how="inner")
    ddf = ddf.rename(columns={"Borough": "DOBorough"})
    ddf = ddf.drop(columns="LocationID")  
    
    print("Make superboroughs")
    ddf = ddf.map_partitions(lambda df: get_superborough(df))
    ddf['airport_fee'] = ddf['airport_fee'].replace("None", 0)
    ddf['airport_fee'] = ddf['airport_fee'].replace('nan', 0)
    ddf['airport_fee'] = ddf['airport_fee'].astype(float)
    ddf['airport_fee'] = ddf['airport_fee'].fillna(0)

    print("Drop unneeded cols")
    to_drop = ['base_passenger_fare', 'bcf', 'sales_tax', 'tips',
               'driver_pay', 'access_a_ride_flag', 'wav_match_flag'
              ]
    ddf2 = ddf.drop(columns=to_drop)
    ddf2 = ddf2.repartition(partition_size="100MB")

    print("Make categoricals")
    categories = ['hvfhs_license_num', 'PULocationID', "DOLocationID", 'wav_request_flag',
                  'accessible_vehicle', 'pickup_month', 'pickup_dow', 'pickup_hour', 
                  'PUBorough', 'DOBorough', 'CrossSuperborough'
                 ]
    ddf2[categories] = ddf2[categories].astype('category')
    ddf2 = ddf2.categorize(columns=categories)
    ddf2 = ddf2.repartition(partition_size="128MB")
    return ddf2

## Test Loading Dataset

In [7]:
cluster = coiled.Cluster(
    worker_vm_types=["m6i.4xlarge"],
    scheduler_vm_types=["m6i.2xlarge"],
    package_sync=True, # copy local packages,
    name="dask-engineering-f799f650-0",
    shutdown_on_close=True,  # reuse cluster across runs
    show_widget=False,
    n_workers=20,
    use_best_zone=True,
    account="dask-engineering",
    )
client = Client(cluster)

In [8]:
client

0,1
Connection method: Cluster object,Cluster type: coiled.ClusterBeta
Dashboard: http://18.218.247.182:8787,

0,1
Dashboard: http://18.218.247.182:8787,Workers: 20
Total threads: 320,Total memory: 1.19 TiB

0,1
Comm: tls://10.0.23.186:8786,Workers: 20
Dashboard: http://10.0.23.186:8787/status,Total threads: 320
Started: 40 minutes ago,Total memory: 1.19 TiB

0,1
Comm: tls://10.0.28.144:36981,Total threads: 16
Dashboard: http://10.0.28.144:8787/status,Memory: 60.88 GiB
Nanny: tls://10.0.28.144:33985,
Local directory: /scratch/dask-worker-space/worker-yxwho9rf,Local directory: /scratch/dask-worker-space/worker-yxwho9rf

0,1
Comm: tls://10.0.20.20:44805,Total threads: 16
Dashboard: http://10.0.20.20:8787/status,Memory: 60.87 GiB
Nanny: tls://10.0.20.20:45369,
Local directory: /scratch/dask-worker-space/worker-o1gqzkex,Local directory: /scratch/dask-worker-space/worker-o1gqzkex

0,1
Comm: tls://10.0.24.244:44117,Total threads: 16
Dashboard: http://10.0.24.244:8787/status,Memory: 60.88 GiB
Nanny: tls://10.0.24.244:38497,
Local directory: /scratch/dask-worker-space/worker-69ev3sjj,Local directory: /scratch/dask-worker-space/worker-69ev3sjj

0,1
Comm: tls://10.0.18.90:44099,Total threads: 16
Dashboard: http://10.0.18.90:8787/status,Memory: 60.88 GiB
Nanny: tls://10.0.18.90:34685,
Local directory: /scratch/dask-worker-space/worker-vtq9y245,Local directory: /scratch/dask-worker-space/worker-vtq9y245

0,1
Comm: tls://10.0.25.209:43761,Total threads: 16
Dashboard: http://10.0.25.209:8787/status,Memory: 60.88 GiB
Nanny: tls://10.0.25.209:35333,
Local directory: /scratch/dask-worker-space/worker-tu2khink,Local directory: /scratch/dask-worker-space/worker-tu2khink

0,1
Comm: tls://10.0.17.123:34723,Total threads: 16
Dashboard: http://10.0.17.123:8787/status,Memory: 60.88 GiB
Nanny: tls://10.0.17.123:40047,
Local directory: /scratch/dask-worker-space/worker-9u89_e9h,Local directory: /scratch/dask-worker-space/worker-9u89_e9h

0,1
Comm: tls://10.0.21.214:36721,Total threads: 16
Dashboard: http://10.0.21.214:8787/status,Memory: 60.88 GiB
Nanny: tls://10.0.21.214:41275,
Local directory: /scratch/dask-worker-space/worker-pps9wwsq,Local directory: /scratch/dask-worker-space/worker-pps9wwsq

0,1
Comm: tls://10.0.27.194:36739,Total threads: 16
Dashboard: http://10.0.27.194:8787/status,Memory: 60.89 GiB
Nanny: tls://10.0.27.194:37029,
Local directory: /scratch/dask-worker-space/worker-jmyx8i8u,Local directory: /scratch/dask-worker-space/worker-jmyx8i8u

0,1
Comm: tls://10.0.25.34:44913,Total threads: 16
Dashboard: http://10.0.25.34:8787/status,Memory: 60.89 GiB
Nanny: tls://10.0.25.34:37733,
Local directory: /scratch/dask-worker-space/worker-9ih3qtpm,Local directory: /scratch/dask-worker-space/worker-9ih3qtpm

0,1
Comm: tls://10.0.28.22:39137,Total threads: 16
Dashboard: http://10.0.28.22:8787/status,Memory: 60.88 GiB
Nanny: tls://10.0.28.22:37399,
Local directory: /scratch/dask-worker-space/worker-jr5evozx,Local directory: /scratch/dask-worker-space/worker-jr5evozx

0,1
Comm: tls://10.0.22.118:41953,Total threads: 16
Dashboard: http://10.0.22.118:8787/status,Memory: 60.88 GiB
Nanny: tls://10.0.22.118:43831,
Local directory: /scratch/dask-worker-space/worker-jp1_uoai,Local directory: /scratch/dask-worker-space/worker-jp1_uoai

0,1
Comm: tls://10.0.19.92:43115,Total threads: 16
Dashboard: http://10.0.19.92:8787/status,Memory: 60.87 GiB
Nanny: tls://10.0.19.92:35469,
Local directory: /scratch/dask-worker-space/worker-mptp3zw4,Local directory: /scratch/dask-worker-space/worker-mptp3zw4

0,1
Comm: tls://10.0.26.30:45397,Total threads: 16
Dashboard: http://10.0.26.30:8787/status,Memory: 60.88 GiB
Nanny: tls://10.0.26.30:45607,
Local directory: /scratch/dask-worker-space/worker-q00osif8,Local directory: /scratch/dask-worker-space/worker-q00osif8

0,1
Comm: tls://10.0.18.196:46247,Total threads: 16
Dashboard: http://10.0.18.196:8787/status,Memory: 60.88 GiB
Nanny: tls://10.0.18.196:32911,
Local directory: /scratch/dask-worker-space/worker-9q57gxp7,Local directory: /scratch/dask-worker-space/worker-9q57gxp7

0,1
Comm: tls://10.0.22.215:34633,Total threads: 16
Dashboard: http://10.0.22.215:8787/status,Memory: 60.88 GiB
Nanny: tls://10.0.22.215:45363,
Local directory: /scratch/dask-worker-space/worker-h6p3onj_,Local directory: /scratch/dask-worker-space/worker-h6p3onj_

0,1
Comm: tls://10.0.28.185:43543,Total threads: 16
Dashboard: http://10.0.28.185:8787/status,Memory: 60.88 GiB
Nanny: tls://10.0.28.185:37311,
Local directory: /scratch/dask-worker-space/worker-p275lyur,Local directory: /scratch/dask-worker-space/worker-p275lyur

0,1
Comm: tls://10.0.19.171:39505,Total threads: 16
Dashboard: http://10.0.19.171:8787/status,Memory: 60.88 GiB
Nanny: tls://10.0.19.171:42095,
Local directory: /scratch/dask-worker-space/worker-2bis4fzh,Local directory: /scratch/dask-worker-space/worker-2bis4fzh

0,1
Comm: tls://10.0.31.107:33843,Total threads: 16
Dashboard: http://10.0.31.107:8787/status,Memory: 60.88 GiB
Nanny: tls://10.0.31.107:43527,
Local directory: /scratch/dask-worker-space/worker-djcp6u7m,Local directory: /scratch/dask-worker-space/worker-djcp6u7m

0,1
Comm: tls://10.0.22.9:37807,Total threads: 16
Dashboard: http://10.0.22.9:8787/status,Memory: 60.88 GiB
Nanny: tls://10.0.22.9:42829,
Local directory: /scratch/dask-worker-space/worker-xnselssv,Local directory: /scratch/dask-worker-space/worker-xnselssv

0,1
Comm: tls://10.0.19.226:43723,Total threads: 16
Dashboard: http://10.0.19.226:8787/status,Memory: 60.88 GiB
Nanny: tls://10.0.19.226:40539,
Local directory: /scratch/dask-worker-space/worker-jc60kzlq,Local directory: /scratch/dask-worker-space/worker-jc60kzlq


In [9]:
ddf = load_data()
ddf = make_taxi_data(ddf)

loading data
Make accessible feature
Fraction of dataset left after removing outliers:  0.9976846042771939
Load taxi data
Make superboroughs
Drop unneeded cols
Make categoricals


In [10]:
ddf.head()

Unnamed: 0,hvfhs_license_num,PULocationID,DOLocationID,trip_miles,trip_time,tolls,congestion_surcharge,airport_fee,wav_request_flag,accessible_vehicle,pickup_month,pickup_dow,pickup_hour,PUBorough,DOBorough,CrossSuperborough
0,HV0003,179,107,5.77,2566,0.0,2.75,0.0,N,0,4,1,8,Queens,Manhattan,Y
1,HV0003,179,107,7.11,1657,0.0,0.75,0.0,N,0,4,6,21,Queens,Manhattan,Y
2,HV0003,107,107,0.69,295,0.0,2.75,0.0,N,0,4,1,21,Manhattan,Manhattan,N
3,HV0003,107,107,2.91,365,0.0,2.75,0.0,N,0,4,1,22,Manhattan,Manhattan,N
4,HV0003,107,107,0.37,258,0.0,2.75,0.0,N,0,4,1,22,Manhattan,Manhattan,N


In [11]:
ddf.columns.tolist()

['hvfhs_license_num',
 'PULocationID',
 'DOLocationID',
 'trip_miles',
 'trip_time',
 'tolls',
 'congestion_surcharge',
 'airport_fee',
 'wav_request_flag',
 'accessible_vehicle',
 'pickup_month',
 'pickup_dow',
 'pickup_hour',
 'PUBorough',
 'DOBorough',
 'CrossSuperborough']

In [12]:
ddf.dtypes

hvfhs_license_num       category
PULocationID            category
DOLocationID            category
trip_miles               float64
trip_time                  int64
tolls                    float64
congestion_surcharge     float64
airport_fee              float64
wav_request_flag        category
accessible_vehicle      category
pickup_month            category
pickup_dow              category
pickup_hour             category
PUBorough               category
DOBorough               category
CrossSuperborough       category
dtype: object

In [13]:
ddf.to_parquet("s3://prefect-dask-examples/nyc-uber-lyft/feature_table.parquet", overwrite=True)

In [14]:
client.shutdown()

2023-01-01 08:45:10,304 - distributed.client - ERROR - Failed to reconnect to scheduler after 30.00 seconds, closing client
