Skip to content

Commit

Permalink
support kwargs for sklearn (#577)
Browse files Browse the repository at this point in the history
* support kwargs for sklearn

* add link to Parameters
  • Loading branch information
wxchan authored and guolinke committed Jun 2, 2017
1 parent 4894cc4 commit e465f92
Showing 1 changed file with 29 additions and 137 deletions.
166 changes: 29 additions & 137 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,10 @@ class LGBMModel(LGBMModelBase):

def __init__(self, boosting_type="gbdt", num_leaves=31, max_depth=-1,
learning_rate=0.1, n_estimators=10, max_bin=255,
subsample_for_bin=50000, objective="regression",
subsample_for_bin=50000, objective=None,
min_split_gain=0, min_child_weight=5, min_child_samples=10,
subsample=1, subsample_freq=1, colsample_bytree=1,
reg_alpha=0, reg_lambda=0, scale_pos_weight=1,
is_unbalance=False, seed=0, nthread=-1, silent=True,
sigmoid=1.0, huber_delta=1.0, gaussian_eta=1.0, fair_c=1.0,
poisson_max_delta_step=0.7,
max_position=20, label_gain=None,
drop_rate=0.1, skip_drop=0.5, max_drop=50,
uniform_drop=False, xgboost_dart_mode=False, use_missing=True):
reg_alpha=0, reg_lambda=0, seed=0, nthread=-1, silent=True, **kwargs):
"""
Implementation of the Scikit-Learn API for LightGBM.
Expand Down Expand Up @@ -174,45 +168,15 @@ def __init__(self, boosting_type="gbdt", num_leaves=31, max_depth=-1,
L1 regularization term on weights
reg_lambda : float
L2 regularization term on weights
scale_pos_weight : float
Balancing of positive and negative weights.
is_unbalance : bool
Is unbalance for binary classification
seed : int
Random number seed.
nthread : int
Number of parallel threads
silent : boolean
Whether to print messages while running boosting.
sigmoid : float
Only used in binary classification and lambdarank. Parameter for sigmoid function.
huber_delta : float
Only used in regression. Parameter for Huber loss function.
gaussian_eta : float
Only used in regression. Parameter for L1 and Huber loss function.
It is used to control the width of Gaussian function to approximate hessian.
fair_c : float
Only used in regression. Parameter for Fair loss function.
poisson_max_delta_step : float
parameter used to safeguard optimization in Poisson regression.
max_position : int
Only used in lambdarank, will optimize NDCG at this position.
label_gain : list of float
Only used in lambdarank, relevant gain for labels.
For example, the gain of label 2 is 3 if using default label gains.
None (default) means use default value of CLI version: {0,1,3,7,15,31,63,...}.
drop_rate : float
Only used when boosting_type='dart'. Probablity to select dropping trees.
skip_drop : float
Only used when boosting_type='dart'. Probablity to skip dropping trees.
max_drop : int
Only used when boosting_type='dart'. Max number of dropped trees in one iteration.
uniform_drop : bool
Only used when boosting_type='dart'. If true, drop trees uniformly, else drop according to weights.
xgboost_dart_mode : bool
Only used when boosting_type='dart'. Whether use xgboost dart mode.
use_missing : bool
Set to False will disbale the special handle of missing value (default: True).
**kwargs : other parameters
Check http://lightgbm.readthedocs.io/en/latest/Parameters.html for more parameters.
Note: **kwargs is not supported in sklearn, it may cause unexpected issues.
Note
----
Expand Down Expand Up @@ -240,13 +204,23 @@ def __init__(self, boosting_type="gbdt", num_leaves=31, max_depth=-1,
raise LightGBMError('Scikit-learn is required for this module')

self.boosting_type = boosting_type
if objective is None:
if isinstance(self, LGBMRegressor):
self.objective = "regression"
elif isinstance(self, LGBMClassifier):
self.objective = "binary"
elif isinstance(self, LGBMRanker):
self.objective = "lambdarank"
else:
raise TypeError("Unknown LGBMModel type.")
else:
self.objective = objective
self.num_leaves = num_leaves
self.max_depth = max_depth
self.learning_rate = learning_rate
self.n_estimators = n_estimators
self.max_bin = max_bin
self.subsample_for_bin = subsample_for_bin
self.objective = objective
self.min_split_gain = min_split_gain
self.min_child_weight = min_child_weight
self.min_child_samples = min_child_samples
Expand All @@ -255,24 +229,9 @@ def __init__(self, boosting_type="gbdt", num_leaves=31, max_depth=-1,
self.colsample_bytree = colsample_bytree
self.reg_alpha = reg_alpha
self.reg_lambda = reg_lambda
self.scale_pos_weight = scale_pos_weight
self.is_unbalance = is_unbalance
self.seed = seed
self.nthread = nthread
self.silent = silent
self.sigmoid = sigmoid
self.huber_delta = huber_delta
self.gaussian_eta = gaussian_eta
self.fair_c = fair_c
self.poisson_max_delta_step = poisson_max_delta_step
self.max_position = max_position
self.label_gain = label_gain
self.drop_rate = drop_rate
self.skip_drop = skip_drop
self.max_drop = max_drop
self.uniform_drop = uniform_drop
self.xgboost_dart_mode = xgboost_dart_mode
self.use_missing = use_missing
self._Booster = None
self.evals_result = None
self.best_iteration = -1
Expand All @@ -281,6 +240,19 @@ def __init__(self, boosting_type="gbdt", num_leaves=31, max_depth=-1,
self.fobj = _objective_function_wrapper(self.objective)
else:
self.fobj = None
self.other_params = kwargs

def get_params(self, deep=True):
params = super(LGBMModel, self).get_params(deep=deep)
params.update(self.other_params)
return params

# minor change to support `**kwargs`
def set_params(self, **params):
for key, value in params.items():
setattr(self, key, value)
self.other_params[key] = value
return self

def fit(self, X, y,
sample_weight=None, init_score=None, group=None,
Expand Down Expand Up @@ -363,12 +335,8 @@ def fit(self, X, y,
params['verbose'] = -1 if self.silent else 1
if hasattr(self, 'n_classes_') and self.n_classes_ > 2:
params['num_class'] = self.n_classes_
if hasattr(self, 'eval_at'):
params['ndcg_eval_at'] = self.eval_at
if self.fobj:
params['objective'] = 'None' # objective = nullptr for unknown objective
if 'label_gain' in params and params['label_gain'] is None:
del params['label_gain'] # use default of cli version

if callable(eval_metric):
feval = _eval_function_wrapper(eval_metric)
Expand Down Expand Up @@ -494,32 +462,6 @@ def feature_importance(self):

class LGBMRegressor(LGBMModel, LGBMRegressorBase):

def __init__(self, boosting_type="gbdt", num_leaves=31, max_depth=-1,
learning_rate=0.1, n_estimators=10, max_bin=255,
subsample_for_bin=50000, objective="regression",
min_split_gain=0, min_child_weight=5, min_child_samples=10,
subsample=1, subsample_freq=1, colsample_bytree=1,
reg_alpha=0, reg_lambda=0,
seed=0, nthread=-1, silent=True,
huber_delta=1.0, gaussian_eta=1.0, fair_c=1.0,
poisson_max_delta_step=0.7,
drop_rate=0.1, skip_drop=0.5, max_drop=50,
uniform_drop=False, xgboost_dart_mode=False, use_missing=True):
super(LGBMRegressor, self).__init__(boosting_type=boosting_type, num_leaves=num_leaves,
max_depth=max_depth, learning_rate=learning_rate,
n_estimators=n_estimators, max_bin=max_bin,
subsample_for_bin=subsample_for_bin, objective=objective,
min_split_gain=min_split_gain, min_child_weight=min_child_weight,
min_child_samples=min_child_samples, subsample=subsample,
subsample_freq=subsample_freq, colsample_bytree=colsample_bytree,
reg_alpha=reg_alpha, reg_lambda=reg_lambda,
seed=seed, nthread=nthread, silent=silent,
huber_delta=huber_delta, gaussian_eta=gaussian_eta, fair_c=fair_c,
poisson_max_delta_step=poisson_max_delta_step,
drop_rate=drop_rate, skip_drop=skip_drop, max_drop=max_drop,
uniform_drop=uniform_drop, xgboost_dart_mode=xgboost_dart_mode,
use_missing=use_missing)

def fit(self, X, y,
sample_weight=None, init_score=None,
eval_set=None, eval_names=None, eval_sample_weight=None,
Expand All @@ -543,31 +485,6 @@ def fit(self, X, y,

class LGBMClassifier(LGBMModel, LGBMClassifierBase):

def __init__(self, boosting_type="gbdt", num_leaves=31, max_depth=-1,
learning_rate=0.1, n_estimators=10, max_bin=255,
subsample_for_bin=50000, objective="binary",
min_split_gain=0, min_child_weight=5, min_child_samples=10,
subsample=1, subsample_freq=1, colsample_bytree=1,
reg_alpha=0, reg_lambda=0, scale_pos_weight=1,
is_unbalance=False, seed=0, nthread=-1,
silent=True, sigmoid=1.0,
drop_rate=0.1, skip_drop=0.5, max_drop=50,
uniform_drop=False, xgboost_dart_mode=False, use_missing=True):
self.classes, self.n_classes = None, None
super(LGBMClassifier, self).__init__(boosting_type=boosting_type, num_leaves=num_leaves,
max_depth=max_depth, learning_rate=learning_rate,
n_estimators=n_estimators, max_bin=max_bin,
subsample_for_bin=subsample_for_bin, objective=objective,
min_split_gain=min_split_gain, min_child_weight=min_child_weight,
min_child_samples=min_child_samples, subsample=subsample,
subsample_freq=subsample_freq, colsample_bytree=colsample_bytree,
reg_alpha=reg_alpha, reg_lambda=reg_lambda,
scale_pos_weight=scale_pos_weight, is_unbalance=is_unbalance,
seed=seed, nthread=nthread, silent=silent, sigmoid=sigmoid,
drop_rate=drop_rate, skip_drop=skip_drop, max_drop=max_drop,
uniform_drop=uniform_drop, xgboost_dart_mode=xgboost_dart_mode,
use_missing=use_missing)

def fit(self, X, y,
sample_weight=None, init_score=None,
eval_set=None, eval_names=None, eval_sample_weight=None,
Expand Down Expand Up @@ -659,31 +576,6 @@ def n_classes_(self):

class LGBMRanker(LGBMModel):

def __init__(self, boosting_type="gbdt", num_leaves=31, max_depth=-1,
learning_rate=0.1, n_estimators=10, max_bin=255,
subsample_for_bin=50000, objective="lambdarank",
min_split_gain=0, min_child_weight=5, min_child_samples=10,
subsample=1, subsample_freq=1, colsample_bytree=1,
reg_alpha=0, reg_lambda=0, scale_pos_weight=1,
is_unbalance=False, seed=0, nthread=-1, silent=True,
sigmoid=1.0, max_position=20, label_gain=None,
drop_rate=0.1, skip_drop=0.5, max_drop=50,
uniform_drop=False, xgboost_dart_mode=False, use_missing=True):
super(LGBMRanker, self).__init__(boosting_type=boosting_type, num_leaves=num_leaves,
max_depth=max_depth, learning_rate=learning_rate,
n_estimators=n_estimators, max_bin=max_bin,
subsample_for_bin=subsample_for_bin, objective=objective,
min_split_gain=min_split_gain, min_child_weight=min_child_weight,
min_child_samples=min_child_samples, subsample=subsample,
subsample_freq=subsample_freq, colsample_bytree=colsample_bytree,
reg_alpha=reg_alpha, reg_lambda=reg_lambda,
scale_pos_weight=scale_pos_weight, is_unbalance=is_unbalance,
seed=seed, nthread=nthread, silent=silent,
sigmoid=sigmoid, max_position=max_position, label_gain=label_gain,
drop_rate=drop_rate, skip_drop=skip_drop, max_drop=max_drop,
uniform_drop=uniform_drop, xgboost_dart_mode=xgboost_dart_mode,
use_missing=use_missing)

def fit(self, X, y,
sample_weight=None, init_score=None, group=None,
eval_set=None, eval_names=None, eval_sample_weight=None,
Expand Down

0 comments on commit e465f92

Please sign in to comment.