Skip to content

Commit

Permalink
less expensive access pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb committed Oct 6, 2023
1 parent 5c60a31 commit 017e5e5
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 22 deletions.
15 changes: 10 additions & 5 deletions python-package/lightgbm/basic.py
Expand Up @@ -2850,8 +2850,9 @@ def get_label(self) -> Optional[_LGBM_LabelType]:
Returns
-------
label : numpy array or None
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None
The label information from the Dataset.
For a constructed ``Dataset``, this will only return a numpy array.
"""
if self.label is None:
self.label = self.get_field('label')
Expand All @@ -2862,8 +2863,9 @@ def get_weight(self) -> Optional[_LGBM_WeightType]:
Returns
-------
weight : numpy array or None
weight : list, numpy 1-D array, pandas Series or None
Weight for each data point from the Dataset. Weights should be non-negative.
For a constructed ``Dataset``, this will only return ``None`` or a numpy array.
"""
if self.weight is None:
self.weight = self.get_field('weight')
Expand All @@ -2874,8 +2876,9 @@ def get_init_score(self) -> Optional[_LGBM_InitScoreType]:
Returns
-------
init_score : numpy array or None
init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), or None
Init score of Booster.
For a constructed ``Dataset``, this will only return ``None`` or a numpy array.
"""
if self.init_score is None:
self.init_score = self.get_field('init_score')
Expand Down Expand Up @@ -2918,12 +2921,13 @@ def get_group(self) -> Optional[_LGBM_GroupType]:
Returns
-------
group : numpy array or None
group : list, numpy 1-D array, pandas Series or None
Group/query data.
Only used in the learning-to-rank task.
sum(group) = n_samples.
For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups,
where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
For a constructed ``Dataset``, this will only return ``None`` or a numpy array.
"""
if self.group is None:
self.group = self.get_field('group')
Expand All @@ -2937,8 +2941,9 @@ def get_position(self) -> Optional[_LGBM_PositionType]:
Returns
-------
position : numpy 1-D array or None
position : numpy 1-D array, pandas Series or None
Position of items used in unbiased learning-to-rank task.
For a constructed ``Dataset``, this will only return ``None`` or a numpy array.
"""
if self.position is None:
self.position = self.get_field('position')
Expand Down
61 changes: 44 additions & 17 deletions python-package/lightgbm/sklearn.py
Expand Up @@ -86,6 +86,26 @@
_LGBM_ScikitValidSet = Tuple[_LGBM_ScikitMatrixLike, _LGBM_LabelType]


def _get_label_from_constructed_dataset(dataset: Dataset) -> np.ndarray:
label = dataset.get_label()
error_msg = (
"Estimators in lightgbm.sklearn should only retrieve labels from a constructed Dataset. "
"If you're seeing this message, it's a bug in lightgbm. Please report it at https://github.com/microsoft/LightGBM/issues."
)
assert isinstance(label, np.ndarray), error_msg
return label


def _get_weight_from_constructed_dataset(dataset: Dataset) -> Optional[np.ndarray]:
weight = dataset.get_weight()
error_msg = (
"Estimators in lightgbm.sklearn should only retrieve weights from a constructed Dataset. "
"If you're seeing this message, it's a bug in lightgbm. Please report it at https://github.com/microsoft/LightGBM/issues."
)
assert (weight is None or isinstance(weight, np.ndarray)), error_msg
return weight


class _ObjectiveFunctionWrapper:
"""Proxy class for objective function."""

Expand Down Expand Up @@ -151,21 +171,25 @@ def __call__(self, preds: np.ndarray, dataset: Dataset) -> Tuple[np.ndarray, np.
The value of the second order derivative (Hessian) of the loss
with respect to the elements of preds for each sample point.
"""
labels = dataset.get_field("label")
labels = _get_label_from_constructed_dataset(dataset)
argc = len(signature(self.func).parameters)
if argc == 2:
grad, hess = self.func(labels, preds) # type: ignore[call-arg]
elif argc == 3:
grad, hess = self.func(labels, preds, dataset.get_field("weight")) # type: ignore[call-arg]
elif argc == 4:
return grad, hess

weight = _get_weight_from_constructed_dataset(dataset)
if argc == 3:
grad, hess = self.func(labels, preds, weight) # type: ignore[call-arg]
return grad, hess

if argc == 4:
group = dataset.get_field("group")
if group is not None:
return self.func(labels, preds, dataset.get_field("weight"), np.diff(group)) # type: ignore[call-arg]
return self.func(labels, preds, weight, np.diff(group)) # type: ignore[call-arg]
else:
return self.func(labels, preds, dataset.get_field("weight"), group) # type: ignore[call-arg]
else:
raise TypeError(f"Self-defined objective function should have 2, 3 or 4 arguments, got {argc}")
return grad, hess
return self.func(labels, preds, weight, group) # type: ignore[call-arg]

raise TypeError(f"Self-defined objective function should have 2, 3 or 4 arguments, got {argc}")


class _EvalFunctionWrapper:
Expand Down Expand Up @@ -233,20 +257,23 @@ def __call__(
is_higher_better : bool
Is eval result higher better, e.g. AUC is ``is_higher_better``.
"""
labels = dataset.get_field("label")
labels = _get_label_from_constructed_dataset(dataset)
argc = len(signature(self.func).parameters)
if argc == 2:
return self.func(labels, preds) # type: ignore[call-arg]
elif argc == 3:
return self.func(labels, preds, dataset.get_field("weight")) # type: ignore[call-arg]
elif argc == 4:

weight = _get_weight_from_constructed_dataset(dataset)
if argc == 3:
return self.func(labels, preds, weight) # type: ignore[call-arg]

if argc == 4:
group = dataset.get_field("group")
if group is not None:
return self.func(labels, preds, dataset.get_field("weight"), np.diff(group)) # type: ignore[call-arg]
return self.func(labels, preds, weight, np.diff(group)) # type: ignore[call-arg]
else:
return self.func(labels, preds, dataset.get_field("weight"), group) # type: ignore[call-arg]
else:
raise TypeError(f"Self-defined eval function should have 2, 3 or 4 arguments, got {argc}")
return self.func(labels, preds, weight, group) # type: ignore[call-arg]

raise TypeError(f"Self-defined eval function should have 2, 3 or 4 arguments, got {argc}")


# documentation templates for LGBMModel methods are shared between the classes in
Expand Down

0 comments on commit 017e5e5

Please sign in to comment.