Skip to content

Commit

Permalink
[python-package] [dask] Add DaskLGBMRanker (#3708)
Browse files Browse the repository at this point in the history
* ranker support wip

* fix ranker tests

* fix _make_ranking rnd gen bug, add sleep to help w stoch binding port failed exceptions

* add wait_for_workers to prevent Binding port exception

* another attempt to stabilize test_dask.py

* requested changes: docstrings, dask_ml, tuples for list_of_parts

* fix lint bug, add group param to test_ranker_local_predict

* decorator to skip tests with errors on fixture teardown

* remove gpu ranker tests, reduce make_ranking data complexity

* another attempt to
silence client, decorator does not silence fixture errors

* address requested changes on 1/20/20

* skip test_dask for all GPU tasks

* address changes requested on 1/21/21

* issubclass instead of __qualname__

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* parity in group docstr with sklearn

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* _make_ranking docstr cleanup

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
  • Loading branch information
ffineis and StrikerRUS committed Jan 22, 2021
1 parent 6dbe736 commit 3c7e7e0
Show file tree
Hide file tree
Showing 2 changed files with 265 additions and 21 deletions.
86 changes: 68 additions & 18 deletions python-package/lightgbm/dask.py
Expand Up @@ -21,7 +21,7 @@
from dask.distributed import Client, default_client, get_worker, wait

from .basic import _ConfigAliases, _LIB, _safe_call
from .sklearn import LGBMClassifier, LGBMRegressor
from .sklearn import LGBMClassifier, LGBMRegressor, LGBMRanker

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -133,15 +133,24 @@ def _train_part(params, model_factory, list_of_parts, worker_address_to_port, re
}
params.update(network_params)

is_ranker = issubclass(model_factory, LGBMRanker)

# Concatenate many parts into one
parts = tuple(zip(*list_of_parts))
data = _concat(parts[0])
label = _concat(parts[1])
weight = _concat(parts[2]) if len(parts) == 3 else None

try:
model = model_factory(**params)
model.fit(data, label, sample_weight=weight, **kwargs)

if is_ranker:
group = _concat(parts[-1])
weight = _concat(parts[2]) if len(parts) == 4 else None
model.fit(data, y=label, sample_weight=weight, group=group, **kwargs)
else:
weight = _concat(parts[2]) if len(parts) == 3 else None
model.fit(data, y=label, sample_weight=weight, **kwargs)

finally:
_safe_call(_LIB.LGBM_NetworkFree())

Expand All @@ -156,7 +165,7 @@ def _split_to_parts(data, is_matrix):
return parts


def _train(client, data, label, params, model_factory, weight=None, **kwargs):
def _train(client, data, label, params, model_factory, sample_weight=None, group=None, **kwargs):
"""Inner train routine.
Parameters
Expand All @@ -167,22 +176,36 @@ def _train(client, data, label, params, model_factory, weight=None, **kwargs):
y : dask array of shape = [n_samples]
The target values (class labels in classification, real numbers in regression).
params : dict
model_factory : lightgbm.LGBMClassifier or lightgbm.LGBMRegressor class
model_factory : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class
sample_weight : array-like of shape = [n_samples] or None, optional (default=None)
Weights of training data.
Weights of training data.
group : array-like or None, optional (default=None)
Group/query data.
Only used in the learning-to-rank task.
sum(group) = n_samples.
For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups,
where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
"""
params = deepcopy(params)

# Split arrays/dataframes into parts. Arrange parts into tuples to enforce co-locality
data_parts = _split_to_parts(data, is_matrix=True)
label_parts = _split_to_parts(label, is_matrix=False)
if weight is None:
parts = list(map(delayed, zip(data_parts, label_parts)))
weight_parts = _split_to_parts(sample_weight, is_matrix=False) if sample_weight is not None else None
group_parts = _split_to_parts(group, is_matrix=False) if group is not None else None

# choose between four options of (sample_weight, group) being (un)specified
if weight_parts is None and group_parts is None:
parts = zip(data_parts, label_parts)
elif weight_parts is not None and group_parts is None:
parts = zip(data_parts, label_parts, weight_parts)
elif weight_parts is None and group_parts is not None:
parts = zip(data_parts, label_parts, group_parts)
else:
weight_parts = _split_to_parts(weight, is_matrix=False)
parts = list(map(delayed, zip(data_parts, label_parts, weight_parts)))
parts = zip(data_parts, label_parts, weight_parts, group_parts)

# Start computation in the background
parts = list(map(delayed, parts))
parts = client.compute(parts)
wait(parts)

Expand Down Expand Up @@ -281,13 +304,13 @@ def _predict(model, data, proba=False, dtype=np.float32, **kwargs):
Parameters
----------
model :
model : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class
data : dask array of shape = [n_samples, n_features]
Input feature matrix.
proba : bool
Should method return results of predict_proba (proba == True) or predict (proba == False)
Should method return results of predict_proba (proba == True) or predict (proba == False).
dtype : np.dtype
Dtype of the output
Dtype of the output.
kwargs : other parameters passed to predict or predict_proba method
"""
if isinstance(data, dd._Frame):
Expand All @@ -304,13 +327,14 @@ def _predict(model, data, proba=False, dtype=np.float32, **kwargs):

class _LGBMModel:

def _fit(self, model_factory, X, y=None, sample_weight=None, client=None, **kwargs):
def _fit(self, model_factory, X, y=None, sample_weight=None, group=None, client=None, **kwargs):
"""Docstring is inherited from the LGBMModel."""
if client is None:
client = default_client()

params = self.get_params(True)
model = _train(client, X, y, params, model_factory, sample_weight, **kwargs)
model = _train(client, data=X, label=y, params=params, model_factory=model_factory,
sample_weight=sample_weight, group=group, **kwargs)

self.set_params(**model.get_params())
self._copy_extra_params(model, self)
Expand All @@ -335,8 +359,8 @@ class DaskLGBMClassifier(_LGBMModel, LGBMClassifier):
"""Distributed version of lightgbm.LGBMClassifier."""

def fit(self, X, y=None, sample_weight=None, client=None, **kwargs):
"""Docstring is inherited from the LGBMModel."""
return self._fit(LGBMClassifier, X, y, sample_weight, client, **kwargs)
"""Docstring is inherited from the lightgbm.LGBMClassifier.fit."""
return self._fit(LGBMClassifier, X=X, y=y, sample_weight=sample_weight, client=client, **kwargs)
fit.__doc__ = LGBMClassifier.fit.__doc__

def predict(self, X, **kwargs):
Expand Down Expand Up @@ -364,7 +388,7 @@ class DaskLGBMRegressor(_LGBMModel, LGBMRegressor):

def fit(self, X, y=None, sample_weight=None, client=None, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMRegressor.fit."""
return self._fit(LGBMRegressor, X, y, sample_weight, client, **kwargs)
return self._fit(LGBMRegressor, X=X, y=y, sample_weight=sample_weight, client=client, **kwargs)
fit.__doc__ = LGBMRegressor.fit.__doc__

def predict(self, X, **kwargs):
Expand All @@ -380,3 +404,29 @@ def to_local(self):
model : lightgbm.LGBMRegressor
"""
return self._to_local(LGBMRegressor)


class DaskLGBMRanker(_LGBMModel, LGBMRanker):
"""Docstring is inherited from the lightgbm.LGBMRanker."""

def fit(self, X, y=None, sample_weight=None, init_score=None, group=None, client=None, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMRanker.fit."""
if init_score is not None:
raise RuntimeError('init_score is not currently supported in lightgbm.dask')

return self._fit(LGBMRanker, X=X, y=y, sample_weight=sample_weight, group=group, client=client, **kwargs)
fit.__doc__ = LGBMRanker.fit.__doc__

def predict(self, X, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMRanker.predict."""
return _predict(self.to_local(), X, **kwargs)
predict.__doc__ = LGBMRanker.predict.__doc__

def to_local(self):
"""Create regular version of lightgbm.LGBMRanker from the distributed version.
Returns
-------
model : lightgbm.LGBMRanker
"""
return self._to_local(LGBMRanker)

0 comments on commit 3c7e7e0

Please sign in to comment.