Skip to content

Commit

Permalink
Document the optimization.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jan 29, 2021
1 parent cc33e73 commit eff113c
Showing 1 changed file with 31 additions and 2 deletions.
33 changes: 31 additions & 2 deletions doc/tutorials/dask.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,11 @@ is a ``DaskDMatrix`` or ``da.Array``. When putting dask collection directly int
``predict`` function or using ``inplace_predict``, the output type depends on input data.
See next section for details.

Alternatively, XGBoost also implements the Scikit-Learn interface with ``DaskXGBClassifier``
and ``DaskXGBRegressor``. See ``xgboost/demo/dask`` for more examples.
Alternatively, XGBoost also implements the Scikit-Learn interface with
``DaskXGBClassifier``, ``DaskXGBRegressor``, ``DaskXGBRanker`` and 2 random forest
variances. This wrapper is similar to the single node Scikit-Learn interface in xgboost,
with dask collection as inputs and has an additional ``client`` attribute. See
``xgboost/demo/dask`` for more examples.


******************
Expand Down Expand Up @@ -160,6 +163,32 @@ if not using GPU, the number of threads used for prediction on each block matter
now, xgboost uses single thread for each partition. If the number of blocks on each
workers is smaller than number of cores, then the CPU workers might not be fully utilized.

One simple optimization for running consecutive predictions is using
``distributed.Future``:

.. code-block:: python
dataset = [X_0, X_1, X_2]
booster_f = client.scatter(booster, broadcast=True)
futures = []
for X in dataset:
# Here we pass in a future instead of concrete booster
shap_f = xgb.dask.predict(client, booster_f, X, pred_contribs=True)
futures.append(shap_f)
results = client.gather(futures)
This is only available on functional interface, as the Scikit-Learn wrapper doesn't know
how to maintain a valid future for booster. To obtain the booster object from
Scikit-Learn wrapper object:

.. code-block:: python
cls = xgb.dask.DaskXGBClassifier()
cls.fit(X, y)
booster = cls.get_booster()
***************************
Expand Down

0 comments on commit eff113c

Please sign in to comment.