Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dask] Accept Future of model for prediction. #6650

Merged
merged 5 commits into from
Feb 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
105 changes: 81 additions & 24 deletions doc/tutorials/dask.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,11 @@ is a ``DaskDMatrix`` or ``da.Array``. When putting dask collection directly int
``predict`` function or using ``inplace_predict``, the output type depends on input data.
See next section for details.

Alternatively, XGBoost also implements the Scikit-Learn interface with ``DaskXGBClassifier``
and ``DaskXGBRegressor``. See ``xgboost/demo/dask`` for more examples.
Alternatively, XGBoost also implements the Scikit-Learn interface with
``DaskXGBClassifier``, ``DaskXGBRegressor``, ``DaskXGBRanker`` and 2 random forest
variances. This wrapper is similar to the single node Scikit-Learn interface in xgboost,
with dask collection as inputs and has an additional ``client`` attribute. See
``xgboost/demo/dask`` for more examples.


******************
Expand Down Expand Up @@ -160,6 +163,32 @@ if not using GPU, the number of threads used for prediction on each block matter
now, xgboost uses single thread for each partition. If the number of blocks on each
workers is smaller than number of cores, then the CPU workers might not be fully utilized.

One simple optimization for running consecutive predictions is using
``distributed.Future``:

.. code-block:: python

dataset = [X_0, X_1, X_2]
booster_f = client.scatter(booster, broadcast=True)
futures = []
for X in dataset:
# Here we pass in a future instead of concrete booster
shap_f = xgb.dask.predict(client, booster_f, X, pred_contribs=True)
futures.append(shap_f)

results = client.gather(futures)


This is only available on functional interface, as the Scikit-Learn wrapper doesn't know
how to maintain a valid future for booster. To obtain the booster object from
Scikit-Learn wrapper object:

.. code-block:: python

cls = xgb.dask.DaskXGBClassifier()
cls.fit(X, y)

booster = cls.get_booster()


***************************
Expand Down Expand Up @@ -231,17 +260,17 @@ will override the configuration in Dask. For example:
with dask.distributed.LocalCluster(n_workers=7, threads_per_worker=4) as cluster:

There are 4 threads allocated for each dask worker. Then by default XGBoost will use 4
threads in each process for both training and prediction. But if ``nthread`` parameter is
set:
threads in each process for training. But if ``nthread`` parameter is set:

.. code-block:: python

output = xgb.dask.train(client,
{'verbosity': 1,
'nthread': 8,
'tree_method': 'hist'},
dtrain,
num_boost_round=4, evals=[(dtrain, 'train')])
output = xgb.dask.train(
client,
{"verbosity": 1, "nthread": 8, "tree_method": "hist"},
dtrain,
num_boost_round=4,
evals=[(dtrain, "train")],
)

XGBoost will use 8 threads in each training process.

Expand Down Expand Up @@ -274,12 +303,12 @@ Functional interface:
with_X = await xgb.dask.predict(client, output, X)
inplace = await xgb.dask.inplace_predict(client, output, X)

# Use `client.compute` instead of the `compute` method from dask collection
# Use ``client.compute`` instead of the ``compute`` method from dask collection
print(await client.compute(with_m))


While for the Scikit-Learn interface, trivial methods like ``set_params`` and accessing class
attributes like ``evals_result_`` do not require ``await``. Other methods involving
attributes like ``evals_result()`` do not require ``await``. Other methods involving
actual computation will return a coroutine and hence require awaiting:

.. code-block:: python
Expand Down Expand Up @@ -373,6 +402,46 @@ If early stopping is enabled by also passing ``early_stopping_rounds``, you can
print(booster.best_iteration)
best_model = booster[: booster.best_iteration]


*******************
Other customization
*******************

XGBoost dask interface accepts other advanced features found in single node Python
interface, including callback functions, custom evaluation metric and objective:

def eval_error_metric(predt, dtrain: xgb.DMatrix):
label = dtrain.get_label()
r = np.zeros(predt.shape)
gt = predt > 0.5
r[gt] = 1 - label[gt]
le = predt <= 0.5
r[le] = label[le]
return 'CustomErr', np.sum(r)

# custom callback
early_stop = xgb.callback.EarlyStopping(
rounds=early_stopping_rounds,
metric_name="CustomErr",
data_name="Train",
save_best=True,
)

booster = xgb.dask.train(
client,
params={
"objective": "binary:logistic",
"eval_metric": ["error", "rmse"],
"tree_method": "hist",
},
dtrain=D_train,
evals=[(D_train, "Train"), (D_valid, "Valid")],
feval=eval_error_metric, # custom evaluation metric
num_boost_round=100,
callbacks=[early_stop],
)


*****************************************************************************
Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors
*****************************************************************************
Expand Down Expand Up @@ -414,15 +483,3 @@ References:

#. https://github.com/dask/dask/issues/6833
#. https://stackoverflow.com/questions/45941528/how-to-efficiently-send-a-large-numpy-array-to-the-cluster-with-dask-array

***********
Limitations
***********

Basic functionality including model training and generating classification and regression predictions
have been implemented. However, there are still some other limitations we haven't
addressed yet:

- Label encoding for the ``DaskXGBClassifier`` classifier may not be supported. So users need
to encode their training labels into discrete values first.
- Ranking is not yet supported.
79 changes: 46 additions & 33 deletions python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,24 +940,22 @@ def _can_output_df(data: _DaskCollection, output_shape: Tuple) -> bool:


async def _direct_predict_impl(
client: "distributed.Client",
mapped_predict: Callable,
booster: Booster,
booster: "distributed.Future",
data: _DaskCollection,
base_margin: Optional[_DaskCollection],
output_shape: Tuple[int, ...],
meta: Dict[int, str],
) -> _DaskCollection:
columns = list(meta.keys())
booster_f = await client.scatter(data=booster, broadcast=True)
if _can_output_df(data, output_shape):
if base_margin is not None and isinstance(base_margin, da.Array):
base_margin_df: Optional[dd.DataFrame] = base_margin.to_dask_dataframe()
else:
base_margin_df = base_margin
predictions = dd.map_partitions(
mapped_predict,
booster_f,
booster,
data,
True,
columns,
Expand All @@ -984,7 +982,7 @@ async def _direct_predict_impl(
new_axis = [i + 2 for i in range(len(output_shape) - 2)]
predictions = da.map_blocks(
mapped_predict,
booster_f,
booster,
data,
False,
columns,
Expand All @@ -997,7 +995,10 @@ async def _direct_predict_impl(


def _infer_predict_output(
booster: Booster, data: _DaskCollection, inplace: bool, **kwargs: Any
booster: Booster,
data: Union[DaskDMatrix, _DaskCollection],
inplace: bool,
**kwargs: Any
) -> Tuple[Tuple[int, ...], Dict[int, str]]:
"""Create a dummy test sample to infer output shape for prediction."""
if isinstance(data, DaskDMatrix):
Expand All @@ -1021,11 +1022,29 @@ def _infer_predict_output(
return test_predt.shape, meta


async def _get_model_future(
client: "distributed.Client", model: Union[Booster, Dict, "distributed.Future"]
) -> "distributed.Future":
if isinstance(model, Booster):
booster = await client.scatter(model, broadcast=True)
elif isinstance(model, dict):
booster = await client.scatter(model["booster"])
elif isinstance(model, distributed.Future):
booster = model
if booster.type is not Booster:
raise TypeError(
f"Underlying type of model future should be `Booster`, got {booster.type}"
)
else:
raise TypeError(_expect([Booster, dict, distributed.Future], type(model)))
return booster


# pylint: disable=too-many-statements
async def _predict_async(
client: "distributed.Client",
global_config: Dict[str, Any],
model: Union[Booster, Dict],
model: Union[Booster, Dict, "distributed.Future"],
data: _DaskCollection,
output_margin: bool,
missing: float,
Expand All @@ -1035,12 +1054,7 @@ async def _predict_async(
pred_interactions: bool,
validate_features: bool,
) -> _DaskCollection:
if isinstance(model, Booster):
_booster = model
elif isinstance(model, dict):
_booster = model["booster"]
else:
raise TypeError(_expect([Booster, dict], type(model)))
_booster = await _get_model_future(client, model)
if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)):
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame], type(data)))

Expand Down Expand Up @@ -1070,7 +1084,7 @@ def mapped_predict(
# Predict on dask collection directly.
if isinstance(data, (da.Array, dd.DataFrame)):
_output_shape, meta = _infer_predict_output(
_booster,
await _booster.result(),
data,
inplace=False,
output_margin=output_margin,
Expand All @@ -1081,10 +1095,11 @@ def mapped_predict(
validate_features=False,
)
return await _direct_predict_impl(
client, mapped_predict, _booster, data, None, _output_shape, meta
mapped_predict, _booster, data, None, _output_shape, meta
)

output_shape, _ = _infer_predict_output(
booster=_booster,
booster=await _booster.result(),
data=data,
inplace=False,
output_margin=output_margin,
Expand All @@ -1108,11 +1123,9 @@ def dispatched_predict(booster: Booster, part: Any) -> numpy.ndarray:
for i, blob in enumerate(part[1:]):
if meta_names[i] == "base_margin":
base_margin = blob
worker = distributed.get_worker()
with config.config_context(**global_config):
m = DMatrix(
data,
nthread=worker.nthreads,
missing=missing,
base_margin=base_margin,
feature_names=feature_names,
Expand Down Expand Up @@ -1148,9 +1161,8 @@ def dispatched_predict(booster: Booster, part: Any) -> numpy.ndarray:
all_shapes = [shape for part, shape, order in parts_with_order]

futures = []
booster_f = await client.scatter(data=_booster, broadcast=True)
for part in all_parts:
f = client.submit(dispatched_predict, booster_f, part)
f = client.submit(dispatched_predict, _booster, part)
futures.append(f)

# Constructing a dask array from list of numpy arrays
Expand All @@ -1168,7 +1180,7 @@ def dispatched_predict(booster: Booster, part: Any) -> numpy.ndarray:

def predict( # pylint: disable=unused-argument
client: "distributed.Client",
model: Union[TrainReturnT, Booster],
model: Union[TrainReturnT, Booster, "distributed.Future"],
data: Union[DaskDMatrix, _DaskCollection],
output_margin: bool = False,
missing: float = numpy.nan,
Expand All @@ -1194,7 +1206,8 @@ def predict( # pylint: disable=unused-argument
Specify the dask client used for training. Use default client
returned from dask if it's set to None.
model:
The trained model.
The trained model. It can be a distributed.Future so user can
pre-scatter it onto all workers.
data:
Input data used for prediction. When input is a dataframe object,
prediction output is a series.
Expand All @@ -1221,19 +1234,14 @@ def predict( # pylint: disable=unused-argument
async def _inplace_predict_async(
client: "distributed.Client",
global_config: Dict[str, Any],
model: Union[Booster, Dict],
model: Union[Booster, Dict, "distributed.Future"],
data: _DaskCollection,
iteration_range: Tuple[int, int] = (0, 0),
predict_type: str = 'value',
missing: float = numpy.nan
) -> _DaskCollection:
client = _xgb_get_client(client)
if isinstance(model, Booster):
booster = model
elif isinstance(model, dict):
booster = model['booster']
else:
raise TypeError(_expect([Booster, dict], type(model)))
booster = await _get_model_future(client, model)
if not isinstance(data, (da.Array, dd.DataFrame)):
raise TypeError(_expect([da.Array, dd.DataFrame], type(data)))

Expand Down Expand Up @@ -1261,16 +1269,20 @@ def mapped_predict(
return prediction

shape, meta = _infer_predict_output(
booster, data, True, predict_type=predict_type, iteration_range=iteration_range
await booster.result(),
data,
True,
predict_type=predict_type,
iteration_range=iteration_range
)
return await _direct_predict_impl(
client, mapped_predict, booster, data, None, shape, meta
mapped_predict, booster, data, None, shape, meta
)


def inplace_predict( # pylint: disable=unused-argument
client: "distributed.Client",
model: Union[TrainReturnT, Booster],
model: Union[TrainReturnT, Booster, "distributed.Future"],
data: _DaskCollection,
iteration_range: Tuple[int, int] = (0, 0),
predict_type: str = 'value',
Expand All @@ -1286,7 +1298,8 @@ def inplace_predict( # pylint: disable=unused-argument
Specify the dask client used for training. Use default client
returned from dask if it's set to None.
model:
The trained model.
The trained model. It can be a distributed.Future so user can
pre-scatter it onto all workers.
iteration_range:
Specify the range of trees used for prediction.
predict_type:
Expand Down
2 changes: 1 addition & 1 deletion python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def save_model(self, fname: str):
json.dumps({k: v})
meta[k] = v
except TypeError:
warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.')
warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.', UserWarning)
meta['_estimator_type'] = self._get_type()
meta_str = json.dumps(meta)
self.get_booster().set_attr(scikit_learn=meta_str)
Expand Down