Skip to content

Commit

Permalink
[python] avoid to set all weight to 1. (#1152)
Browse files Browse the repository at this point in the history
* avoid to set all weight to 1.

* fix valid_class_weight
  • Loading branch information
guolinke committed Dec 29, 2017
1 parent 4271082 commit 3f5b313
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
2 changes: 2 additions & 0 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,8 @@ def set_weight(self, weight):
weight : list, numpy array or None
Weight to be set for each data point.
"""
if weight is not None and np.all(weight == 1):
weight = None
self.weight = weight
if self.handle is not None and weight is not None:
weight = list_to_1d_numpy(weight, name='weight')
Expand Down
22 changes: 12 additions & 10 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,11 +418,12 @@ def fit(self, X, y,
X, y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2)
_LGBMCheckConsistentLength(X, y, sample_weight)

class_sample_weight = _LGBMComputeSampleWeight(self.class_weight, y)
if sample_weight is None or len(sample_weight) == 0:
sample_weight = class_sample_weight
else:
sample_weight = np.multiply(sample_weight, class_sample_weight)
if self.class_weight is not None:
class_sample_weight = _LGBMComputeSampleWeight(self.class_weight, y)
if sample_weight is None or len(sample_weight) == 0:
sample_weight = class_sample_weight
else:
sample_weight = np.multiply(sample_weight, class_sample_weight)

self._n_features = X.shape[1]

Expand Down Expand Up @@ -452,11 +453,12 @@ def get_meta_data(collection, i):
else:
raise TypeError('eval_sample_weight, eval_class_weight, eval_init_score, and eval_group should be dict or list')
valid_weight = get_meta_data(eval_sample_weight, i)
valid_class_sample_weight = _LGBMComputeSampleWeight(get_meta_data(eval_class_weight, i), valid_data[1])
if valid_weight is None or len(valid_weight) == 0:
valid_weight = valid_class_sample_weight
else:
valid_weight = np.multiply(valid_weight, valid_class_sample_weight)
if get_meta_data(eval_class_weight, i) is not None:
valid_class_sample_weight = _LGBMComputeSampleWeight(get_meta_data(eval_class_weight, i), valid_data[1])
if valid_weight is None or len(valid_weight) == 0:
valid_weight = valid_class_sample_weight
else:
valid_weight = np.multiply(valid_weight, valid_class_sample_weight)
valid_init_score = get_meta_data(eval_init_score, i)
valid_group = get_meta_data(eval_group, i)
valid_set = _construct_dataset(valid_data[0], valid_data[1], valid_weight, valid_init_score, valid_group, params)
Expand Down

0 comments on commit 3f5b313

Please sign in to comment.