Skip to content

Commit

Permalink
[python] added class_weight parameter in sklearn wrapper (#1114)
Browse files Browse the repository at this point in the history
* added class_weight parameter in sklearn wrapper

* added note about alternative parameters in case of binary classification task
  • Loading branch information
StrikerRUS committed Dec 25, 2017
1 parent 4bbe17f commit d110d6f
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 8 deletions.
3 changes: 3 additions & 0 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class DataFrame(object):
from sklearn.base import BaseEstimator
from sklearn.base import RegressorMixin, ClassifierMixin
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_sample_weight
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.validation import check_X_y, check_array, check_consistent_length
try:
Expand All @@ -83,6 +84,7 @@ class DataFrame(object):
_LGBMCheckArray = check_array
_LGBMCheckConsistentLength = check_consistent_length
_LGBMCheckClassificationTargets = check_classification_targets
_LGBMComputeSampleWeight = compute_sample_weight
except ImportError:
SKLEARN_INSTALLED = False
_LGBMModelBase = object
Expand All @@ -96,6 +98,7 @@ class DataFrame(object):
_LGBMCheckArray = None
_LGBMCheckConsistentLength = None
_LGBMCheckClassificationTargets = None
_LGBMComputeSampleWeight = None


# DeprecationWarning is not shown by default, so let's create our own with higher level
Expand Down
48 changes: 40 additions & 8 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from .compat import (SKLEARN_INSTALLED, _LGBMClassifierBase,
LGBMNotFittedError, _LGBMLabelEncoder, _LGBMModelBase,
_LGBMRegressorBase, _LGBMCheckXY, _LGBMCheckArray, _LGBMCheckConsistentLength,
_LGBMCheckClassificationTargets, argc_, range_, LGBMDeprecationWarning)
_LGBMCheckClassificationTargets, _LGBMComputeSampleWeight,
argc_, range_, LGBMDeprecationWarning)
from .engine import train


Expand Down Expand Up @@ -134,7 +135,7 @@ class LGBMModel(_LGBMModelBase):

def __init__(self, boosting_type="gbdt", num_leaves=31, max_depth=-1,
learning_rate=0.1, n_estimators=100,
subsample_for_bin=200000, objective=None,
subsample_for_bin=200000, objective=None, class_weight=None,
min_split_gain=0., min_child_weight=1e-3, min_child_samples=20,
subsample=1., subsample_freq=1, colsample_bytree=1.,
reg_alpha=0., reg_lambda=0., random_state=None,
Expand Down Expand Up @@ -162,6 +163,15 @@ def __init__(self, boosting_type="gbdt", num_leaves=31, max_depth=-1,
Specify the learning task and the corresponding learning objective or
a custom objective function to be used (see note below).
default: 'regression' for LGBMRegressor, 'binary' or 'multiclass' for LGBMClassifier, 'lambdarank' for LGBMRanker.
class_weight : dict, 'balanced' or None, optional (default=None)
Weights associated with classes in the form ``{class_label: weight}``.
Use this parameter only for multi-class classification task;
for binary classification task you may use ``is_unbalance`` or ``scale_pos_weight`` parameters.
The 'balanced' mode uses the values of y to automatically adjust weights
inversely proportional to class frequencies in the input data as ``n_samples / (n_classes * np.bincount(y))``.
If None, all classes are supposed to have weight one.
Note that these weights will be multiplied with ``sample_weight`` (passed through the fit method)
if ``sample_weight`` is specified.
min_split_gain : float, optional (default=0.)
Minimum loss reduction required to make a further partition on a leaf node of the tree.
min_child_weight : float, optional (default=1e-3)
Expand Down Expand Up @@ -262,6 +272,7 @@ def __init__(self, boosting_type="gbdt", num_leaves=31, max_depth=-1,
self._best_iteration = None
self._other_params = {}
self._objective = objective
self.class_weight = class_weight
self._n_features = None
self._classes = None
self._n_classes = None
Expand All @@ -284,9 +295,9 @@ def set_params(self, **params):
def fit(self, X, y,
sample_weight=None, init_score=None, group=None,
eval_set=None, eval_names=None, eval_sample_weight=None,
eval_init_score=None, eval_group=None, eval_metric=None,
early_stopping_rounds=None, verbose=True, feature_name='auto',
categorical_feature='auto', callbacks=None):
eval_class_weight=None, eval_init_score=None, eval_group=None,
eval_metric=None, early_stopping_rounds=None, verbose=True,
feature_name='auto', categorical_feature='auto', callbacks=None):
"""Build a gradient boosting model from the training set (X, y).
Parameters
Expand All @@ -303,10 +314,12 @@ def fit(self, X, y,
Group data of training data.
eval_set : list or None, optional (default=None)
A list of (X, y) tuple pairs to use as a validation sets for early-stopping.
eval_names: list of strings or None, optional (default=None)
eval_names : list of strings or None, optional (default=None)
Names of eval_set.
eval_sample_weight : list of arrays or None, optional (default=None)
Weights of eval data.
eval_class_weight : list or None, optional (default=None)
Class weights of eval data.
eval_init_score : list of arrays or None, optional (default=None)
Init score of eval data.
eval_group : list of arrays or None, optional (default=None)
Expand Down Expand Up @@ -386,6 +399,7 @@ def fit(self, X, y,
params['verbose'] = -1
params.pop('silent', None)
params.pop('n_estimators', None)
params.pop('class_weight', None)
if self._n_classes is not None and self._n_classes > 2:
params['num_class'] = self._n_classes
if hasattr(self, '_eval_at'):
Expand All @@ -404,6 +418,12 @@ def fit(self, X, y,
X, y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2)
_LGBMCheckConsistentLength(X, y, sample_weight)

class_sample_weight = _LGBMComputeSampleWeight(self.class_weight, y)
if sample_weight is None or len(sample_weight) == 0:
sample_weight = class_sample_weight
else:
sample_weight = np.multiply(sample_weight, class_sample_weight)

self._n_features = X.shape[1]

def _construct_dataset(X, y, sample_weight, init_score, group, params):
Expand All @@ -430,8 +450,13 @@ def get_meta_data(collection, i):
elif isinstance(collection, dict):
return collection.get(i, None)
else:
raise TypeError('eval_sample_weight, eval_init_score, and eval_group should be dict or list')
raise TypeError('eval_sample_weight, eval_class_weight, eval_init_score, and eval_group should be dict or list')
valid_weight = get_meta_data(eval_sample_weight, i)
valid_class_sample_weight = _LGBMComputeSampleWeight(get_meta_data(eval_class_weight, i), valid_data[1])
if valid_weight is None or len(valid_weight) == 0:
valid_weight = valid_class_sample_weight
else:
valid_weight = np.multiply(valid_weight, valid_class_sample_weight)
valid_init_score = get_meta_data(eval_init_score, i)
valid_group = get_meta_data(eval_group, i)
valid_set = _construct_dataset(valid_data[0], valid_data[1], valid_weight, valid_init_score, valid_group, params)
Expand Down Expand Up @@ -592,6 +617,9 @@ def fit(self, X, y,
return self

base_doc = LGBMModel.fit.__doc__
fit.__doc__ = (base_doc[:base_doc.find('eval_class_weight :')] +
base_doc[base_doc.find('eval_init_score :'):])
base_doc = fit.__doc__
fit.__doc__ = (base_doc[:base_doc.find('eval_metric :')] +
'eval_metric : string, list of strings, callable or None, optional (default="l2")\n' +
base_doc[base_doc.find(' If string, it should be a built-in evaluation metric to use.'):])
Expand All @@ -603,7 +631,7 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
def fit(self, X, y,
sample_weight=None, init_score=None,
eval_set=None, eval_names=None, eval_sample_weight=None,
eval_init_score=None, eval_metric="logloss",
eval_class_weight=None, eval_init_score=None, eval_metric="logloss",
early_stopping_rounds=None, verbose=True,
feature_name='auto', categorical_feature='auto', callbacks=None):
_LGBMCheckClassificationTargets(y)
Expand Down Expand Up @@ -639,6 +667,7 @@ def fit(self, X, y,
init_score=init_score, eval_set=eval_set,
eval_names=eval_names,
eval_sample_weight=eval_sample_weight,
eval_class_weight=eval_class_weight,
eval_init_score=eval_init_score,
eval_metric=eval_metric,
early_stopping_rounds=early_stopping_rounds,
Expand Down Expand Up @@ -742,6 +771,9 @@ def fit(self, X, y,
return self

base_doc = LGBMModel.fit.__doc__
fit.__doc__ = (base_doc[:base_doc.find('eval_class_weight :')] +
base_doc[base_doc.find('eval_init_score :'):])
base_doc = fit.__doc__
fit.__doc__ = (base_doc[:base_doc.find('eval_metric :')] +
'eval_metric : string, list of strings, callable or None, optional (default="ndcg")\n' +
base_doc[base_doc.find(' If string, it should be a built-in evaluation metric to use.'):base_doc.find('early_stopping_rounds :')] +
Expand Down

0 comments on commit d110d6f

Please sign in to comment.