Skip to content

Commit

Permalink
[dask] Random forest estimators (#6602)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jan 13, 2021
1 parent 0027220 commit 89a00a5
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 84 deletions.
6 changes: 6 additions & 0 deletions doc/python/python_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,9 @@ Dask API
.. autofunction:: xgboost.dask.DaskXGBClassifier

.. autofunction:: xgboost.dask.DaskXGBRegressor

.. autofunction:: xgboost.dask.DaskXGBRanker

.. autofunction:: xgboost.dask.DaskXGBRFRegressor

.. autofunction:: xgboost.dask.DaskXGBRFClassifier
93 changes: 84 additions & 9 deletions python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
from .core import _deprecate_positional_args
from .training import train as worker_train
from .tracker import RabitTracker, get_host_ip
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase, _objective_decorator
from .sklearn import xgboost_model_doc
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase
from .sklearn import xgboost_model_doc, _objective_decorator
from .sklearn import _cls_predict_proba
from .sklearn import XGBRanker

Expand Down Expand Up @@ -1262,7 +1262,6 @@ class DaskScikitLearnBase(XGBModel):

_client = None

# pylint: disable=arguments-differ
@_deprecate_positional_args
async def _predict_async(
self, data: _DaskCollection,
Expand All @@ -1282,7 +1281,7 @@ async def _predict_async(

def predict(
self,
data: _DaskCollection,
X: _DaskCollection,
output_margin: bool = False,
ntree_limit: Optional[int] = None,
validate_features: bool = True,
Expand All @@ -1291,10 +1290,13 @@ def predict(
_assert_dask_support()
msg = '`ntree_limit` is not supported on dask, use model slicing instead.'
assert ntree_limit is None, msg
return self.client.sync(self._predict_async, data,
output_margin=output_margin,
validate_features=validate_features,
base_margin=base_margin)
return self.client.sync(
self._predict_async,
X,
output_margin=output_margin,
validate_features=validate_features,
base_margin=base_margin
)

def __await__(self) -> Awaitable[Any]:
# Generate a coroutine wrapper to make this class awaitable.
Expand Down Expand Up @@ -1586,7 +1588,8 @@ async def _predict_async(
""",
)
class DaskXGBRanker(DaskScikitLearnBase):
def __init__(self, objective: str = "rank:pairwise", **kwargs: Any):
@_deprecate_positional_args
def __init__(self, *, objective: str = "rank:pairwise", **kwargs: Any):
if callable(objective):
raise ValueError("Custom objective function not supported by XGBRanker.")
super().__init__(objective=objective, kwargs=kwargs)
Expand Down Expand Up @@ -1698,3 +1701,75 @@ def fit( # pylint: disable=arguments-differ

# FIXME(trivialfis): arguments differ due to additional parameters like group and qid.
fit.__doc__ = XGBRanker.fit.__doc__


@xgboost_model_doc(
"Implementation of the Scikit-Learn API for XGBoost Random Forest Regressor.",
["model", "objective"],
extra_parameters="""
n_estimators : int
Number of trees in random forest to fit.
""",
)
class DaskXGBRFRegressor(DaskXGBRegressor):
@_deprecate_positional_args
def __init__(
self,
*,
learning_rate: Optional[float] = 1,
subsample: Optional[float] = 0.8,
colsample_bynode: Optional[float] = 0.8,
reg_lambda: Optional[float] = 1e-5,
**kwargs: Any
) -> None:
super().__init__(
learning_rate=learning_rate,
subsample=subsample,
colsample_bynode=colsample_bynode,
reg_lambda=reg_lambda,
**kwargs
)

def get_xgb_params(self) -> Dict[str, Any]:
params = super().get_xgb_params()
params["num_parallel_tree"] = self.n_estimators
return params

def get_num_boosting_rounds(self) -> int:
return 1


@xgboost_model_doc(
"Implementation of the Scikit-Learn API for XGBoost Random Forest Classifier.",
["model", "objective"],
extra_parameters="""
n_estimators : int
Number of trees in random forest to fit.
""",
)
class DaskXGBRFClassifier(DaskXGBClassifier):
@_deprecate_positional_args
def __init__(
self,
*,
learning_rate: Optional[float] = 1,
subsample: Optional[float] = 0.8,
colsample_bynode: Optional[float] = 0.8,
reg_lambda: Optional[float] = 1e-5,
**kwargs: Any
) -> None:
super().__init__(
learning_rate=learning_rate,
subsample=subsample,
colsample_bynode=colsample_bynode,
reg_lambda=reg_lambda,
**kwargs
)

def get_xgb_params(self) -> Dict[str, Any]:
params = super().get_xgb_params()
params["num_parallel_tree"] = self.n_estimators
return params

def get_num_boosting_rounds(self) -> int:
return 1
4 changes: 2 additions & 2 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def inner(preds, dmatrix):
node of the tree.
min_child_weight : float
Minimum sum of instance weight(hessian) needed in a child.
max_delta_step : int
max_delta_step : float
Maximum delta step we allow each tree's weight estimation to be.
subsample : float
Subsample ratio of the training instance.
Expand Down Expand Up @@ -1465,7 +1465,7 @@ def fit(
xgb_model = xgb_model._Booster # pylint: disable=protected-access

self._Booster = train(params, train_dmatrix,
self.n_estimators,
self.get_num_boosting_rounds(),
early_stopping_rounds=early_stopping_rounds,
evals=evals,
evals_result=evals_result, feval=feval,
Expand Down
173 changes: 100 additions & 73 deletions tests/python/test_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
if hasattr(HealthCheck, 'function_scoped_fixture'):
suppress = [HealthCheck.function_scoped_fixture]
else:
suppress = hypothesis.utils.conventions.not_set
suppress = hypothesis.utils.conventions.not_set # type:ignore


kRows = 1000
Expand Down Expand Up @@ -264,100 +264,127 @@ def test_dask_missing_value_cls() -> None:
assert hasattr(cls, 'missing')


def test_dask_regressor() -> None:
with LocalCluster(n_workers=kWorkers) as cluster:
with Client(cluster) as client:
X, y, w = generate_array(with_weights=True)
regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
assert regressor._estimator_type == "regressor"
assert sklearn.base.is_regressor(regressor)
@pytest.mark.parametrize("model", ["boosting", "rf"])
def test_dask_regressor(model: str, client: "Client") -> None:
X, y, w = generate_array(with_weights=True)
if model == "boosting":
regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
else:
regressor = xgb.dask.DaskXGBRFRegressor(verbosity=1, n_estimators=2)

regressor.set_params(tree_method='hist')
regressor.client = client
regressor.fit(X, y, sample_weight=w, eval_set=[(X, y)])
prediction = regressor.predict(X)
assert regressor._estimator_type == "regressor"
assert sklearn.base.is_regressor(regressor)

assert prediction.ndim == 1
assert prediction.shape[0] == kRows
regressor.set_params(tree_method='hist')
regressor.client = client
regressor.fit(X, y, sample_weight=w, eval_set=[(X, y)])
prediction = regressor.predict(X)

history = regressor.evals_result()
assert prediction.ndim == 1
assert prediction.shape[0] == kRows

assert isinstance(prediction, da.Array)
assert isinstance(history, dict)
history = regressor.evals_result()

assert list(history['validation_0'].keys())[0] == 'rmse'
assert len(history['validation_0']['rmse']) == 2
assert isinstance(prediction, da.Array)
assert isinstance(history, dict)

assert list(history['validation_0'].keys())[0] == 'rmse'
forest = int(
json.loads(regressor.get_booster().save_config())["learner"][
"gradient_booster"
]["gbtree_train_param"]["num_parallel_tree"]
)

def test_dask_classifier() -> None:
with LocalCluster(n_workers=kWorkers) as cluster:
with Client(cluster) as client:
X, y, w = generate_array(with_weights=True)
y = (y * 10).astype(np.int32)
classifier = xgb.dask.DaskXGBClassifier(
verbosity=1, n_estimators=2, eval_metric='merror')
assert classifier._estimator_type == "classifier"
assert sklearn.base.is_classifier(classifier)
if model == "boosting":
assert len(history['validation_0']['rmse']) == 2
assert forest == 1
else:
assert len(history['validation_0']['rmse']) == 1
assert forest == 2

classifier.client = client
classifier.fit(X, y, sample_weight=w, eval_set=[(X, y)])
prediction = classifier.predict(X)

assert prediction.ndim == 1
assert prediction.shape[0] == kRows
@pytest.mark.parametrize("model", ["boosting", "rf"])
def test_dask_classifier(model: str, client: "Client") -> None:
X, y, w = generate_array(with_weights=True)
y = (y * 10).astype(np.int32)
if model == "boosting":
classifier = xgb.dask.DaskXGBClassifier(
verbosity=1, n_estimators=2, eval_metric="merror"
)
else:
classifier = xgb.dask.DaskXGBRFClassifier(
verbosity=1, n_estimators=2, eval_metric="merror"
)

assert classifier._estimator_type == "classifier"
assert sklearn.base.is_classifier(classifier)

classifier.client = client
classifier.fit(X, y, sample_weight=w, eval_set=[(X, y)])
prediction = classifier.predict(X)

history = classifier.evals_result()
assert prediction.ndim == 1
assert prediction.shape[0] == kRows

assert isinstance(prediction, da.Array)
assert isinstance(history, dict)
history = classifier.evals_result()

assert list(history.keys())[0] == 'validation_0'
assert list(history['validation_0'].keys())[0] == 'merror'
assert len(list(history['validation_0'])) == 1
assert len(history['validation_0']['merror']) == 2
assert isinstance(prediction, da.Array)
assert isinstance(history, dict)

# Test .predict_proba()
probas = classifier.predict_proba(X)
assert classifier.n_classes_ == 10
assert probas.ndim == 2
assert probas.shape[0] == kRows
assert probas.shape[1] == 10
assert list(history.keys())[0] == "validation_0"
assert list(history["validation_0"].keys())[0] == "merror"
assert len(list(history["validation_0"])) == 1
forest = int(
json.loads(classifier.get_booster().save_config())["learner"][
"gradient_booster"
]["gbtree_train_param"]["num_parallel_tree"]
)
if model == "boosting":
assert len(history["validation_0"]["merror"]) == 2
assert forest == 1
else:
assert len(history["validation_0"]["merror"]) == 1
assert forest == 2

cls_booster = classifier.get_booster()
single_node_proba = cls_booster.inplace_predict(X.compute())
# Test .predict_proba()
probas = classifier.predict_proba(X)
assert classifier.n_classes_ == 10
assert probas.ndim == 2
assert probas.shape[0] == kRows
assert probas.shape[1] == 10

np.testing.assert_allclose(single_node_proba,
probas.compute())
cls_booster = classifier.get_booster()
single_node_proba = cls_booster.inplace_predict(X.compute())

# Test with dataframe.
X_d = dd.from_dask_array(X)
y_d = dd.from_dask_array(y)
classifier.fit(X_d, y_d)
np.testing.assert_allclose(single_node_proba, probas.compute())

assert classifier.n_classes_ == 10
prediction = classifier.predict(X_d)
# Test with dataframe.
X_d = dd.from_dask_array(X)
y_d = dd.from_dask_array(y)
classifier.fit(X_d, y_d)

assert prediction.ndim == 1
assert prediction.shape[0] == kRows
assert classifier.n_classes_ == 10
prediction = classifier.predict(X_d)

assert prediction.ndim == 1
assert prediction.shape[0] == kRows


@pytest.mark.skipif(**tm.no_sklearn())
def test_sklearn_grid_search() -> None:
def test_sklearn_grid_search(client: "Client") -> None:
from sklearn.model_selection import GridSearchCV
with LocalCluster(n_workers=kWorkers) as cluster:
with Client(cluster) as client:
X, y, _ = generate_array()
reg = xgb.dask.DaskXGBRegressor(learning_rate=0.1,
tree_method='hist')
reg.client = client
model = GridSearchCV(reg, {'max_depth': [2, 4],
'n_estimators': [5, 10]},
cv=2, verbose=1)
model.fit(X, y)
# Expect unique results for each parameter value This confirms
# sklearn is able to successfully update the parameter
means = model.cv_results_['mean_test_score']
assert len(means) == len(set(means))
X, y, _ = generate_array()
reg = xgb.dask.DaskXGBRegressor(learning_rate=0.1,
tree_method='hist')
reg.client = client
model = GridSearchCV(reg, {'max_depth': [2, 4],
'n_estimators': [5, 10]},
cv=2, verbose=1)
model.fit(X, y)
# Expect unique results for each parameter value This confirms
# sklearn is able to successfully update the parameter
means = model.cv_results_['mean_test_score']
assert len(means) == len(set(means))


def test_empty_dmatrix_training_continuation(client: "Client") -> None:
Expand Down

0 comments on commit 89a00a5

Please sign in to comment.