Skip to content

Commit

Permalink
[dask] Support all parameters in regressor and classifier. (#6471)
Browse files Browse the repository at this point in the history
* Add eval_metric.
* Add callback.
* Add feature weights.
* Add custom objective.
  • Loading branch information
trivialfis committed Dec 13, 2020
1 parent c31e3ef commit a30461c
Show file tree
Hide file tree
Showing 5 changed files with 348 additions and 91 deletions.
1 change: 0 additions & 1 deletion doc/tutorials/dask.rst
Original file line number Diff line number Diff line change
Expand Up @@ -326,4 +326,3 @@ addressed yet:
- Label encoding for the ``DaskXGBClassifier`` classifier may not be supported. So users need
to encode their training labels into discrete values first.
- Ranking is not yet supported.
- Callback functions are not tested.
172 changes: 136 additions & 36 deletions python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from .core import _deprecate_positional_args
from .training import train as worker_train
from .tracker import RabitTracker, get_host_ip
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase, _objective_decorator
from .sklearn import xgboost_model_doc


Expand All @@ -47,8 +47,6 @@
# not properly supported yet.
#
# TODOs:
# - Callback.
# - Label encoding.
# - CV
# - Ranking
#
Expand Down Expand Up @@ -184,6 +182,8 @@ class DaskDMatrix:
Upper bound for survival training.
label_upper_bound : dask.array.Array/dask.dataframe.DataFrame
Lower bound for survival training.
feature_weights : dask.array.Array/dask.dataframe.DataFrame
Weight for features used in column sampling.
feature_names : list, optional
Set names for features.
feature_types : list, optional
Expand All @@ -200,6 +200,7 @@ def __init__(self,
base_margin=None,
label_lower_bound=None,
label_upper_bound=None,
feature_weights=None,
feature_names=None,
feature_types=None):
_assert_dask_support()
Expand Down Expand Up @@ -227,14 +228,15 @@ def __init__(self,
self._init = client.sync(self.map_local_data,
client, data, label=label, weights=weight,
base_margin=base_margin,
feature_weights=feature_weights,
label_lower_bound=label_lower_bound,
label_upper_bound=label_upper_bound)

def __await__(self):
return self._init.__await__()

async def map_local_data(self, client, data, label=None, weights=None,
base_margin=None,
base_margin=None, feature_weights=None,
label_lower_bound=None, label_upper_bound=None):
'''Obtain references to local data.'''

Expand Down Expand Up @@ -328,6 +330,11 @@ def append_meta(m_parts, name: str):
self.worker_map = worker_map
self.meta_names = meta_names

if feature_weights is None:
self.feature_weights = None
else:
self.feature_weights = await client.compute(feature_weights).result()

return self

def create_fn_args(self, worker_addr: str):
Expand All @@ -337,6 +344,7 @@ def create_fn_args(self, worker_addr: str):
'''
return {'feature_names': self.feature_names,
'feature_types': self.feature_types,
'feature_weights': self.feature_weights,
'meta_names': self.meta_names,
'missing': self.missing,
'parts': self.worker_map.get(worker_addr, None),
Expand Down Expand Up @@ -518,6 +526,7 @@ def create_fn_args(self, worker_addr: str):


def _create_device_quantile_dmatrix(feature_names, feature_types,
feature_weights,
meta_names, missing, parts,
max_bin):
worker = distributed.get_worker()
Expand Down Expand Up @@ -546,10 +555,12 @@ def _create_device_quantile_dmatrix(feature_names, feature_types,
feature_types=feature_types,
nthread=worker.nthreads,
max_bin=max_bin)
dmatrix.set_info(feature_weights=feature_weights)
return dmatrix


def _create_dmatrix(feature_names, feature_types, meta_names, missing, parts):
def _create_dmatrix(feature_names, feature_types, feature_weights, meta_names, missing,
parts):
'''Get data that local to worker from DaskDMatrix.
Returns
Expand Down Expand Up @@ -590,7 +601,8 @@ def concat_or_none(data):
nthread=worker.nthreads)
dmatrix.set_info(base_margin=base_margin, weight=weights,
label_lower_bound=label_lower_bound,
label_upper_bound=label_upper_bound)
label_upper_bound=label_upper_bound,
feature_weights=feature_weights)
return dmatrix


Expand Down Expand Up @@ -627,16 +639,15 @@ def _get_workers_from_data(dtrain: DaskDMatrix, evals=()):
async def _train_async(client,
global_config,
params,
dtrain: DaskDMatrix,
*args,
evals=(),
early_stopping_rounds=None,
**kwargs):
if 'evals_result' in kwargs.keys():
raise ValueError(
'evals_result is not supported in dask interface.',
'The evaluation history is returned as result of training.')

dtrain,
num_boost_round,
evals,
obj,
feval,
early_stopping_rounds,
verbose_eval,
xgb_model,
callbacks):
workers = list(_get_workers_from_data(dtrain, evals))
_rabit_args = await _get_rabit_args(len(workers), client)

Expand Down Expand Up @@ -668,11 +679,15 @@ def dispatched_train(worker_addr, rabit_args, dtrain_ref, dtrain_idt, evals_ref)
local_param[p] = worker.nthreads
bst = worker_train(params=local_param,
dtrain=local_dtrain,
*args,
num_boost_round=num_boost_round,
evals_result=local_history,
evals=local_evals,
obj=obj,
feval=feval,
early_stopping_rounds=early_stopping_rounds,
**kwargs)
verbose_eval=verbose_eval,
xgb_model=xgb_model,
callbacks=callbacks)
ret = {'booster': bst, 'history': local_history}
if local_dtrain.num_row() == 0:
ret = None
Expand Down Expand Up @@ -703,8 +718,17 @@ def dispatched_train(worker_addr, rabit_args, dtrain_ref, dtrain_idt, evals_ref)
return list(filter(lambda ret: ret is not None, results))[0]


def train(client, params, dtrain, *args, evals=(), early_stopping_rounds=None,
**kwargs):
def train(client,
params,
dtrain,
num_boost_round=10,
evals=(),
obj=None,
feval=None,
early_stopping_rounds=None,
xgb_model=None,
verbose_eval=True,
callbacks=None):
'''Train XGBoost model.
.. versionadded:: 1.0.0
Expand Down Expand Up @@ -737,9 +761,19 @@ def train(client, params, dtrain, *args, evals=(), early_stopping_rounds=None,
# Get global configuration before transferring computation to another thread or
# process.
global_config = config.get_config()
return client.sync(
_train_async, client, global_config, params, dtrain=dtrain, *args, evals=evals,
early_stopping_rounds=early_stopping_rounds, **kwargs)
return client.sync(_train_async,
client=client,
global_config=global_config,
num_boost_round=num_boost_round,
obj=obj,
feval=feval,
params=params,
dtrain=dtrain,
evals=evals,
early_stopping_rounds=early_stopping_rounds,
verbose_eval=verbose_eval,
xgb_model=xgb_model,
callbacks=callbacks)


async def _direct_predict_impl(client, data, predict_fn):
Expand Down Expand Up @@ -1030,10 +1064,13 @@ def fit(self, X, y, *,
sample_weight=None,
base_margin=None,
eval_set=None,
eval_metric=None,
sample_weight_eval_set=None,
early_stopping_rounds=None,
verbose=True):
'''Fit the regressor.
verbose=True,
feature_weights=None,
callbacks=None):
'''Fit gradient boosting model
Parameters
----------
Expand All @@ -1047,14 +1084,31 @@ def fit(self, X, y, *,
A list of (X, y) tuple pairs to use as validation sets, for which
metrics will be computed.
Validation metrics will help us track the performance of the model.
eval_metric : str, list of str, or callable, optional
sample_weight_eval_set : list, optional
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list
of group weights on the i-th validation set.
early_stopping_rounds : int
Activates early stopping.
verbose : bool
If `verbose` and an evaluation set is used, writes the evaluation
metric measured on the validation set to stderr.'''
metric measured on the validation set to stderr.
feature_weights: array_like
Weight for each feature, defines the probability of each feature being
selected when colsample is being used. All values must be greater than 0,
otherwise a `ValueError` is thrown. Only available for `hist`, `gpu_hist` and
`exact` tree methods.
callbacks : list of callback functions
List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using :ref:`callback_api`.
Example:
.. code-block:: python
callbacks = [xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
save_best=True)]
'''
raise NotImplementedError

def predict(self, data): # pylint: disable=arguments-differ
Expand Down Expand Up @@ -1089,25 +1143,42 @@ def client(self, clt):
class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
# pylint: disable=missing-class-docstring
async def _fit_async(self, X, y, sample_weight, base_margin, eval_set,
sample_weight_eval_set, early_stopping_rounds,
verbose):
eval_metric, sample_weight_eval_set,
early_stopping_rounds, verbose, feature_weights,
callbacks):
dtrain = await DaskDMatrix(client=self.client,
data=X,
label=y,
weight=sample_weight,
base_margin=base_margin,
feature_weights=feature_weights,
missing=self.missing)
params = self.get_xgb_params()
evals = await _evaluation_matrices(self.client, eval_set,
sample_weight_eval_set,
self.missing)

if callable(self.objective):
obj = _objective_decorator(self.objective)
else:
obj = None
metric = eval_metric if callable(eval_metric) else None
if eval_metric is not None:
if callable(eval_metric):
eval_metric = None
else:
params.update({"eval_metric": eval_metric})

results = await train(client=self.client,
params=params,
dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(),
evals=evals,
feval=metric,
obj=obj,
verbose_eval=verbose,
early_stopping_rounds=early_stopping_rounds)
early_stopping_rounds=early_stopping_rounds,
callbacks=callbacks)
self._Booster = results['booster']
# pylint: disable=attribute-defined-outside-init
self.evals_result_ = results['history']
Expand All @@ -1122,19 +1193,25 @@ def fit(self,
sample_weight=None,
base_margin=None,
eval_set=None,
eval_metric=None,
sample_weight_eval_set=None,
early_stopping_rounds=None,
verbose=True):
verbose=True,
feature_weights=None,
callbacks=None):
_assert_dask_support()
return self.client.sync(self._fit_async,
X=X,
y=y,
sample_weight=sample_weight,
base_margin=base_margin,
eval_set=eval_set,
eval_metric=eval_metric,
sample_weight_eval_set=sample_weight_eval_set,
early_stopping_rounds=early_stopping_rounds,
verbose=verbose)
verbose=verbose,
feature_weights=feature_weights,
callbacks=callbacks)

async def _predict_async(
self, data, output_margin=False, base_margin=None):
Expand All @@ -1161,13 +1238,15 @@ def predict(self, data, output_margin=False, base_margin=None):
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
# pylint: disable=missing-class-docstring
async def _fit_async(self, X, y, sample_weight, base_margin, eval_set,
sample_weight_eval_set, early_stopping_rounds,
verbose):
eval_metric, sample_weight_eval_set,
early_stopping_rounds, verbose, feature_weights,
callbacks):
dtrain = await DaskDMatrix(client=self.client,
data=X,
label=y,
weight=sample_weight,
base_margin=base_margin,
feature_weights=feature_weights,
missing=self.missing)
params = self.get_xgb_params()

Expand All @@ -1187,13 +1266,28 @@ async def _fit_async(self, X, y, sample_weight, base_margin, eval_set,
evals = await _evaluation_matrices(self.client, eval_set,
sample_weight_eval_set,
self.missing)

if callable(self.objective):
obj = _objective_decorator(self.objective)
else:
obj = None
metric = eval_metric if callable(eval_metric) else None
if eval_metric is not None:
if callable(eval_metric):
eval_metric = None
else:
params.update({"eval_metric": eval_metric})

results = await train(client=self.client,
params=params,
dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(),
evals=evals,
obj=obj,
feval=metric,
verbose_eval=verbose,
early_stopping_rounds=early_stopping_rounds,
verbose_eval=verbose)
callbacks=callbacks)
self._Booster = results['booster']
# pylint: disable=attribute-defined-outside-init
self.evals_result_ = results['history']
Expand All @@ -1207,19 +1301,25 @@ def fit(self,
sample_weight=None,
base_margin=None,
eval_set=None,
eval_metric=None,
sample_weight_eval_set=None,
early_stopping_rounds=None,
verbose=True):
verbose=True,
feature_weights=None,
callbacks=None):
_assert_dask_support()
return self.client.sync(self._fit_async,
X=X,
y=y,
sample_weight=sample_weight,
base_margin=base_margin,
eval_set=eval_set,
eval_metric=eval_metric,
sample_weight_eval_set=sample_weight_eval_set,
early_stopping_rounds=early_stopping_rounds,
verbose=verbose)
verbose=verbose,
feature_weights=feature_weights,
callbacks=callbacks)

async def _predict_proba_async(self, data, output_margin=False,
base_margin=None):
Expand Down

0 comments on commit a30461c

Please sign in to comment.