Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dask] Random forest estimators #6602

Merged
merged 7 commits into from
Jan 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/python/python_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,9 @@ Dask API
.. autofunction:: xgboost.dask.DaskXGBClassifier

.. autofunction:: xgboost.dask.DaskXGBRegressor

.. autofunction:: xgboost.dask.DaskXGBRanker

.. autofunction:: xgboost.dask.DaskXGBRFRegressor

.. autofunction:: xgboost.dask.DaskXGBRFClassifier
93 changes: 84 additions & 9 deletions python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
from .core import _deprecate_positional_args
from .training import train as worker_train
from .tracker import RabitTracker, get_host_ip
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase, _objective_decorator
from .sklearn import xgboost_model_doc
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase
from .sklearn import xgboost_model_doc, _objective_decorator
from .sklearn import _cls_predict_proba
from .sklearn import XGBRanker

Expand Down Expand Up @@ -1262,7 +1262,6 @@ class DaskScikitLearnBase(XGBModel):

_client = None

# pylint: disable=arguments-differ
@_deprecate_positional_args
async def _predict_async(
self, data: _DaskCollection,
Expand All @@ -1282,7 +1281,7 @@ async def _predict_async(

def predict(
self,
data: _DaskCollection,
X: _DaskCollection,
output_margin: bool = False,
ntree_limit: Optional[int] = None,
validate_features: bool = True,
Expand All @@ -1291,10 +1290,13 @@ def predict(
_assert_dask_support()
msg = '`ntree_limit` is not supported on dask, use model slicing instead.'
assert ntree_limit is None, msg
return self.client.sync(self._predict_async, data,
output_margin=output_margin,
validate_features=validate_features,
base_margin=base_margin)
return self.client.sync(
self._predict_async,
X,
output_margin=output_margin,
validate_features=validate_features,
base_margin=base_margin
)

def __await__(self) -> Awaitable[Any]:
# Generate a coroutine wrapper to make this class awaitable.
Expand Down Expand Up @@ -1586,7 +1588,8 @@ async def _predict_async(
""",
)
class DaskXGBRanker(DaskScikitLearnBase):
def __init__(self, objective: str = "rank:pairwise", **kwargs: Any):
@_deprecate_positional_args
def __init__(self, *, objective: str = "rank:pairwise", **kwargs: Any):
if callable(objective):
raise ValueError("Custom objective function not supported by XGBRanker.")
super().__init__(objective=objective, kwargs=kwargs)
Expand Down Expand Up @@ -1698,3 +1701,75 @@ def fit( # pylint: disable=arguments-differ

# FIXME(trivialfis): arguments differ due to additional parameters like group and qid.
fit.__doc__ = XGBRanker.fit.__doc__


@xgboost_model_doc(
"Implementation of the Scikit-Learn API for XGBoost Random Forest Regressor.",
["model", "objective"],
extra_parameters="""
n_estimators : int
Number of trees in random forest to fit.
""",
)
class DaskXGBRFRegressor(DaskXGBRegressor):
@_deprecate_positional_args
def __init__(
self,
*,
learning_rate: Optional[float] = 1,
subsample: Optional[float] = 0.8,
colsample_bynode: Optional[float] = 0.8,
reg_lambda: Optional[float] = 1e-5,
**kwargs: Any
) -> None:
super().__init__(
learning_rate=learning_rate,
subsample=subsample,
colsample_bynode=colsample_bynode,
reg_lambda=reg_lambda,
**kwargs
)

def get_xgb_params(self) -> Dict[str, Any]:
params = super().get_xgb_params()
params["num_parallel_tree"] = self.n_estimators
return params

def get_num_boosting_rounds(self) -> int:
return 1


@xgboost_model_doc(
"Implementation of the Scikit-Learn API for XGBoost Random Forest Classifier.",
["model", "objective"],
extra_parameters="""
n_estimators : int
Number of trees in random forest to fit.
""",
)
class DaskXGBRFClassifier(DaskXGBClassifier):
@_deprecate_positional_args
def __init__(
self,
*,
learning_rate: Optional[float] = 1,
subsample: Optional[float] = 0.8,
colsample_bynode: Optional[float] = 0.8,
reg_lambda: Optional[float] = 1e-5,
**kwargs: Any
) -> None:
super().__init__(
learning_rate=learning_rate,
subsample=subsample,
colsample_bynode=colsample_bynode,
reg_lambda=reg_lambda,
**kwargs
)

def get_xgb_params(self) -> Dict[str, Any]:
params = super().get_xgb_params()
params["num_parallel_tree"] = self.n_estimators
return params

def get_num_boosting_rounds(self) -> int:
return 1
4 changes: 2 additions & 2 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def inner(preds, dmatrix):
node of the tree.
min_child_weight : float
Minimum sum of instance weight(hessian) needed in a child.
max_delta_step : int
max_delta_step : float
Maximum delta step we allow each tree's weight estimation to be.
subsample : float
Subsample ratio of the training instance.
Expand Down Expand Up @@ -1465,7 +1465,7 @@ def fit(
xgb_model = xgb_model._Booster # pylint: disable=protected-access

self._Booster = train(params, train_dmatrix,
self.n_estimators,
self.get_num_boosting_rounds(),
early_stopping_rounds=early_stopping_rounds,
evals=evals,
evals_result=evals_result, feval=feval,
Expand Down
173 changes: 100 additions & 73 deletions tests/python/test_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
if hasattr(HealthCheck, 'function_scoped_fixture'):
suppress = [HealthCheck.function_scoped_fixture]
else:
suppress = hypothesis.utils.conventions.not_set
suppress = hypothesis.utils.conventions.not_set # type:ignore


kRows = 1000
Expand Down Expand Up @@ -264,100 +264,127 @@ def test_dask_missing_value_cls() -> None:
assert hasattr(cls, 'missing')


def test_dask_regressor() -> None:
with LocalCluster(n_workers=kWorkers) as cluster:
with Client(cluster) as client:
X, y, w = generate_array(with_weights=True)
regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
assert regressor._estimator_type == "regressor"
assert sklearn.base.is_regressor(regressor)
@pytest.mark.parametrize("model", ["boosting", "rf"])
def test_dask_regressor(model: str, client: "Client") -> None:
X, y, w = generate_array(with_weights=True)
if model == "boosting":
regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
else:
regressor = xgb.dask.DaskXGBRFRegressor(verbosity=1, n_estimators=2)

regressor.set_params(tree_method='hist')
regressor.client = client
regressor.fit(X, y, sample_weight=w, eval_set=[(X, y)])
prediction = regressor.predict(X)
assert regressor._estimator_type == "regressor"
assert sklearn.base.is_regressor(regressor)

assert prediction.ndim == 1
assert prediction.shape[0] == kRows
regressor.set_params(tree_method='hist')
regressor.client = client
regressor.fit(X, y, sample_weight=w, eval_set=[(X, y)])
prediction = regressor.predict(X)

history = regressor.evals_result()
assert prediction.ndim == 1
assert prediction.shape[0] == kRows

assert isinstance(prediction, da.Array)
assert isinstance(history, dict)
history = regressor.evals_result()

assert list(history['validation_0'].keys())[0] == 'rmse'
assert len(history['validation_0']['rmse']) == 2
assert isinstance(prediction, da.Array)
assert isinstance(history, dict)

assert list(history['validation_0'].keys())[0] == 'rmse'
forest = int(
json.loads(regressor.get_booster().save_config())["learner"][
"gradient_booster"
]["gbtree_train_param"]["num_parallel_tree"]
)

def test_dask_classifier() -> None:
with LocalCluster(n_workers=kWorkers) as cluster:
with Client(cluster) as client:
X, y, w = generate_array(with_weights=True)
y = (y * 10).astype(np.int32)
classifier = xgb.dask.DaskXGBClassifier(
verbosity=1, n_estimators=2, eval_metric='merror')
assert classifier._estimator_type == "classifier"
assert sklearn.base.is_classifier(classifier)
if model == "boosting":
assert len(history['validation_0']['rmse']) == 2
assert forest == 1
else:
assert len(history['validation_0']['rmse']) == 1
assert forest == 2

classifier.client = client
classifier.fit(X, y, sample_weight=w, eval_set=[(X, y)])
prediction = classifier.predict(X)

assert prediction.ndim == 1
assert prediction.shape[0] == kRows
@pytest.mark.parametrize("model", ["boosting", "rf"])
def test_dask_classifier(model: str, client: "Client") -> None:
X, y, w = generate_array(with_weights=True)
y = (y * 10).astype(np.int32)
if model == "boosting":
classifier = xgb.dask.DaskXGBClassifier(
verbosity=1, n_estimators=2, eval_metric="merror"
)
else:
classifier = xgb.dask.DaskXGBRFClassifier(
verbosity=1, n_estimators=2, eval_metric="merror"
)

assert classifier._estimator_type == "classifier"
assert sklearn.base.is_classifier(classifier)

classifier.client = client
classifier.fit(X, y, sample_weight=w, eval_set=[(X, y)])
prediction = classifier.predict(X)

history = classifier.evals_result()
assert prediction.ndim == 1
assert prediction.shape[0] == kRows

assert isinstance(prediction, da.Array)
assert isinstance(history, dict)
history = classifier.evals_result()

assert list(history.keys())[0] == 'validation_0'
assert list(history['validation_0'].keys())[0] == 'merror'
assert len(list(history['validation_0'])) == 1
assert len(history['validation_0']['merror']) == 2
assert isinstance(prediction, da.Array)
assert isinstance(history, dict)

# Test .predict_proba()
probas = classifier.predict_proba(X)
assert classifier.n_classes_ == 10
assert probas.ndim == 2
assert probas.shape[0] == kRows
assert probas.shape[1] == 10
assert list(history.keys())[0] == "validation_0"
assert list(history["validation_0"].keys())[0] == "merror"
assert len(list(history["validation_0"])) == 1
forest = int(
json.loads(classifier.get_booster().save_config())["learner"][
"gradient_booster"
]["gbtree_train_param"]["num_parallel_tree"]
)
if model == "boosting":
assert len(history["validation_0"]["merror"]) == 2
assert forest == 1
else:
assert len(history["validation_0"]["merror"]) == 1
assert forest == 2

cls_booster = classifier.get_booster()
single_node_proba = cls_booster.inplace_predict(X.compute())
# Test .predict_proba()
probas = classifier.predict_proba(X)
assert classifier.n_classes_ == 10
assert probas.ndim == 2
assert probas.shape[0] == kRows
assert probas.shape[1] == 10

np.testing.assert_allclose(single_node_proba,
probas.compute())
cls_booster = classifier.get_booster()
single_node_proba = cls_booster.inplace_predict(X.compute())

# Test with dataframe.
X_d = dd.from_dask_array(X)
y_d = dd.from_dask_array(y)
classifier.fit(X_d, y_d)
np.testing.assert_allclose(single_node_proba, probas.compute())

assert classifier.n_classes_ == 10
prediction = classifier.predict(X_d)
# Test with dataframe.
X_d = dd.from_dask_array(X)
y_d = dd.from_dask_array(y)
classifier.fit(X_d, y_d)

assert prediction.ndim == 1
assert prediction.shape[0] == kRows
assert classifier.n_classes_ == 10
prediction = classifier.predict(X_d)

assert prediction.ndim == 1
assert prediction.shape[0] == kRows


@pytest.mark.skipif(**tm.no_sklearn())
def test_sklearn_grid_search() -> None:
def test_sklearn_grid_search(client: "Client") -> None:
from sklearn.model_selection import GridSearchCV
with LocalCluster(n_workers=kWorkers) as cluster:
with Client(cluster) as client:
X, y, _ = generate_array()
reg = xgb.dask.DaskXGBRegressor(learning_rate=0.1,
tree_method='hist')
reg.client = client
model = GridSearchCV(reg, {'max_depth': [2, 4],
'n_estimators': [5, 10]},
cv=2, verbose=1)
model.fit(X, y)
# Expect unique results for each parameter value This confirms
# sklearn is able to successfully update the parameter
means = model.cv_results_['mean_test_score']
assert len(means) == len(set(means))
X, y, _ = generate_array()
reg = xgb.dask.DaskXGBRegressor(learning_rate=0.1,
tree_method='hist')
reg.client = client
model = GridSearchCV(reg, {'max_depth': [2, 4],
'n_estimators': [5, 10]},
cv=2, verbose=1)
model.fit(X, y)
# Expect unique results for each parameter value This confirms
# sklearn is able to successfully update the parameter
means = model.cv_results_['mean_test_score']
assert len(means) == len(set(means))


def test_empty_dmatrix_training_continuation(client: "Client") -> None:
Expand Down