diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index c02b4642b58..c8cf81b223a 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -1339,6 +1339,7 @@ def set_label(self, label): if self.handle is not None: label = list_to_1d_numpy(_label_from_pandas(label), name='label') self.set_field('label', label) + self.label = self.get_field('label') # original values can be modified at cpp side return self def set_weight(self, weight): @@ -1360,6 +1361,7 @@ def set_weight(self, weight): if self.handle is not None and weight is not None: weight = list_to_1d_numpy(weight, name='weight') self.set_field('weight', weight) + self.weight = self.get_field('weight') # original values can be modified at cpp side return self def set_init_score(self, init_score): @@ -1379,6 +1381,7 @@ def set_init_score(self, init_score): if self.handle is not None and init_score is not None: init_score = list_to_1d_numpy(init_score, np.float64, name='init_score') self.set_field('init_score', init_score) + self.init_score = self.get_field('init_score') # original values can be modified at cpp side return self def set_group(self, group): diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 454a999fbf3..5d6509f5d3b 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -281,3 +281,34 @@ def test_cegb_scaling_equalities(self): with open(p2name, 'rt') as f: p2txt = f.read() self.assertEqual(p1txt, p2txt) + + def test_consistent_state_for_dataset_fields(self): + + def check_asserts(data): + np.testing.assert_allclose(data.label, data.get_label()) + np.testing.assert_allclose(data.label, data.get_field('label')) + self.assertFalse(np.isnan(data.label[0])) + self.assertFalse(np.isinf(data.label[1])) + np.testing.assert_allclose(data.weight, data.get_weight()) + np.testing.assert_allclose(data.weight, data.get_field('weight')) + self.assertFalse(np.isnan(data.weight[0])) + self.assertFalse(np.isinf(data.weight[1])) + np.testing.assert_allclose(data.init_score, data.get_init_score()) + np.testing.assert_allclose(data.init_score, data.get_field('init_score')) + self.assertFalse(np.isnan(data.init_score[0])) + self.assertFalse(np.isinf(data.init_score[1])) + self.assertTrue(np.all(np.isclose([data.label[0], data.weight[0], data.init_score[0]], + data.label[0]))) + self.assertAlmostEqual(data.label[1], data.weight[1]) + + X, y = load_breast_cancer(True) + sequence = np.ones(y.shape[0]) + sequence[0] = np.nan + sequence[1] = np.inf + lgb_data = lgb.Dataset(X, sequence, weight=sequence, init_score=sequence).construct() + check_asserts(lgb_data) + lgb_data = lgb.Dataset(X, y).construct() + lgb_data.set_label(sequence) + lgb_data.set_weight(sequence) + lgb_data.set_init_score(sequence) + check_asserts(lgb_data)