Skip to content

Commit

Permalink
[python-package] fix access to Dataset metadata in scikit-learn custo…
Browse files Browse the repository at this point in the history
…m metrics and objectives (#6108)
  • Loading branch information
jameslamb committed Nov 7, 2023
1 parent b7f6311 commit aeafccf
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 38 deletions.
68 changes: 46 additions & 22 deletions python-package/lightgbm/basic.py
Expand Up @@ -434,31 +434,31 @@ def _data_to_2d_numpy(
"It should be list of lists, numpy 2-D array or pandas DataFrame")


def _cfloat32_array_to_numpy(cptr: "ctypes._Pointer", length: int) -> np.ndarray:
def _cfloat32_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray:
"""Convert a ctypes float pointer array to a numpy array."""
if isinstance(cptr, ctypes.POINTER(ctypes.c_float)):
return np.ctypeslib.as_array(cptr, shape=(length,)).copy()
else:
raise RuntimeError('Expected float pointer')


def _cfloat64_array_to_numpy(cptr: "ctypes._Pointer", length: int) -> np.ndarray:
def _cfloat64_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray:
"""Convert a ctypes double pointer array to a numpy array."""
if isinstance(cptr, ctypes.POINTER(ctypes.c_double)):
return np.ctypeslib.as_array(cptr, shape=(length,)).copy()
else:
raise RuntimeError('Expected double pointer')


def _cint32_array_to_numpy(cptr: "ctypes._Pointer", length: int) -> np.ndarray:
def _cint32_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray:
"""Convert a ctypes int pointer array to a numpy array."""
if isinstance(cptr, ctypes.POINTER(ctypes.c_int32)):
return np.ctypeslib.as_array(cptr, shape=(length,)).copy()
else:
raise RuntimeError('Expected int32 pointer')


def _cint64_array_to_numpy(cptr: "ctypes._Pointer", length: int) -> np.ndarray:
def _cint64_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray:
"""Convert a ctypes int pointer array to a numpy array."""
if isinstance(cptr, ctypes.POINTER(ctypes.c_int64)):
return np.ctypeslib.as_array(cptr, shape=(length,)).copy()
Expand Down Expand Up @@ -1295,18 +1295,18 @@ def __create_sparse_native(
data_indices_len = out_shape[0]
indptr_len = out_shape[1]
if indptr_type == _C_API_DTYPE_INT32:
out_indptr = _cint32_array_to_numpy(out_ptr_indptr, indptr_len)
out_indptr = _cint32_array_to_numpy(cptr=out_ptr_indptr, length=indptr_len)
elif indptr_type == _C_API_DTYPE_INT64:
out_indptr = _cint64_array_to_numpy(out_ptr_indptr, indptr_len)
out_indptr = _cint64_array_to_numpy(cptr=out_ptr_indptr, length=indptr_len)
else:
raise TypeError("Expected int32 or int64 type for indptr")
if data_type == _C_API_DTYPE_FLOAT32:
out_data = _cfloat32_array_to_numpy(out_ptr_data, data_indices_len)
out_data = _cfloat32_array_to_numpy(cptr=out_ptr_data, length=data_indices_len)
elif data_type == _C_API_DTYPE_FLOAT64:
out_data = _cfloat64_array_to_numpy(out_ptr_data, data_indices_len)
out_data = _cfloat64_array_to_numpy(cptr=out_ptr_data, length=data_indices_len)
else:
raise TypeError("Expected float32 or float64 type for data")
out_indices = _cint32_array_to_numpy(out_ptr_indices, data_indices_len)
out_indices = _cint32_array_to_numpy(cptr=out_ptr_indices, length=data_indices_len)
# break up indptr based on number of rows (note more than one matrix in multiclass case)
per_class_indptr_shape = cs.indptr.shape[0]
# for CSC there is extra column added
Expand Down Expand Up @@ -2609,6 +2609,12 @@ def set_field(
def get_field(self, field_name: str) -> Optional[np.ndarray]:
"""Get property from the Dataset.
Can only be run on a constructed Dataset.
Unlike ``get_group()``, ``get_init_score()``, ``get_label()``, ``get_position()``, and ``get_weight()``,
this method ignores any raw data passed into ``lgb.Dataset()`` on the Python side, and will only read
data from the constructed C++ ``Dataset`` object.
Parameters
----------
field_name : str
Expand All @@ -2635,11 +2641,20 @@ def get_field(self, field_name: str) -> Optional[np.ndarray]:
if tmp_out_len.value == 0:
return None
if out_type.value == _C_API_DTYPE_INT32:
arr = _cint32_array_to_numpy(ctypes.cast(ret, ctypes.POINTER(ctypes.c_int32)), tmp_out_len.value)
arr = _cint32_array_to_numpy(
cptr=ctypes.cast(ret, ctypes.POINTER(ctypes.c_int32)),
length=tmp_out_len.value
)
elif out_type.value == _C_API_DTYPE_FLOAT32:
arr = _cfloat32_array_to_numpy(ctypes.cast(ret, ctypes.POINTER(ctypes.c_float)), tmp_out_len.value)
arr = _cfloat32_array_to_numpy(
cptr=ctypes.cast(ret, ctypes.POINTER(ctypes.c_float)),
length=tmp_out_len.value
)
elif out_type.value == _C_API_DTYPE_FLOAT64:
arr = _cfloat64_array_to_numpy(ctypes.cast(ret, ctypes.POINTER(ctypes.c_double)), tmp_out_len.value)
arr = _cfloat64_array_to_numpy(
cptr=ctypes.cast(ret, ctypes.POINTER(ctypes.c_double)),
length=tmp_out_len.value
)
else:
raise TypeError("Unknown type")
if field_name == 'init_score':
Expand Down Expand Up @@ -2878,6 +2893,10 @@ def set_group(
if self._handle is not None and group is not None:
group = _list_to_1d_numpy(group, dtype=np.int32, name='group')
self.set_field('group', group)
# original values can be modified at cpp side
constructed_group = self.get_field('group')
if constructed_group is not None:
self.group = np.diff(constructed_group)
return self

def set_position(
Expand Down Expand Up @@ -2941,37 +2960,40 @@ def get_feature_name(self) -> List[str]:
ptr_string_buffers))
return [string_buffers[i].value.decode('utf-8') for i in range(num_feature)]

def get_label(self) -> Optional[np.ndarray]:
def get_label(self) -> Optional[_LGBM_LabelType]:
"""Get the label of the Dataset.
Returns
-------
label : numpy array or None
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None
The label information from the Dataset.
For a constructed ``Dataset``, this will only return a numpy array.
"""
if self.label is None:
self.label = self.get_field('label')
return self.label

def get_weight(self) -> Optional[np.ndarray]:
def get_weight(self) -> Optional[_LGBM_WeightType]:
"""Get the weight of the Dataset.
Returns
-------
weight : numpy array or None
weight : list, numpy 1-D array, pandas Series or None
Weight for each data point from the Dataset. Weights should be non-negative.
For a constructed ``Dataset``, this will only return ``None`` or a numpy array.
"""
if self.weight is None:
self.weight = self.get_field('weight')
return self.weight

def get_init_score(self) -> Optional[np.ndarray]:
def get_init_score(self) -> Optional[_LGBM_InitScoreType]:
"""Get the initial score of the Dataset.
Returns
-------
init_score : numpy array or None
init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), or None
Init score of Booster.
For a constructed ``Dataset``, this will only return ``None`` or a numpy array.
"""
if self.init_score is None:
self.init_score = self.get_field('init_score')
Expand Down Expand Up @@ -3009,17 +3031,18 @@ def get_data(self) -> Optional[_LGBM_TrainDataType]:
"set free_raw_data=False when construct Dataset to avoid this.")
return self.data

def get_group(self) -> Optional[np.ndarray]:
def get_group(self) -> Optional[_LGBM_GroupType]:
"""Get the group of the Dataset.
Returns
-------
group : numpy array or None
group : list, numpy 1-D array, pandas Series or None
Group/query data.
Only used in the learning-to-rank task.
sum(group) = n_samples.
For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups,
where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
For a constructed ``Dataset``, this will only return ``None`` or a numpy array.
"""
if self.group is None:
self.group = self.get_field('group')
Expand All @@ -3028,13 +3051,14 @@ def get_group(self) -> Optional[np.ndarray]:
self.group = np.diff(self.group)
return self.group

def get_position(self) -> Optional[np.ndarray]:
def get_position(self) -> Optional[_LGBM_PositionType]:
"""Get the position of the Dataset.
Returns
-------
position : numpy 1-D array or None
position : numpy 1-D array, pandas Series or None
Position of items used in unbiased learning-to-rank task.
For a constructed ``Dataset``, this will only return ``None`` or a numpy array.
"""
if self.position is None:
self.position = self.get_field('position')
Expand Down
69 changes: 54 additions & 15 deletions python-package/lightgbm/sklearn.py
Expand Up @@ -86,6 +86,36 @@
_LGBM_ScikitValidSet = Tuple[_LGBM_ScikitMatrixLike, _LGBM_LabelType]


def _get_group_from_constructed_dataset(dataset: Dataset) -> Optional[np.ndarray]:
group = dataset.get_group()
error_msg = (
"Estimators in lightgbm.sklearn should only retrieve query groups from a constructed Dataset. "
"If you're seeing this message, it's a bug in lightgbm. Please report it at https://github.com/microsoft/LightGBM/issues."
)
assert (group is None or isinstance(group, np.ndarray)), error_msg
return group


def _get_label_from_constructed_dataset(dataset: Dataset) -> np.ndarray:
label = dataset.get_label()
error_msg = (
"Estimators in lightgbm.sklearn should only retrieve labels from a constructed Dataset. "
"If you're seeing this message, it's a bug in lightgbm. Please report it at https://github.com/microsoft/LightGBM/issues."
)
assert isinstance(label, np.ndarray), error_msg
return label


def _get_weight_from_constructed_dataset(dataset: Dataset) -> Optional[np.ndarray]:
weight = dataset.get_weight()
error_msg = (
"Estimators in lightgbm.sklearn should only retrieve weights from a constructed Dataset. "
"If you're seeing this message, it's a bug in lightgbm. Please report it at https://github.com/microsoft/LightGBM/issues."
)
assert (weight is None or isinstance(weight, np.ndarray)), error_msg
return weight


class _ObjectiveFunctionWrapper:
"""Proxy class for objective function."""

Expand Down Expand Up @@ -151,17 +181,22 @@ def __call__(self, preds: np.ndarray, dataset: Dataset) -> Tuple[np.ndarray, np.
The value of the second order derivative (Hessian) of the loss
with respect to the elements of preds for each sample point.
"""
labels = dataset.get_label()
labels = _get_label_from_constructed_dataset(dataset)
argc = len(signature(self.func).parameters)
if argc == 2:
grad, hess = self.func(labels, preds) # type: ignore[call-arg]
elif argc == 3:
grad, hess = self.func(labels, preds, dataset.get_weight()) # type: ignore[call-arg]
elif argc == 4:
grad, hess = self.func(labels, preds, dataset.get_weight(), dataset.get_group()) # type: ignore [call-arg]
else:
raise TypeError(f"Self-defined objective function should have 2, 3 or 4 arguments, got {argc}")
return grad, hess
return grad, hess

weight = _get_weight_from_constructed_dataset(dataset)
if argc == 3:
grad, hess = self.func(labels, preds, weight) # type: ignore[call-arg]
return grad, hess

if argc == 4:
group = _get_group_from_constructed_dataset(dataset)
return self.func(labels, preds, weight, group) # type: ignore[call-arg]

raise TypeError(f"Self-defined objective function should have 2, 3 or 4 arguments, got {argc}")


class _EvalFunctionWrapper:
Expand Down Expand Up @@ -229,16 +264,20 @@ def __call__(
is_higher_better : bool
Is eval result higher better, e.g. AUC is ``is_higher_better``.
"""
labels = dataset.get_label()
labels = _get_label_from_constructed_dataset(dataset)
argc = len(signature(self.func).parameters)
if argc == 2:
return self.func(labels, preds) # type: ignore[call-arg]
elif argc == 3:
return self.func(labels, preds, dataset.get_weight()) # type: ignore[call-arg]
elif argc == 4:
return self.func(labels, preds, dataset.get_weight(), dataset.get_group()) # type: ignore[call-arg]
else:
raise TypeError(f"Self-defined eval function should have 2, 3 or 4 arguments, got {argc}")

weight = _get_weight_from_constructed_dataset(dataset)
if argc == 3:
return self.func(labels, preds, weight) # type: ignore[call-arg]

if argc == 4:
group = _get_group_from_constructed_dataset(dataset)
return self.func(labels, preds, weight, group) # type: ignore[call-arg]

raise TypeError(f"Self-defined eval function should have 2, 3 or 4 arguments, got {argc}")


# documentation templates for LGBMModel methods are shared between the classes in
Expand Down
90 changes: 89 additions & 1 deletion tests/python_package_test/test_basic.py
Expand Up @@ -15,7 +15,7 @@
import lightgbm as lgb
from lightgbm.compat import PANDAS_INSTALLED, pd_DataFrame, pd_Series

from .utils import dummy_obj, load_breast_cancer, mse_obj
from .utils import dummy_obj, load_breast_cancer, mse_obj, np_assert_array_equal


def test_basic(tmp_path):
Expand Down Expand Up @@ -499,6 +499,94 @@ def check_asserts(data):
check_asserts(lgb_data)


def test_dataset_construction_overwrites_user_provided_metadata_fields():

X = np.array([[1.0, 2.0], [3.0, 4.0]])

position = np.array([0.0, 1.0], dtype=np.float32)
if getenv('TASK', '') == 'cuda':
position = None

dtrain = lgb.Dataset(
X,
params={
"min_data_in_bin": 1,
"min_data_in_leaf": 1,
"verbosity": -1
},
group=[1, 1],
init_score=[0.312, 0.708],
label=[1, 2],
position=position,
weight=[0.5, 1.5],
)

# unconstructed, get_* methods should return whatever was provided
assert dtrain.group == [1, 1]
assert dtrain.get_group() == [1, 1]
assert dtrain.init_score == [0.312, 0.708]
assert dtrain.get_init_score() == [0.312, 0.708]
assert dtrain.label == [1, 2]
assert dtrain.get_label() == [1, 2]
if getenv('TASK', '') != 'cuda':
np_assert_array_equal(
dtrain.position,
np.array([0.0, 1.0], dtype=np.float32),
strict=True
)
np_assert_array_equal(
dtrain.get_position(),
np.array([0.0, 1.0], dtype=np.float32),
strict=True
)
assert dtrain.weight == [0.5, 1.5]
assert dtrain.get_weight() == [0.5, 1.5]

# before construction, get_field() should raise an exception
for field_name in ["group", "init_score", "label", "position", "weight"]:
with pytest.raises(Exception, match=f"Cannot get {field_name} before construct Dataset"):
dtrain.get_field(field_name)

# constructed, get_* methods should return numpy arrays, even when the provided
# input was a list of floats or ints
dtrain.construct()
expected_group = np.array([1, 1], dtype=np.int32)
np_assert_array_equal(dtrain.group, expected_group, strict=True)
np_assert_array_equal(dtrain.get_group(), expected_group, strict=True)
# get_field("group") returns a numpy array with boundaries, instead of size
np_assert_array_equal(
dtrain.get_field("group"),
np.array([0, 1, 2], dtype=np.int32),
strict=True
)

expected_init_score = np.array([0.312, 0.708],)
np_assert_array_equal(dtrain.init_score, expected_init_score, strict=True)
np_assert_array_equal(dtrain.get_init_score(), expected_init_score, strict=True)
np_assert_array_equal(dtrain.get_field("init_score"), expected_init_score, strict=True)

expected_label = np.array([1, 2], dtype=np.float32)
np_assert_array_equal(dtrain.label, expected_label, strict=True)
np_assert_array_equal(dtrain.get_label(), expected_label, strict=True)
np_assert_array_equal(dtrain.get_field("label"), expected_label, strict=True)

if getenv('TASK', '') != 'cuda':
expected_position = np.array([0.0, 1.0], dtype=np.float32)
np_assert_array_equal(dtrain.position, expected_position, strict=True)
np_assert_array_equal(dtrain.get_position(), expected_position, strict=True)
# NOTE: "position" is converted to int32 on the C++ side
np_assert_array_equal(
dtrain.get_field("position"),
np.array([0.0, 1.0], dtype=np.int32),
strict=True
)

expected_weight = np.array([0.5, 1.5], dtype=np.float32)
np_assert_array_equal(dtrain.weight, expected_weight, strict=True)
np_assert_array_equal(dtrain.get_weight(), expected_weight, strict=True)
np_assert_array_equal(dtrain.get_field("weight"), expected_weight, strict=True)


def test_choose_param_value():

original_params = {
Expand Down

0 comments on commit aeafccf

Please sign in to comment.