# XGBoost.Dask in many threads

Sometimes we want to train many large XGBoost models in parallel.  We do so in this example with ...

1.  The `xgboost.dask` project to do large training runs
2.  Optuna to do hyper-parameter-optimization
3.  A thread pool, to run many of these in parallel
4.  Coiled to launch Dask clusters (but you could swap in your favorite Dask deployment technology as you like)

Using `xgboost.dask` from many threads tooks a couple of small tweaks across projects.  This notebook resulted in the following PRs and issues:

-  https://github.com/dask/distributed/issues/7377
-  https://github.com/dask/dask/pull/9723
-  https://github.com/dask/distributed/pull/7369
-  https://github.com/dmlc/xgboost/pull/8558 (mostly cosmetic, not necessary)
-  Also something in Coiled to allow package_sync to be thread-safe, should be released by 2022-12-07

In [1]:
import datetime
import threading
from concurrent.futures import ThreadPoolExecutor

from distributed import Client
import dask.dataframe as dd
from coiled import Cluster
import coiled

import optuna
from sklearn.metrics import mean_squared_error
from dask_ml.metrics import mean_squared_error as lazy_mse
import xgboost as xgb
from xgboost.dask import DaskDMatrix

from dask_ml.datasets import make_classification_df
from dask_ml.model_selection import train_test_split, KFold
from dask_ml.preprocessing import OneHotEncoder
import dask.array as da


In [2]:
import dask, coiled
print("coiled:", coiled.__version__)
print("dask:", dask.__version__)
print("dask.distributed:", dask.distributed.__version__)
print("optuna:", optuna.__version__)
print("xgboost:", xgb.__version__)
print("coiled:", coiled.__version__)

coiled: 0.2.58
dask: 2022.12.0+13.g0d8e12be
dask.distributed: 2022.12.0+17.gf8302593
optuna: 3.1.0.dev
xgboost: 1.7.2
coiled: 0.2.58


In [5]:
cluster_name = "eda"
cluster = coiled.Cluster(
    worker_vm_types=["m6i.4xlarge"],
    scheduler_vm_types=["m6i.2xlarge"],
    package_sync=True, # copy local packages,
    name=cluster_name,
    shutdown_on_close=False,  # reuse cluster across runs
    show_widget=False,
    n_workers=6,
    use_best_zone=True,
    account="dask-engineering",
    backend_options={"region": "us-east-2", "spot": True, "spot_on_demand_fallback": True}
    )

print("starting run")
client = Client(cluster)

starting run


### Load data

In [6]:
import dask.dataframe as dd
from s3fs import S3FileSystem


def load_data():
    start = datetime.datetime.now()
    print("loading data")
    to_exclude=["string", "category", "object"]
    ddf= dd.read_parquet("s3://prefect-dask-examples/nyc-uber-lyft/processed_files.parquet").select_dtypes(exclude=to_exclude)
    # ddf = ddf.drop(columns=["base_passenger_fare", "sales_tax", "bcf", "congestion_surcharge", "tips", "driver_pay", "dropoff_datetime"])
    ddf = ddf.assign(accessible_vehicle = 1)
    print("Make accessible feature")
    ddf.accessible_vehicle = ddf.accessible_vehicle.where(ddf.on_scene_datetime.isnull(),0)  # Only applies if the vehicle is wheelchair accessible
    ddf = ddf.assign(request_dow = ddf.request_datetime.dt.dayofweek)
    ddf = ddf.assign(pickup_datetime_dow = ddf.pickup_datetime.dt.dayofweek)
    ddf = ddf.assign(request_hour = ddf.request_datetime.dt.hour)
    ddf = ddf.assign(pickup_datetime_hour = ddf.pickup_datetime.dt.hour)
    ddf = ddf.drop(columns=['on_scene_datetime', 'request_datetime', 'pickup_datetime'])

    ddf = ddf.dropna(how="any")
    ddf = ddf.repartition(partition_size="128MB")
    ddf = ddf.reset_index(drop=True)

    categories = ["request_dow", "request_hour", "pickup_datetime_hour", "pickup_datetime_dow"]
    for cat in categories:
        ddf[cat] = ddf[cat].astype('category')

    # Ideally we would categorize the data here, but splitting
    # causes us to lose that information, so its a wasted operation

    print(f"Completed data preprocessing in {datetime.datetime.now() - start} with {len(ddf.index)} rows")
    return ddf
ddf = load_data()

loading data
Make accessible feature
Completed data preprocessing in 0:01:37.987823 with 726579128 rows


In [7]:
ddf.head()

Unnamed: 0,dropoff_datetime,PULocationID,DOLocationID,trip_miles,trip_time,base_passenger_fare,tolls,bcf,sales_tax,congestion_surcharge,tips,driver_pay,accessible_vehicle,request_dow,pickup_datetime_dow,request_hour,pickup_datetime_hour
0,2019-04-23 07:21:17,47,152,4.32,1279,17.05,0.0,0.43,1.51,0.0,0.0,15.33,0,1,1,6,6
1,2019-04-23 06:59:44,230,226,4.55,860,19.78,0.0,0.49,1.75,2.75,7.0,19.11,0,1,1,6,6
2,2019-04-23 06:59:58,249,161,2.61,809,13.24,0.0,0.33,1.17,2.75,1.0,10.57,0,1,1,6,6
3,2019-04-23 06:57:58,23,23,2.4,420,8.24,0.0,0.21,0.73,0.0,0.0,6.41,1,1,1,6,6
4,2019-04-23 06:39:26,23,23,2.88,504,9.98,0.0,0.25,0.89,0.0,1.0,8.33,0,1,1,6,6


In [13]:
len(ddf['PULocationID'].unique().compute())

0       47
1      230
2      249
3       23
4       39
      ... 
258      2
259    110
260    199
261    264
262    105
Name: PULocationID, Length: 263, dtype: int64

In [14]:
ddf['DOLocationID'].unique().compute()

0      152
1      226
2      161
3       23
4       39
      ... 
259    264
260    110
261    104
262    105
263    199
Name: DOLocationID, Length: 264, dtype: int64

2022-12-23 22:26:52,149 - distributed.client - ERROR - Failed to reconnect to scheduler after 30.00 seconds, closing client
Traceback (most recent call last):
  File "/Users/greghayes/mambaforge/envs/xgboost_test/lib/python3.10/site-packages/distributed/comm/tcp.py", line 498, in connect
    stream = await self.client.connect(
  File "/Users/greghayes/mambaforge/envs/xgboost_test/lib/python3.10/site-packages/tornado/tcpclient.py", line 275, in connect
    af, addr, stream = await connector.start(connect_timeout=timeout)
asyncio.exceptions.CancelledError

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/greghayes/mambaforge/envs/xgboost_test/lib/python3.10/asyncio/tasks.py", line 456, in wait_for
    return fut.result()
asyncio.exceptions.CancelledError

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/greghayes/mambaforge/envs/xgboost_test/lib/pyt