# XGBoost.Dask in many threads

Sometimes we want to train many large XGBoost models in parallel.  We do so in this example with ...

1.  The `xgboost.dask` project to do large training runs
2.  Optuna to do hyper-parameter-optimization
3.  A thread pool, to run many of these in parallel
4.  Coiled to launch Dask clusters (but you could swap in your favorite Dask deployment technology as you like)

Using `xgboost.dask` from many threads tooks a couple of small tweaks across projects.  This notebook resulted in the following PRs and issues:

-  https://github.com/dask/distributed/issues/7377
-  https://github.com/dask/dask/pull/9723
-  https://github.com/dask/distributed/pull/7369
-  https://github.com/dmlc/xgboost/pull/8558 (mostly cosmetic, not necessary)
-  Also something in Coiled to allow package_sync to be thread-safe, should be released by 2022-12-07

In [1]:
import datetime
import threading
from concurrent.futures import ThreadPoolExecutor

from distributed import Client
import dask.dataframe as dd
from coiled import Cluster
import coiled

import optuna
from sklearn.metrics import roc_auc_score
import xgboost as xgb
from xgboost.dask import DaskDMatrix

from dask_ml.datasets import make_classification_df
from dask_ml.model_selection import train_test_split, KFold
from dask_ml.preprocessing import OneHotEncoder

In [2]:
import dask, coiled
print("coiled:", coiled.__version__)
print("dask:", dask.__version__)
print("dask.distributed:", dask.distributed.__version__)
print("optuna:", optuna.__version__)
print("xgboost:", xgb.__version__)
print("coiled:", coiled.__version__)

coiled: 0.2.55
dask: 2022.12.0+13.g0d8e12be
dask.distributed: 2022.12.0+17.gf8302593
optuna: 3.0.4
xgboost: 1.7.2
coiled: 0.2.55


In [3]:
dd.read_parquet("s3://coiled-datasets/uber-lyft-tlc/*.parquet", use_nullable_dtypes=False).select_dtypes(exclude=['string', 'category', 'datetime64[ns]']).astype("float").dtypes

PULocationID            float64
DOLocationID            float64
trip_miles              float64
trip_time               float64
base_passenger_fare     float64
tolls                   float64
bcf                     float64
sales_tax               float64
congestion_surcharge    float64
airport_fee             float64
tips                    float64
driver_pay              float64
dtype: object

### Load data

In [14]:
import dask.dataframe as dd

def load_data():
    s3_uri = "s3://coiled-datasets/uber-lyft-tlc/*.parquet"
    nyc_taxi = (
        dd.read_parquet(s3_uri,) #use_nullable_dtypes=True)
            .select_dtypes(exclude=["string", "category"])
    )

    nyc_taxi["pickup_hour"] = nyc_taxi["pickup_datetime"].dt.hour

    cols = nyc_taxi.select_dtypes(include="datetime64[ns]").columns.tolist()
    nyc_taxi[cols] = nyc_taxi[cols].astype(int).div(1e9).astype(int)
    nyc_taxi["trip_time"] = nyc_taxi["dropoff_datetime"] - nyc_taxi["pickup_datetime"]
    nyc_taxi = nyc_taxi.drop(columns=[
        # "dropoff_datetime", # outcome
        "base_passenger_fare", # outcome
        "driver_pay",  # outcome
        "sales_tax", # outcome
        "airport_fee", # bad data (need to fix)
    ])
    nyc_taxi = nyc_taxi.astype("float")
    # nyc_taxi = nyc_taxi.reset_index(drop=True)
    # nyc_taxi = nyc_taxi.categorize(
    #     columns=[ ]
    # )
    # encoded_vars = []
    # for col in [["shared_request_flag"],["shared_match_flag"],["wav_request_flag"],["wav_match_flag"]]:
    #     temp = OneHotEncoder().fit_transform(nyc_taxi[col])
    #     temp = temp.reset_index().set_index("request_datetime")
    #     encoded_vars.append(temp)
        

    # nyc_taxi = nyc_taxi.drop(columns=cols2)
    # nyc_taxi = nyc_taxi.reset_index().set_index("request_datetime")
    # encoded = dd.concat(encoded_vars, axis=1)
    # nyc_taxi = dd.concat([nyc_taxi, encoded], axis=1)

    
    X = nyc_taxi.drop(columns=["trip_time"])
    y = nyc_taxi["trip_time"]
    return X.to_dask_array(lengths=True), y.to_dask_array(lengths=True)

In [15]:
train_options = dict(
    n_splits = 5 
)


In [28]:
def cv_estimate(trial_number, clf_params, n_splits=5):
    thread_id = threading.get_ident()
    with coiled.Cluster(
        package_sync=True, # copy local packages
        # name="xgb-nyc-taxi-" + str(thread_id), 
        name="xgb-nyc-taxi-11005882368",
        shutdown_on_close=False,  # reuse cluster across runs
        show_widget=False,
        n_workers=64,
        worker_memory="16 GiB",
        account="dask-engineering",
        backend_options={"region": "us-east-2", "spot": True, "spot_on_demand_fallback": True}
    ) as cluster:
#     with LocalCluster() as cluster:  # for testing
        with Client(cluster) as client:
            # with client.as_current():  # this should maybe go away.  See https://github.com/dask/distributed/issues/7377
            print(f"Trial {trial_number} thread {thread_id} Cluster dashboard {cluster.dashboard_link}")

            # Load data here
            X, y = load_data()
            X = X.persist()
            y = y.persist()
            cv = KFold(n_splits=5)

            val_scores = 0
            # for i, (train, test) in enumerate(cv.split(X, y)):
                # print(thread_id, f"Trial {trial_number} KFold {i} started")
            start = datetime.datetime.now()
            print(start)

            dtrain = DaskDMatrix(client, X, y, enable_categorical=True)
            print("created d train")
            # dtest = DaskDMatrix(client, X[test], y[test])#, enable_categorical=True)

            model = xgb.dask.train(
                client,
                {
                    'verbosity': 1,
                    'tree_method': 'hist', 
                    "objective": "reg:squarederror",
                    **clf_params
                },
                dtrain,
                num_boost_round=4, 
                evals=[(dtrain, 'train')],
                early_stopping_rounds=1
            )
            print("made model")

            predictions = xgb.dask.predict(client, model, dtest)
            # predictions = xgb.dask.predict(client, model["booster"], X) #X_test)
            print("made predictions")

            # actual, predictions = dask.compute(y[test], predictions)
            actual, predictions = dask.compute(y, predictions)
            assert actual.shape == predictions.shape, (actual.shape, predictions.shape)  # sometimes this is off.  Not sure why.

            score = roc_auc_score(actual, predictions)
            val_scores += score
            end = datetime.datetime.now()
            print(end)
            print(f"Trial {trial_number} thread {thread_id} KFold {i}, score: {score}, seconds {((end - start).total_seconds())}")
            print(f"Trial {trial_number} thread {thread_id} finished")

    return val_scores / n_splits

cv_estimate(1, {}, data_kwargs)

In [29]:
def objective(trial):
    params = {
        'n_estimators': trial.suggest_int('n_estimators', 5, 100),
        'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.99),
        'subsample': trial.suggest_float('subsample', 0.1, 0.9),
        'max_depth': trial.suggest_int('max_depth', 1, 10),
        'colsample_bytree': trial.suggest_float('colsample_bytree', 0.1, 0.9),
        'min_child_weight': trial.suggest_int('min_child_weight', 1, 9),
    }
    accuracy = cv_estimate(
        trial_number=trial.number,
        clf_params=params, 
        n_splits=train_options["n_splits"]
    ) 
    return accuracy

In [30]:

# create a single study
study = optuna.create_study()

executor = ThreadPoolExecutor(1)

futures = [
    executor.submit(study.optimize, objective, n_trials=1) for _ in range(1)
]

[32m[I 2022-12-17 10:42:40,036][0m A new study created in memory with name: no-name-7e46e2c3-3243-435d-8293-93065e7ae1b0[0m


Trial 0 thread 10888441856 Cluster dashboard http://13.59.164.209:8787
2022-12-17 10:43:12.536066
created d train


In [19]:
futures[0].result()

Trial 0 thread 10972573696 Cluster dashboard http://13.59.164.209:8787
2022-12-17 10:28:22.246330
created d train


KeyboardInterrupt: 

In [35]:
with coiled.Cluster(
    package_sync=True, # copy local packages
    # name="xgb-nyc-taxi-" + str(thread_id), 
    name="xgb-nyc-taxi-11005882368",
    shutdown_on_close=False,  # reuse cluster across runs
    show_widget=False,
    n_workers=64,
    worker_memory="16 GiB",
    account="dask-engineering",
    backend_options={"region": "us-east-2", "spot": True, "spot_on_demand_fallback": True}
) as cluster:
    client = Client(cluster)
    client.restart()

    X, y= load_data()
    X = X.persist()
    y = y.persist()
    print(X)

    cv = KFold(n_splits=5)

    val_scores = 0
    # for i, (train, test) in enumerate(cv.split(X, y)):
        # print(thread_id, f"Trial {trial_number} KFold {i} started")
    start = datetime.datetime.now()
    print(start)

    dtrain = DaskDMatrix(client, X, y)#, enable_categorical=True)
    print("created d train")
    # dtest = DaskDMatrix(client, X[test], y[test])#, enable_categorical=True)
    params = {
        # 'n_estimators': 5,
        'learning_rate': 0.99,
        'subsample':  0.9,
        'max_depth': 10,
        'colsample_bytree': 0.9,
        'min_child_weight': 9,
    }
    
    
    model = xgb.dask.train(
        client,
        {
            'verbosity': 1,
            # 'tree_method': 'hist', 
            "objective": "reg:squarederror",
            **params
        },
        dtrain,
        num_boost_round=4, 
        evals=[(dtrain, 'train')],
        early_stopping_rounds=1
    )
    print("made model")

Collecting git+https://github.com/dask/distributed.git@f83025935383d13033fe0dcd4af2ee689078cb40
Collecting git+https://github.com/dask/dask.git@0d8e12be4c2261b3457978c16aba7e893b1cf4a1
  Cloning https://github.com/dask/distributed.git (to revision f83025935383d13033fe0dcd4af2ee689078cb40) to /private/var/folders/b5/f_y899x168j7cs2m7szjld5c0000gn/T/pip-req-build-59xg9djo
  Cloning https://github.com/dask/dask.git (to revision 0d8e12be4c2261b3457978c16aba7e893b1cf4a1) to /private/var/folders/b5/f_y899x168j7cs2m7szjld5c0000gn/T/pip-req-build-4fs_xccf


  Running command git clone --filter=blob:none --quiet https://github.com/dask/dask.git /private/var/folders/b5/f_y899x168j7cs2m7szjld5c0000gn/T/pip-req-build-4fs_xccf
  Running command git clone --filter=blob:none --quiet https://github.com/dask/distributed.git /private/var/folders/b5/f_y899x168j7cs2m7szjld5c0000gn/T/pip-req-build-59xg9djo
  Running command git rev-parse -q --verify 'sha^f83025935383d13033fe0dcd4af2ee689078cb40'
  Running command git fetch -q https://github.com/dask/distributed.git f83025935383d13033fe0dcd4af2ee689078cb40
  Running command git checkout -q f83025935383d13033fe0dcd4af2ee689078cb40
  Running command git rev-parse -q --verify 'sha^0d8e12be4c2261b3457978c16aba7e893b1cf4a1'
  Running command git fetch -q https://github.com/dask/dask.git 0d8e12be4c2261b3457978c16aba7e893b1cf4a1


  Resolved https://github.com/dask/distributed.git to commit f83025935383d13033fe0dcd4af2ee689078cb40
  Installing build dependencies: started


  Running command git checkout -q 0d8e12be4c2261b3457978c16aba7e893b1cf4a1


  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Resolved https://github.com/dask/dask.git to commit 0d8e12be4c2261b3457978c16aba7e893b1cf4a1
  Installing build dependencies: started
  Getting requirements to build wheel: finished with status 'done'
  Preparing metadata (pyproject.toml): started
  Preparing metadata (pyproject.toml): finished with status 'done'
Building wheels for collected packages: distributed
  Building wheel for distributed (pyproject.toml): started
  Building wheel for distributed (pyproject.toml): finished with status 'done'
  Created wheel for distributed: filename=distributed-2022.12.0+17.gf8302593-py3-none-any.whl size=930327 sha256=836f2e928318995f782223a5f3ad1b374a620fe8c321a65dea1189ed2ea3fc80
  Stored in directory: /private/var/folders/b5/f_y899x168j7cs2m7szjld5c0000gn/T/pip-ephem-wheel-cache-p_4_6lg4/wheels/fd/af/11/58dd2291a58d74b51b5fd96dedfb871cc8bba17bef3e91f0b7
Successfully built distribute

2022-12-17 14:56:40,500 - distributed.client - ERROR - Failed to reconnect to scheduler after 30.00 seconds, closing client
