From c91e01a1543299c99e61df532f2af9b8f226c677 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Thu, 21 Sep 2023 00:08:59 -0500 Subject: [PATCH 01/10] use stricter data-getters in scikit-learn interface --- python-package/lightgbm/basic.py | 47 +++++++++++++------- python-package/lightgbm/sklearn.py | 51 +++++++++++++++++++--- tests/python_package_test/test_basic.py | 57 +++++++++++++++++++++++++ 3 files changed, 133 insertions(+), 22 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index c2964fcedd8d..d79a37fbe806 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -368,7 +368,7 @@ def _data_to_2d_numpy( "It should be list of lists, numpy 2-D array or pandas DataFrame") -def _cfloat32_array_to_numpy(cptr: "ctypes._Pointer", length: int) -> np.ndarray: +def _cfloat32_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray: """Convert a ctypes float pointer array to a numpy array.""" if isinstance(cptr, ctypes.POINTER(ctypes.c_float)): return np.ctypeslib.as_array(cptr, shape=(length,)).copy() @@ -376,7 +376,7 @@ def _cfloat32_array_to_numpy(cptr: "ctypes._Pointer", length: int) -> np.ndarray raise RuntimeError('Expected float pointer') -def _cfloat64_array_to_numpy(cptr: "ctypes._Pointer", length: int) -> np.ndarray: +def _cfloat64_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray: """Convert a ctypes double pointer array to a numpy array.""" if isinstance(cptr, ctypes.POINTER(ctypes.c_double)): return np.ctypeslib.as_array(cptr, shape=(length,)).copy() @@ -384,7 +384,7 @@ def _cfloat64_array_to_numpy(cptr: "ctypes._Pointer", length: int) -> np.ndarray raise RuntimeError('Expected double pointer') -def _cint32_array_to_numpy(cptr: "ctypes._Pointer", length: int) -> np.ndarray: +def _cint32_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray: """Convert a ctypes int pointer array to a numpy array.""" if isinstance(cptr, ctypes.POINTER(ctypes.c_int32)): return np.ctypeslib.as_array(cptr, shape=(length,)).copy() @@ -392,7 +392,7 @@ def _cint32_array_to_numpy(cptr: "ctypes._Pointer", length: int) -> np.ndarray: raise RuntimeError('Expected int32 pointer') -def _cint64_array_to_numpy(cptr: "ctypes._Pointer", length: int) -> np.ndarray: +def _cint64_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray: """Convert a ctypes int pointer array to a numpy array.""" if isinstance(cptr, ctypes.POINTER(ctypes.c_int64)): return np.ctypeslib.as_array(cptr, shape=(length,)).copy() @@ -1229,18 +1229,18 @@ def __create_sparse_native( data_indices_len = out_shape[0] indptr_len = out_shape[1] if indptr_type == _C_API_DTYPE_INT32: - out_indptr = _cint32_array_to_numpy(out_ptr_indptr, indptr_len) + out_indptr = _cint32_array_to_numpy(cptr=out_ptr_indptr, length=indptr_len) elif indptr_type == _C_API_DTYPE_INT64: - out_indptr = _cint64_array_to_numpy(out_ptr_indptr, indptr_len) + out_indptr = _cint64_array_to_numpy(cptr=out_ptr_indptr, length=indptr_len) else: raise TypeError("Expected int32 or int64 type for indptr") if data_type == _C_API_DTYPE_FLOAT32: - out_data = _cfloat32_array_to_numpy(out_ptr_data, data_indices_len) + out_data = _cfloat32_array_to_numpy(cptr=out_ptr_data, length=data_indices_len) elif data_type == _C_API_DTYPE_FLOAT64: - out_data = _cfloat64_array_to_numpy(out_ptr_data, data_indices_len) + out_data = _cfloat64_array_to_numpy(cptr=out_ptr_data, length=data_indices_len) else: raise TypeError("Expected float32 or float64 type for data") - out_indices = _cint32_array_to_numpy(out_ptr_indices, data_indices_len) + out_indices = _cint32_array_to_numpy(cptr=out_ptr_indices, length=data_indices_len) # break up indptr based on number of rows (note more than one matrix in multiclass case) per_class_indptr_shape = cs.indptr.shape[0] # for CSC there is extra column added @@ -2504,6 +2504,12 @@ def set_field( def get_field(self, field_name: str) -> Optional[np.ndarray]: """Get property from the Dataset. + Can only be run on a constructed Dataset. + + Unlike ``get_group()``, ``get_init_score()``, ``get_label()``, ``get_position()``, and ``get_weight()``, + this method ignores any raw data passed into ``lgb.Dataset()`` on the Python side, and will only read + data from the constructed C++ ``Dataset`` object. + Parameters ---------- field_name : str @@ -2530,11 +2536,20 @@ def get_field(self, field_name: str) -> Optional[np.ndarray]: if tmp_out_len.value == 0: return None if out_type.value == _C_API_DTYPE_INT32: - arr = _cint32_array_to_numpy(ctypes.cast(ret, ctypes.POINTER(ctypes.c_int32)), tmp_out_len.value) + arr = _cint32_array_to_numpy( + cptr=ctypes.cast(ret, ctypes.POINTER(ctypes.c_int32)), + length=tmp_out_len.value + ) elif out_type.value == _C_API_DTYPE_FLOAT32: - arr = _cfloat32_array_to_numpy(ctypes.cast(ret, ctypes.POINTER(ctypes.c_float)), tmp_out_len.value) + arr = _cfloat32_array_to_numpy( + cptr=ctypes.cast(ret, ctypes.POINTER(ctypes.c_float)), + length=tmp_out_len.value + ) elif out_type.value == _C_API_DTYPE_FLOAT64: - arr = _cfloat64_array_to_numpy(ctypes.cast(ret, ctypes.POINTER(ctypes.c_double)), tmp_out_len.value) + arr = _cfloat64_array_to_numpy( + cptr=ctypes.cast(ret, ctypes.POINTER(ctypes.c_double)), + length=tmp_out_len.value + ) else: raise TypeError("Unknown type") if field_name == 'init_score': @@ -2834,7 +2849,7 @@ def get_feature_name(self) -> List[str]: ptr_string_buffers)) return [string_buffers[i].value.decode('utf-8') for i in range(num_feature)] - def get_label(self) -> Optional[np.ndarray]: + def get_label(self) -> Optional[_LGBM_LabelType]: """Get the label of the Dataset. Returns @@ -2846,7 +2861,7 @@ def get_label(self) -> Optional[np.ndarray]: self.label = self.get_field('label') return self.label - def get_weight(self) -> Optional[np.ndarray]: + def get_weight(self) -> Optional[_LGBM_WeightType]: """Get the weight of the Dataset. Returns @@ -2858,7 +2873,7 @@ def get_weight(self) -> Optional[np.ndarray]: self.weight = self.get_field('weight') return self.weight - def get_init_score(self) -> Optional[np.ndarray]: + def get_init_score(self) -> Optional[_LGBM_InitScoreType]: """Get the initial score of the Dataset. Returns @@ -2921,7 +2936,7 @@ def get_group(self) -> Optional[np.ndarray]: self.group = np.diff(self.group) return self.group - def get_position(self) -> Optional[np.ndarray]: + def get_position(self) -> Optional[_LGBM_PositionType]: """Get the position of the Dataset. Returns diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index c71c233df908..d49a242eae25 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -151,19 +151,54 @@ def __call__(self, preds: np.ndarray, dataset: Dataset) -> Tuple[np.ndarray, np. The value of the second order derivative (Hessian) of the loss with respect to the elements of preds for each sample point. """ - labels = dataset.get_label() + labels = dataset.get_field("label") argc = len(signature(self.func).parameters) if argc == 2: grad, hess = self.func(labels, preds) # type: ignore[call-arg] elif argc == 3: - grad, hess = self.func(labels, preds, dataset.get_weight()) # type: ignore[call-arg] + grad, hess = self.func(labels, preds, dataset.get_field("weight")) # type: ignore[call-arg] elif argc == 4: - grad, hess = self.func(labels, preds, dataset.get_weight(), dataset.get_group()) # type: ignore [call-arg] + group = dataset.get_field("group") + if group is not None: + return self.func(labels, preds, dataset.get_field("weight"), np.diff(group)) # type: ignore[call-arg] + else: + return self.func(labels, preds, dataset.get_field("weight"), group) # type: ignore[call-arg] else: raise TypeError(f"Self-defined objective function should have 2, 3 or 4 arguments, got {argc}") return grad, hess +Fixed these: + +```text +python-package/lightgbm/basic.py:2826: error: Incompatible return value type (got "Union[List[float], List[int], ndarray[Any, Any], Any, Any, None]", expected "Optional[ndarray[Any, Any]]") [return-value] +python-package/lightgbm/basic.py:2838: error: Incompatible return value type (got "Union[List[float], List[int], ndarray[Any, Any], Any, None]", expected "Optional[ndarray[Any, Any]]") [return-value] +python-package/lightgbm/basic.py:2850: error: Incompatible return value type (got "Union[List[float], List[List[float]], ndarray[Any, Any], Any, Any, None]", expected "Optional[ndarray[Any, Any]]") [return-value] +python-package/lightgbm/basic.py:2901: error: Incompatible return value type (got "Union[List[float], Any, List[int], ndarray[Any, dtype[Any]], ndarray[Any, Any], None]", expected "Optional[ndarray[Any, Any]]") [return-value] +``` + +And then all of these that came as a result: + +```text +python-package/lightgbm/sklearn.py:157: error: Argument 1 has incompatible type "Union[List[float], List[int], ndarray[Any, Any], Any, None]"; expected "Optional[ndarray[Any, Any]]" [arg-type] +python-package/lightgbm/sklearn.py:157: note: Error code "arg-type" not covered by "type: ignore" comment +python-package/lightgbm/sklearn.py:159: error: Argument 1 has incompatible type "Union[List[float], List[int], ndarray[Any, Any], Any, None]"; expected "Optional[ndarray[Any, Any]]" [arg-type] +python-package/lightgbm/sklearn.py:159: note: Error code "arg-type" not covered by "type: ignore" comment +python-package/lightgbm/sklearn.py:159: error: Argument 3 has incompatible type "Union[List[float], List[int], ndarray[Any, Any], Any, None]"; expected "Optional[ndarray[Any, Any]]" [arg-type] +python-package/lightgbm/sklearn.py:161: error: Argument 1 has incompatible type "Union[List[float], List[int], ndarray[Any, Any], Any, None]"; expected "Optional[ndarray[Any, Any]]" [arg-type] +python-package/lightgbm/sklearn.py:161: note: Error code "arg-type" not covered by "type: ignore" comment +python-package/lightgbm/sklearn.py:161: error: Argument 3 has incompatible type "Union[List[float], List[int], ndarray[Any, Any], Any, None]"; expected "Optional[ndarray[Any, Any]]" [arg-type] +python-package/lightgbm/sklearn.py:235: error: Argument 1 has incompatible type "Union[List[float], List[int], ndarray[Any, Any], Any, None]"; expected "Optional[ndarray[Any, Any]]" [arg-type] +python-package/lightgbm/sklearn.py:235: note: Error code "arg-type" not covered by "type: ignore" comment +python-package/lightgbm/sklearn.py:237: error: Argument 1 has incompatible type "Union[List[float], List[int], ndarray[Any, Any], Any, None]"; expected "Optional[ndarray[Any, Any]]" [arg-type] +python-package/lightgbm/sklearn.py:237: note: Error code "arg-type" not covered by "type: ignore" comment +python-package/lightgbm/sklearn.py:237: error: Argument 3 has incompatible type "Union[List[float], List[int], ndarray[Any, Any], Any, None]"; expected "Optional[ndarray[Any, Any]]" [arg-type] +python-package/lightgbm/sklearn.py:239: error: Argument 1 has incompatible type "Union[List[float], List[int], ndarray[Any, Any], Any, None]"; expected "Optional[ndarray[Any, Any]]" [arg-type] +python-package/lightgbm/sklearn.py:239: note: Error code "arg-type" not covered by "type: ignore" comment +python-package/lightgbm/sklearn.py:239: error: Argument 3 has incompatible type "Union[List[float], List[int], ndarray[Any, Any], Any, None]"; expected "Optional[ndarray[Any, Any]]" [arg-type] +``` + + class _EvalFunctionWrapper: """Proxy class for evaluation function.""" @@ -229,14 +264,18 @@ def __call__( is_higher_better : bool Is eval result higher better, e.g. AUC is ``is_higher_better``. """ - labels = dataset.get_label() + labels = dataset.get_field("label") argc = len(signature(self.func).parameters) if argc == 2: return self.func(labels, preds) # type: ignore[call-arg] elif argc == 3: - return self.func(labels, preds, dataset.get_weight()) # type: ignore[call-arg] + return self.func(labels, preds, dataset.get_field("weight")) # type: ignore[call-arg] elif argc == 4: - return self.func(labels, preds, dataset.get_weight(), dataset.get_group()) # type: ignore[call-arg] + group = dataset.get_field("group") + if group is not None: + return self.func(labels, preds, dataset.get_field("weight"), np.diff(group)) # type: ignore[call-arg] + else: + return self.func(labels, preds, dataset.get_field("weight"), group) # type: ignore[call-arg] else: raise TypeError(f"Self-defined eval function should have 2, 3 or 4 arguments, got {argc}") diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 7f8980c271f7..1e146806f005 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -499,6 +499,63 @@ def check_asserts(data): check_asserts(lgb_data) +def test_dataset_construction_overwrites_user_provided_metadata_fields(): + + X = np.array([[1.0, 2.0], [3.0, 4.0]]) + + dtrain = lgb.Dataset( + X, + params={ + "min_data_in_bin": 1, + "min_data_in_leaf": 1, + "verbosity": -1 + }, + group=[1, 1], + init_score=[0.312, 0.708], + label=[1, 2], + weight=[0.5, 1.5], + ) + + # unconstructed, get_* methods should return whatever was provided + assert dtrain.group == [1, 1] + assert dtrain.get_group() == [1, 1] + assert dtrain.init_score == [0.312, 0.708] + assert dtrain.get_init_score() == [0.312, 0.708] + assert dtrain.label == [1, 2] + assert dtrain.get_label() == [1, 2] + assert dtrain.weight == [0.5, 1.5] + assert dtrain.get_weight() == [0.5, 1.5] + + # before construction, get_field() raises an exception + for field_name in ["group", "init_score", "label", "position", "weight"]: + with pytest.raises(Exception, match="Cannot get weight before construct Dataset") + dtrain.get_field(field_name) + + # constructed, get_* methods should return numpy arrays, even when the provided + # input was a list of floats or ints + dtrain.construct() + expected_group = [1, 1] + assert dtrain.group == expected_group + assert dtrain.get_group() == expected_group + # get_field("group") returns a numpy array with boundaries, instead of size + assert dtrain.get_field("group") == np.array([0, 1, 2], dtype=int32) + + expected_init_score = np.array([0.312, 0.708]) + assert np.testing.assert_array_equal(dtrain.init_score, expected_init_score, strict=True) + assert np.testing.assert_array_equal(dtrain.get_init_score(), expected_init_score, strict=True) + assert np.testing.assert_array_equal(dtrain.get_field("init_score"), expected_init_score, strict=True) + + expected_label = np.array([1, 2], dtype=np.float32) + assert np.testing.assert_array_equal(dtrain.label, expected_label, strict=True) + assert np.testing.assert_array_equal(dtrain.get_label(), label, strict=True) + assert np.testing.assert_array_equal(dtrain.get_field("label"), label, strict=True) + + expected_weight = np.array([0.5, 1.5], dtype=np.float32) + assert np.testing.assert_array_equal(dtrain.weight, expected_weight, strict=True) + assert np.testing.assert_array_equal(dtrain.get_weight(), expected_weight, strict=True) + assert np.testing.assert_array_equal(dtrain.get_field("weight"), expected_weight, strict=True) + + def test_choose_param_value(): original_params = { From 0119e57ff327910f24013af702829fa7fe7ffe4b Mon Sep 17 00:00:00 2001 From: James Lamb Date: Thu, 21 Sep 2023 22:56:11 -0500 Subject: [PATCH 02/10] more changes --- python-package/lightgbm/basic.py | 2 +- python-package/lightgbm/sklearn.py | 31 ------------------------------ 2 files changed, 1 insertion(+), 32 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index d79a37fbe806..d6f95bd4b677 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -2917,7 +2917,7 @@ def get_data(self) -> Optional[_LGBM_TrainDataType]: "set free_raw_data=False when construct Dataset to avoid this.") return self.data - def get_group(self) -> Optional[np.ndarray]: + def get_group(self) -> Optional[_LGBM_GroupType]: """Get the group of the Dataset. Returns diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index d49a242eae25..450757c1e58f 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -168,37 +168,6 @@ def __call__(self, preds: np.ndarray, dataset: Dataset) -> Tuple[np.ndarray, np. return grad, hess -Fixed these: - -```text -python-package/lightgbm/basic.py:2826: error: Incompatible return value type (got "Union[List[float], List[int], ndarray[Any, Any], Any, Any, None]", expected "Optional[ndarray[Any, Any]]") [return-value] -python-package/lightgbm/basic.py:2838: error: Incompatible return value type (got "Union[List[float], List[int], ndarray[Any, Any], Any, None]", expected "Optional[ndarray[Any, Any]]") [return-value] -python-package/lightgbm/basic.py:2850: error: Incompatible return value type (got "Union[List[float], List[List[float]], ndarray[Any, Any], Any, Any, None]", expected "Optional[ndarray[Any, Any]]") [return-value] -python-package/lightgbm/basic.py:2901: error: Incompatible return value type (got "Union[List[float], Any, List[int], ndarray[Any, dtype[Any]], ndarray[Any, Any], None]", expected "Optional[ndarray[Any, Any]]") [return-value] -``` - -And then all of these that came as a result: - -```text -python-package/lightgbm/sklearn.py:157: error: Argument 1 has incompatible type "Union[List[float], List[int], ndarray[Any, Any], Any, None]"; expected "Optional[ndarray[Any, Any]]" [arg-type] -python-package/lightgbm/sklearn.py:157: note: Error code "arg-type" not covered by "type: ignore" comment -python-package/lightgbm/sklearn.py:159: error: Argument 1 has incompatible type "Union[List[float], List[int], ndarray[Any, Any], Any, None]"; expected "Optional[ndarray[Any, Any]]" [arg-type] -python-package/lightgbm/sklearn.py:159: note: Error code "arg-type" not covered by "type: ignore" comment -python-package/lightgbm/sklearn.py:159: error: Argument 3 has incompatible type "Union[List[float], List[int], ndarray[Any, Any], Any, None]"; expected "Optional[ndarray[Any, Any]]" [arg-type] -python-package/lightgbm/sklearn.py:161: error: Argument 1 has incompatible type "Union[List[float], List[int], ndarray[Any, Any], Any, None]"; expected "Optional[ndarray[Any, Any]]" [arg-type] -python-package/lightgbm/sklearn.py:161: note: Error code "arg-type" not covered by "type: ignore" comment -python-package/lightgbm/sklearn.py:161: error: Argument 3 has incompatible type "Union[List[float], List[int], ndarray[Any, Any], Any, None]"; expected "Optional[ndarray[Any, Any]]" [arg-type] -python-package/lightgbm/sklearn.py:235: error: Argument 1 has incompatible type "Union[List[float], List[int], ndarray[Any, Any], Any, None]"; expected "Optional[ndarray[Any, Any]]" [arg-type] -python-package/lightgbm/sklearn.py:235: note: Error code "arg-type" not covered by "type: ignore" comment -python-package/lightgbm/sklearn.py:237: error: Argument 1 has incompatible type "Union[List[float], List[int], ndarray[Any, Any], Any, None]"; expected "Optional[ndarray[Any, Any]]" [arg-type] -python-package/lightgbm/sklearn.py:237: note: Error code "arg-type" not covered by "type: ignore" comment -python-package/lightgbm/sklearn.py:237: error: Argument 3 has incompatible type "Union[List[float], List[int], ndarray[Any, Any], Any, None]"; expected "Optional[ndarray[Any, Any]]" [arg-type] -python-package/lightgbm/sklearn.py:239: error: Argument 1 has incompatible type "Union[List[float], List[int], ndarray[Any, Any], Any, None]"; expected "Optional[ndarray[Any, Any]]" [arg-type] -python-package/lightgbm/sklearn.py:239: note: Error code "arg-type" not covered by "type: ignore" comment -python-package/lightgbm/sklearn.py:239: error: Argument 3 has incompatible type "Union[List[float], List[int], ndarray[Any, Any], Any, None]"; expected "Optional[ndarray[Any, Any]]" [arg-type] -``` - - class _EvalFunctionWrapper: """Proxy class for evaluation function.""" From 48de34240dbff2931e816ed7623cd988c7b275f3 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Thu, 21 Sep 2023 23:25:01 -0500 Subject: [PATCH 03/10] fix tests --- tests/python_package_test/test_basic.py | 49 +++++++++++++++++++------ 1 file changed, 37 insertions(+), 12 deletions(-) diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 1e146806f005..d9b2dec80481 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -513,6 +513,7 @@ def test_dataset_construction_overwrites_user_provided_metadata_fields(): group=[1, 1], init_score=[0.312, 0.708], label=[1, 2], + position=np.array([0.0, 1.0], dtype=np.float32), weight=[0.5, 1.5], ) @@ -523,12 +524,22 @@ def test_dataset_construction_overwrites_user_provided_metadata_fields(): assert dtrain.get_init_score() == [0.312, 0.708] assert dtrain.label == [1, 2] assert dtrain.get_label() == [1, 2] + np.testing.assert_array_equal( + dtrain.position, + np.array([0.0, 1.0], dtype=np.float32), + strict=True + ) + np.testing.assert_array_equal( + dtrain.get_position(), + np.array([0.0, 1.0], dtype=np.float32), + strict=True + ) assert dtrain.weight == [0.5, 1.5] assert dtrain.get_weight() == [0.5, 1.5] - # before construction, get_field() raises an exception + # before construction, get_field() should raise an exception for field_name in ["group", "init_score", "label", "position", "weight"]: - with pytest.raises(Exception, match="Cannot get weight before construct Dataset") + with pytest.raises(Exception, match=f"Cannot get {field_name} before construct Dataset"): dtrain.get_field(field_name) # constructed, get_* methods should return numpy arrays, even when the provided @@ -538,22 +549,36 @@ def test_dataset_construction_overwrites_user_provided_metadata_fields(): assert dtrain.group == expected_group assert dtrain.get_group() == expected_group # get_field("group") returns a numpy array with boundaries, instead of size - assert dtrain.get_field("group") == np.array([0, 1, 2], dtype=int32) + np.testing.assert_array_equal( + dtrain.get_field("group"), + np.array([0, 1, 2], dtype=np.int32), + strict=True + ) expected_init_score = np.array([0.312, 0.708]) - assert np.testing.assert_array_equal(dtrain.init_score, expected_init_score, strict=True) - assert np.testing.assert_array_equal(dtrain.get_init_score(), expected_init_score, strict=True) - assert np.testing.assert_array_equal(dtrain.get_field("init_score"), expected_init_score, strict=True) + np.testing.assert_array_equal(dtrain.init_score, expected_init_score, strict=True) + np.testing.assert_array_equal(dtrain.get_init_score(), expected_init_score, strict=True) + np.testing.assert_array_equal(dtrain.get_field("init_score"), expected_init_score, strict=True) expected_label = np.array([1, 2], dtype=np.float32) - assert np.testing.assert_array_equal(dtrain.label, expected_label, strict=True) - assert np.testing.assert_array_equal(dtrain.get_label(), label, strict=True) - assert np.testing.assert_array_equal(dtrain.get_field("label"), label, strict=True) + np.testing.assert_array_equal(dtrain.label, expected_label, strict=True) + np.testing.assert_array_equal(dtrain.get_label(), expected_label, strict=True) + np.testing.assert_array_equal(dtrain.get_field("label"), expected_label, strict=True) + + expected_position = np.array([0.0, 1.0], dtype=np.float32) + np.testing.assert_array_equal(dtrain.position, expected_position, strict=True) + np.testing.assert_array_equal(dtrain.get_position(), expected_position, strict=True) + # NOTE: "position" is converted to int32 on thhe C++ side + np.testing.assert_array_equal( + dtrain.get_field("position"), + np.array([0.0, 1.0], dtype=np.int32), + strict=True + ) expected_weight = np.array([0.5, 1.5], dtype=np.float32) - assert np.testing.assert_array_equal(dtrain.weight, expected_weight, strict=True) - assert np.testing.assert_array_equal(dtrain.get_weight(), expected_weight, strict=True) - assert np.testing.assert_array_equal(dtrain.get_field("weight"), expected_weight, strict=True) + np.testing.assert_array_equal(dtrain.weight, expected_weight, strict=True) + np.testing.assert_array_equal(dtrain.get_weight(), expected_weight, strict=True) + np.testing.assert_array_equal(dtrain.get_field("weight"), expected_weight, strict=True) def test_choose_param_value(): From 8874f3b6be49b24e141aa86616976faca43f2613 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Thu, 21 Sep 2023 23:57:55 -0500 Subject: [PATCH 04/10] fix tests on CUDA --- tests/python_package_test/test_basic.py | 46 ++++++++++++++----------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index d9b2dec80481..d4ffd9a24dfc 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -503,6 +503,10 @@ def test_dataset_construction_overwrites_user_provided_metadata_fields(): X = np.array([[1.0, 2.0], [3.0, 4.0]]) + position=np.array([0.0, 1.0], dtype=np.float32) + if getenv('TASK', '') == 'cuda': + position = None + dtrain = lgb.Dataset( X, params={ @@ -513,7 +517,7 @@ def test_dataset_construction_overwrites_user_provided_metadata_fields(): group=[1, 1], init_score=[0.312, 0.708], label=[1, 2], - position=np.array([0.0, 1.0], dtype=np.float32), + position=position, weight=[0.5, 1.5], ) @@ -524,16 +528,17 @@ def test_dataset_construction_overwrites_user_provided_metadata_fields(): assert dtrain.get_init_score() == [0.312, 0.708] assert dtrain.label == [1, 2] assert dtrain.get_label() == [1, 2] - np.testing.assert_array_equal( - dtrain.position, - np.array([0.0, 1.0], dtype=np.float32), - strict=True - ) - np.testing.assert_array_equal( - dtrain.get_position(), - np.array([0.0, 1.0], dtype=np.float32), - strict=True - ) + if getenv('TASK', '') != 'cuda': + np.testing.assert_array_equal( + dtrain.position, + np.array([0.0, 1.0], dtype=np.float32), + strict=True + ) + np.testing.assert_array_equal( + dtrain.get_position(), + np.array([0.0, 1.0], dtype=np.float32), + strict=True + ) assert dtrain.weight == [0.5, 1.5] assert dtrain.get_weight() == [0.5, 1.5] @@ -565,15 +570,16 @@ def test_dataset_construction_overwrites_user_provided_metadata_fields(): np.testing.assert_array_equal(dtrain.get_label(), expected_label, strict=True) np.testing.assert_array_equal(dtrain.get_field("label"), expected_label, strict=True) - expected_position = np.array([0.0, 1.0], dtype=np.float32) - np.testing.assert_array_equal(dtrain.position, expected_position, strict=True) - np.testing.assert_array_equal(dtrain.get_position(), expected_position, strict=True) - # NOTE: "position" is converted to int32 on thhe C++ side - np.testing.assert_array_equal( - dtrain.get_field("position"), - np.array([0.0, 1.0], dtype=np.int32), - strict=True - ) + if getenv('TASK', '') != 'cuda': + expected_position = np.array([0.0, 1.0], dtype=np.float32) + np.testing.assert_array_equal(dtrain.position, expected_position, strict=True) + np.testing.assert_array_equal(dtrain.get_position(), expected_position, strict=True) + # NOTE: "position" is converted to int32 on thhe C++ side + np.testing.assert_array_equal( + dtrain.get_field("position"), + np.array([0.0, 1.0], dtype=np.int32), + strict=True + ) expected_weight = np.array([0.5, 1.5], dtype=np.float32) np.testing.assert_array_equal(dtrain.weight, expected_weight, strict=True) From 7ea687ec7d561bf3c3e7a037c409de6ef9abe9bc Mon Sep 17 00:00:00 2001 From: James Lamb Date: Fri, 22 Sep 2023 00:10:42 -0500 Subject: [PATCH 05/10] fix comment --- tests/python_package_test/test_basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index d4ffd9a24dfc..7441a3520f65 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -574,7 +574,7 @@ def test_dataset_construction_overwrites_user_provided_metadata_fields(): expected_position = np.array([0.0, 1.0], dtype=np.float32) np.testing.assert_array_equal(dtrain.position, expected_position, strict=True) np.testing.assert_array_equal(dtrain.get_position(), expected_position, strict=True) - # NOTE: "position" is converted to int32 on thhe C++ side + # NOTE: "position" is converted to int32 on the C++ side np.testing.assert_array_equal( dtrain.get_field("position"), np.array([0.0, 1.0], dtype=np.int32), From 7386ecd60b8881382afeeb38b7e0fc70769d25ac Mon Sep 17 00:00:00 2001 From: James Lamb Date: Fri, 22 Sep 2023 13:04:25 -0500 Subject: [PATCH 06/10] fix compatibility with np.testing.assert_array_equal() --- tests/python_package_test/test_basic.py | 32 ++++++++++++------------- tests/python_package_test/utils.py | 20 ++++++++++++++++ 2 files changed, 36 insertions(+), 16 deletions(-) diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 7441a3520f65..49ac7db789de 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -15,7 +15,7 @@ import lightgbm as lgb from lightgbm.compat import PANDAS_INSTALLED, pd_DataFrame, pd_Series -from .utils import dummy_obj, load_breast_cancer, mse_obj +from .utils import dummy_obj, load_breast_cancer, mse_obj, np_assert_array_equal def test_basic(tmp_path): @@ -529,12 +529,12 @@ def test_dataset_construction_overwrites_user_provided_metadata_fields(): assert dtrain.label == [1, 2] assert dtrain.get_label() == [1, 2] if getenv('TASK', '') != 'cuda': - np.testing.assert_array_equal( + np_assert_array_equal( dtrain.position, np.array([0.0, 1.0], dtype=np.float32), strict=True ) - np.testing.assert_array_equal( + np_assert_array_equal( dtrain.get_position(), np.array([0.0, 1.0], dtype=np.float32), strict=True @@ -554,37 +554,37 @@ def test_dataset_construction_overwrites_user_provided_metadata_fields(): assert dtrain.group == expected_group assert dtrain.get_group() == expected_group # get_field("group") returns a numpy array with boundaries, instead of size - np.testing.assert_array_equal( + np_assert_array_equal( dtrain.get_field("group"), np.array([0, 1, 2], dtype=np.int32), strict=True ) expected_init_score = np.array([0.312, 0.708]) - np.testing.assert_array_equal(dtrain.init_score, expected_init_score, strict=True) - np.testing.assert_array_equal(dtrain.get_init_score(), expected_init_score, strict=True) - np.testing.assert_array_equal(dtrain.get_field("init_score"), expected_init_score, strict=True) + np_assert_array_equal(dtrain.init_score, expected_init_score, strict=True) + np_assert_array_equal(dtrain.get_init_score(), expected_init_score, strict=True) + np_assert_array_equal(dtrain.get_field("init_score"), expected_init_score, strict=True) expected_label = np.array([1, 2], dtype=np.float32) - np.testing.assert_array_equal(dtrain.label, expected_label, strict=True) - np.testing.assert_array_equal(dtrain.get_label(), expected_label, strict=True) - np.testing.assert_array_equal(dtrain.get_field("label"), expected_label, strict=True) + np_assert_array_equal(dtrain.label, expected_label, strict=True) + np_assert_array_equal(dtrain.get_label(), expected_label, strict=True) + np_assert_array_equal(dtrain.get_field("label"), expected_label, strict=True) if getenv('TASK', '') != 'cuda': expected_position = np.array([0.0, 1.0], dtype=np.float32) - np.testing.assert_array_equal(dtrain.position, expected_position, strict=True) - np.testing.assert_array_equal(dtrain.get_position(), expected_position, strict=True) + np_assert_array_equal(dtrain.position, expected_position, strict=True) + np_assert_array_equal(dtrain.get_position(), expected_position, strict=True) # NOTE: "position" is converted to int32 on the C++ side - np.testing.assert_array_equal( + np_assert_array_equal( dtrain.get_field("position"), np.array([0.0, 1.0], dtype=np.int32), strict=True ) expected_weight = np.array([0.5, 1.5], dtype=np.float32) - np.testing.assert_array_equal(dtrain.weight, expected_weight, strict=True) - np.testing.assert_array_equal(dtrain.get_weight(), expected_weight, strict=True) - np.testing.assert_array_equal(dtrain.get_field("weight"), expected_weight, strict=True) + np_assert_array_equal(dtrain.weight, expected_weight, strict=True) + np_assert_array_equal(dtrain.get_weight(), expected_weight, strict=True) + np_assert_array_equal(dtrain.get_field("weight"), expected_weight, strict=True) def test_choose_param_value(): diff --git a/tests/python_package_test/utils.py b/tests/python_package_test/utils.py index df01e29852e7..b83fa64d1801 100644 --- a/tests/python_package_test/utils.py +++ b/tests/python_package_test/utils.py @@ -1,6 +1,7 @@ # coding: utf-8 import pickle from functools import lru_cache +from inspect import getfullargspec import cloudpickle import joblib @@ -193,3 +194,22 @@ def pickle_and_unpickle_object(obj, serializer): serializer=serializer ) return obj_from_disk # noqa: RET504 + + +# doing this here, at import time, to ensure it only runs once_per import +# instead of once per assertion +_numpy_testing_supports_strict_kwarg = ( + "strict" in getfullargspec(np.testing.assert_array_equal).kwonlyargs +) + + +def np_assert_array_equal(*args, **kwargs): + """ + np.testing.assert_array_equal() only got the kwarg ``strict`` in June 2022: + https://github.com/numpy/numpy/pull/21595 + + This function is here for testing on older Python (and therefore ``numpy``) + """ + if not _numpy_testing_supports_strict_kwarg: + kawrgs.pop("strict") + np.testing.assert_array_equal(*args, **kwargs) From ff01f800c1ce7e50ed7748c307227dde1c29bbeb Mon Sep 17 00:00:00 2001 From: James Lamb Date: Fri, 22 Sep 2023 13:28:21 -0500 Subject: [PATCH 07/10] fix --- tests/python_package_test/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python_package_test/utils.py b/tests/python_package_test/utils.py index b83fa64d1801..7eae62b14369 100644 --- a/tests/python_package_test/utils.py +++ b/tests/python_package_test/utils.py @@ -211,5 +211,5 @@ def np_assert_array_equal(*args, **kwargs): This function is here for testing on older Python (and therefore ``numpy``) """ if not _numpy_testing_supports_strict_kwarg: - kawrgs.pop("strict") + kwargs.pop("strict") np.testing.assert_array_equal(*args, **kwargs) From 017e5e5703bf14f5828fa74d4854b4c72a91f953 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Thu, 5 Oct 2023 23:40:18 -0500 Subject: [PATCH 08/10] less expensive access pattern --- python-package/lightgbm/basic.py | 15 +++++--- python-package/lightgbm/sklearn.py | 61 +++++++++++++++++++++--------- 2 files changed, 54 insertions(+), 22 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 9a84e5f8e2a3..d8389e3f851a 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -2850,8 +2850,9 @@ def get_label(self) -> Optional[_LGBM_LabelType]: Returns ------- - label : numpy array or None + label : list, numpy 1-D array, pandas Series / one-column DataFrame or None The label information from the Dataset. + For a constructed ``Dataset``, this will only return a numpy array. """ if self.label is None: self.label = self.get_field('label') @@ -2862,8 +2863,9 @@ def get_weight(self) -> Optional[_LGBM_WeightType]: Returns ------- - weight : numpy array or None + weight : list, numpy 1-D array, pandas Series or None Weight for each data point from the Dataset. Weights should be non-negative. + For a constructed ``Dataset``, this will only return ``None`` or a numpy array. """ if self.weight is None: self.weight = self.get_field('weight') @@ -2874,8 +2876,9 @@ def get_init_score(self) -> Optional[_LGBM_InitScoreType]: Returns ------- - init_score : numpy array or None + init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), or None Init score of Booster. + For a constructed ``Dataset``, this will only return ``None`` or a numpy array. """ if self.init_score is None: self.init_score = self.get_field('init_score') @@ -2918,12 +2921,13 @@ def get_group(self) -> Optional[_LGBM_GroupType]: Returns ------- - group : numpy array or None + group : list, numpy 1-D array, pandas Series or None Group/query data. Only used in the learning-to-rank task. sum(group) = n_samples. For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups, where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc. + For a constructed ``Dataset``, this will only return ``None`` or a numpy array. """ if self.group is None: self.group = self.get_field('group') @@ -2937,8 +2941,9 @@ def get_position(self) -> Optional[_LGBM_PositionType]: Returns ------- - position : numpy 1-D array or None + position : numpy 1-D array, pandas Series or None Position of items used in unbiased learning-to-rank task. + For a constructed ``Dataset``, this will only return ``None`` or a numpy array. """ if self.position is None: self.position = self.get_field('position') diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 450757c1e58f..9e194c0d53d9 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -86,6 +86,26 @@ _LGBM_ScikitValidSet = Tuple[_LGBM_ScikitMatrixLike, _LGBM_LabelType] +def _get_label_from_constructed_dataset(dataset: Dataset) -> np.ndarray: + label = dataset.get_label() + error_msg = ( + "Estimators in lightgbm.sklearn should only retrieve labels from a constructed Dataset. " + "If you're seeing this message, it's a bug in lightgbm. Please report it at https://github.com/microsoft/LightGBM/issues." + ) + assert isinstance(label, np.ndarray), error_msg + return label + + +def _get_weight_from_constructed_dataset(dataset: Dataset) -> Optional[np.ndarray]: + weight = dataset.get_weight() + error_msg = ( + "Estimators in lightgbm.sklearn should only retrieve weights from a constructed Dataset. " + "If you're seeing this message, it's a bug in lightgbm. Please report it at https://github.com/microsoft/LightGBM/issues." + ) + assert (weight is None or isinstance(weight, np.ndarray)), error_msg + return weight + + class _ObjectiveFunctionWrapper: """Proxy class for objective function.""" @@ -151,21 +171,25 @@ def __call__(self, preds: np.ndarray, dataset: Dataset) -> Tuple[np.ndarray, np. The value of the second order derivative (Hessian) of the loss with respect to the elements of preds for each sample point. """ - labels = dataset.get_field("label") + labels = _get_label_from_constructed_dataset(dataset) argc = len(signature(self.func).parameters) if argc == 2: grad, hess = self.func(labels, preds) # type: ignore[call-arg] - elif argc == 3: - grad, hess = self.func(labels, preds, dataset.get_field("weight")) # type: ignore[call-arg] - elif argc == 4: + return grad, hess + + weight = _get_weight_from_constructed_dataset(dataset) + if argc == 3: + grad, hess = self.func(labels, preds, weight) # type: ignore[call-arg] + return grad, hess + + if argc == 4: group = dataset.get_field("group") if group is not None: - return self.func(labels, preds, dataset.get_field("weight"), np.diff(group)) # type: ignore[call-arg] + return self.func(labels, preds, weight, np.diff(group)) # type: ignore[call-arg] else: - return self.func(labels, preds, dataset.get_field("weight"), group) # type: ignore[call-arg] - else: - raise TypeError(f"Self-defined objective function should have 2, 3 or 4 arguments, got {argc}") - return grad, hess + return self.func(labels, preds, weight, group) # type: ignore[call-arg] + + raise TypeError(f"Self-defined objective function should have 2, 3 or 4 arguments, got {argc}") class _EvalFunctionWrapper: @@ -233,20 +257,23 @@ def __call__( is_higher_better : bool Is eval result higher better, e.g. AUC is ``is_higher_better``. """ - labels = dataset.get_field("label") + labels = _get_label_from_constructed_dataset(dataset) argc = len(signature(self.func).parameters) if argc == 2: return self.func(labels, preds) # type: ignore[call-arg] - elif argc == 3: - return self.func(labels, preds, dataset.get_field("weight")) # type: ignore[call-arg] - elif argc == 4: + + weight = _get_weight_from_constructed_dataset(dataset) + if argc == 3: + return self.func(labels, preds, weight) # type: ignore[call-arg] + + if argc == 4: group = dataset.get_field("group") if group is not None: - return self.func(labels, preds, dataset.get_field("weight"), np.diff(group)) # type: ignore[call-arg] + return self.func(labels, preds, weight, np.diff(group)) # type: ignore[call-arg] else: - return self.func(labels, preds, dataset.get_field("weight"), group) # type: ignore[call-arg] - else: - raise TypeError(f"Self-defined eval function should have 2, 3 or 4 arguments, got {argc}") + return self.func(labels, preds, weight, group) # type: ignore[call-arg] + + raise TypeError(f"Self-defined eval function should have 2, 3 or 4 arguments, got {argc}") # documentation templates for LGBMModel methods are shared between the classes in From ada18f80441262ce3a191c0c55993e0b37ff3f68 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Thu, 26 Oct 2023 22:45:46 -0500 Subject: [PATCH 09/10] pass 'group' as numpy array of boundaries --- python-package/lightgbm/basic.py | 4 ++++ python-package/lightgbm/sklearn.py | 18 ++++++++++++++---- tests/python_package_test/test_basic.py | 8 ++++---- 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index d8389e3f851a..e8fef8722b49 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -2782,6 +2782,10 @@ def set_group( if self._handle is not None and group is not None: group = _list_to_1d_numpy(group, dtype=np.int32, name='group') self.set_field('group', group) + # original values can be modified at cpp side + constructed_group = self.get_field('group') + if constructed_group is not None: + self.group = np.diff(constructed_group) return self def set_position( diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 9e194c0d53d9..c334aa792330 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -86,6 +86,16 @@ _LGBM_ScikitValidSet = Tuple[_LGBM_ScikitMatrixLike, _LGBM_LabelType] +def _get_group_from_constructed_dataset(dataset: Dataset) -> Optional[np.ndarray]: + group = dataset.get_group() + error_msg = ( + "Estimators in lightgbm.sklearn should only retrieve query groups from a constructed Dataset. " + "If you're seeing this message, it's a bug in lightgbm. Please report it at https://github.com/microsoft/LightGBM/issues." + ) + assert (group is None or isinstance(group, np.ndarray)), error_msg + return group + + def _get_label_from_constructed_dataset(dataset: Dataset) -> np.ndarray: label = dataset.get_label() error_msg = ( @@ -183,9 +193,9 @@ def __call__(self, preds: np.ndarray, dataset: Dataset) -> Tuple[np.ndarray, np. return grad, hess if argc == 4: - group = dataset.get_field("group") + group = _get_group_from_constructed_dataset(dataset) if group is not None: - return self.func(labels, preds, weight, np.diff(group)) # type: ignore[call-arg] + return self.func(labels, preds, weight, group) # type: ignore[call-arg] else: return self.func(labels, preds, weight, group) # type: ignore[call-arg] @@ -267,9 +277,9 @@ def __call__( return self.func(labels, preds, weight) # type: ignore[call-arg] if argc == 4: - group = dataset.get_field("group") + group = _get_group_from_constructed_dataset(dataset) if group is not None: - return self.func(labels, preds, weight, np.diff(group)) # type: ignore[call-arg] + return self.func(labels, preds, weight, group) # type: ignore[call-arg] else: return self.func(labels, preds, weight, group) # type: ignore[call-arg] diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 49ac7db789de..2ab80e051ca0 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -550,9 +550,9 @@ def test_dataset_construction_overwrites_user_provided_metadata_fields(): # constructed, get_* methods should return numpy arrays, even when the provided # input was a list of floats or ints dtrain.construct() - expected_group = [1, 1] - assert dtrain.group == expected_group - assert dtrain.get_group() == expected_group + expected_group = np.array([1, 1], dtype=np.int32) + np_assert_array_equal(dtrain.group, expected_group, strict=True) + np_assert_array_equal(dtrain.get_group(), expected_group, strict=True) # get_field("group") returns a numpy array with boundaries, instead of size np_assert_array_equal( dtrain.get_field("group"), @@ -560,7 +560,7 @@ def test_dataset_construction_overwrites_user_provided_metadata_fields(): strict=True ) - expected_init_score = np.array([0.312, 0.708]) + expected_init_score = np.array([0.312, 0.708],) np_assert_array_equal(dtrain.init_score, expected_init_score, strict=True) np_assert_array_equal(dtrain.get_init_score(), expected_init_score, strict=True) np_assert_array_equal(dtrain.get_field("init_score"), expected_init_score, strict=True) From c0b507e631e147111eda31a8c085a537f33c19e2 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Tue, 7 Nov 2023 10:17:08 -0600 Subject: [PATCH 10/10] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: José Morales --- python-package/lightgbm/sklearn.py | 10 ++-------- tests/python_package_test/test_basic.py | 2 +- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index c334aa792330..310d5d2ca6ea 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -194,10 +194,7 @@ def __call__(self, preds: np.ndarray, dataset: Dataset) -> Tuple[np.ndarray, np. if argc == 4: group = _get_group_from_constructed_dataset(dataset) - if group is not None: - return self.func(labels, preds, weight, group) # type: ignore[call-arg] - else: - return self.func(labels, preds, weight, group) # type: ignore[call-arg] + return self.func(labels, preds, weight, group) # type: ignore[call-arg] raise TypeError(f"Self-defined objective function should have 2, 3 or 4 arguments, got {argc}") @@ -278,10 +275,7 @@ def __call__( if argc == 4: group = _get_group_from_constructed_dataset(dataset) - if group is not None: - return self.func(labels, preds, weight, group) # type: ignore[call-arg] - else: - return self.func(labels, preds, weight, group) # type: ignore[call-arg] + return self.func(labels, preds, weight, group) # type: ignore[call-arg] raise TypeError(f"Self-defined eval function should have 2, 3 or 4 arguments, got {argc}") diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 2ab80e051ca0..2f6b07e7a77f 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -503,7 +503,7 @@ def test_dataset_construction_overwrites_user_provided_metadata_fields(): X = np.array([[1.0, 2.0], [3.0, 4.0]]) - position=np.array([0.0, 1.0], dtype=np.float32) + position = np.array([0.0, 1.0], dtype=np.float32) if getenv('TASK', '') == 'cuda': position = None