Skip to content

Commit

Permalink
add support of refit-decay (#1603)
Browse files Browse the repository at this point in the history
* add support of refit-decay

* add refit into c_api

* add test

* update document

* Update basic.py

* Update test_engine.py

* Update basic.py

* Update test_engine.py

* fix comments

* update test

* fix the comments

* Update test_engine.py
  • Loading branch information
guolinke committed Aug 25, 2018
1 parent b1bbeba commit 2db6377
Show file tree
Hide file tree
Showing 9 changed files with 109 additions and 3 deletions.
8 changes: 7 additions & 1 deletion docs/Parameters.rst
Expand Up @@ -49,7 +49,7 @@ Core Parameters

- ``refit``, for refitting existing models with new data, aliases: ``refit_tree``

- **Note**: can be used only in CLI version
- **Note**: can be used only in CLI version; for language-specific packages you can use the correspondent functions

- ``objective`` :raw-html:`<a id="objective" title="Permalink to this parameter" href="#objective">&#x1F517;&#xFE0E;</a>`, default = ``regression``, type = enum, options: ``regression``, ``regression_l1``, ``huber``, ``fair``, ``poisson``, ``quantile``, ``mape``, ``gammma``, ``tweedie``, ``binary``, ``multiclass``, ``multiclassova``, ``xentropy``, ``xentlambda``, ``lambdarank``, aliases: ``objective_type``, ``app``, ``application``

Expand Down Expand Up @@ -364,6 +364,12 @@ Learning Control Parameters

- see `this file <https://github.com/Microsoft/LightGBM/tree/master/examples/binary_classification/forced_splits.json>`__ as an example

- ``refit_decay_rate`` :raw-html:`<a id="refit_decay_rate" title="Permalink to this parameter" href="#refit_decay_rate">&#x1F517;&#xFE0E;</a>`, default = ``0.9``, type = double, constraints: ``0.0 <= refit_decay_rate <= 1.0``

- decay rate of ``refit`` task, will use ``leaf_output = refit_decay_rate * old_leaf_output + (1.0 - refit_decay_rate) * new_leaf_output`` to refit trees

- used only in ``refit`` task in CLI version or as argument in ``refit`` function in language-specific package

IO Parameters
-------------

Expand Down
10 changes: 10 additions & 0 deletions include/LightGBM/c_api.h
Expand Up @@ -427,6 +427,16 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_l
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished);

/*!
* \brief Refit the tree model using the new data (online learning)
* \param handle handle
* \param leaf_preds
* \param nrow number of rows of leaf_preds
* \param ncol number of columns of leaf_preds
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterRefit(BoosterHandle handle, const int32_t* leaf_preds, int32_t nrow, int32_t ncol);

/*!
* \brief update the model, by directly specify gradient and second order gradient,
* this can be used to support customized loss function
Expand Down
8 changes: 7 additions & 1 deletion include/LightGBM/config.h
Expand Up @@ -93,7 +93,7 @@ struct Config {
// desc = ``predict``, for prediction, aliases: ``prediction``, ``test``
// desc = ``convert_model``, for converting model file into if-else format, see more information in `IO Parameters <#io-parameters>`__
// desc = ``refit``, for refitting existing models with new data, aliases: ``refit_tree``
// desc = **Note**: can be used only in CLI version
// desc = **Note**: can be used only in CLI version; for language-specific packages you can use the correspondent functions
TaskType task = TaskType::kTrain;

// [doc-only]
Expand Down Expand Up @@ -368,6 +368,12 @@ struct Config {
// desc = see `this file <https://github.com/Microsoft/LightGBM/tree/master/examples/binary_classification/forced_splits.json>`__ as an example
std::string forcedsplits_filename = "";

// check = >=0.0
// check = <=1.0
// desc = decay rate of ``refit`` task, will use ``leaf_output = refit_decay_rate * old_leaf_output + (1.0 - refit_decay_rate) * new_leaf_output`` to refit trees
// desc = used only in ``refit`` task in CLI version or as argument in ``refit`` function in language-specific package
double refit_decay_rate = 0.9;

#pragma endregion

#pragma region IO Parameters
Expand Down
39 changes: 39 additions & 0 deletions python-package/lightgbm/basic.py
Expand Up @@ -1459,6 +1459,7 @@ def __init__(self, params=None, train_set=None, model_file=None, silent=False):
self.model_from_string(params['model_str'])
else:
raise TypeError('Need at least one training dataset or model file to create booster instance')
self.params = params.copy()

def __del__(self):
try:
Expand Down Expand Up @@ -1624,6 +1625,7 @@ def reset_parameter(self, params):
_safe_call(_LIB.LGBM_BoosterResetParameter(
self.handle,
c_str(params_str)))
self.params.update(params)
return self

def update(self, train_set=None, fobj=None):
Expand Down Expand Up @@ -2019,6 +2021,43 @@ def predict(self, data, num_iteration=None, raw_score=False, pred_leaf=False, pr
num_iteration = self.best_iteration
return predictor.predict(data, num_iteration, raw_score, pred_leaf, pred_contrib, data_has_header, is_reshape)

def refit(self, data, label, decay_rate=0.9):
"""Refit the existing Booster by new data.
Parameters
----------
data : string, numpy array or scipy.sparse
Data source for refit.
If string, it represents the path to txt file.
label : list or numpy 1-D array
Label for refit.
decay_rate : float, optional (default=0.9)
Decay rate of refit, will use ``leaf_output = decay_rate * old_leaf_output + (1.0 - decay_rate) * new_leaf_output`` to refit trees.
Returns
-------
result : Booster
Refitted Booster.
"""
predictor = self._to_predictor()
leaf_preds = predictor.predict(data, -1, pred_leaf=True)
nrow = leaf_preds.shape[0]
ncol = leaf_preds.shape[1]
train_set = Dataset(data, label)
new_booster = Booster(self.params, train_set, silent=True)
# Copy models
_safe_call(_LIB.LGBM_BoosterMerge(
new_booster.handle,
predictor.handle))
leaf_preds = leaf_preds.reshape(-1)
ptr_data, type_ptr_data, _ = c_int_array(leaf_preds)
_safe_call(_LIB.LGBM_BoosterRefit(
new_booster.handle,
ptr_data,
ctypes.c_int(nrow),
ctypes.c_int(ncol)))
return new_booster

def get_leaf_output(self, tree_id, leaf_id):
"""Get the output of a leaf.
Expand Down
1 change: 1 addition & 0 deletions src/boosting/gbdt.cpp
Expand Up @@ -348,6 +348,7 @@ void GBDT::RefitTree(const std::vector<std::vector<int>>& tree_leaf_prediction)
#pragma omp parallel for schedule(static)
for (int i = 0; i < num_data_; ++i) {
leaf_pred[i] = tree_leaf_prediction[i][model_index];
CHECK(leaf_pred[i] < models_[model_index]->num_leaves());
}
size_t bias = static_cast<size_t>(tree_id) * num_data_;
auto grad = gradients_.data() + bias;
Expand Down
18 changes: 18 additions & 0 deletions src/c_api.cpp
Expand Up @@ -182,6 +182,17 @@ class Booster {
return boosting_->TrainOneIter(nullptr, nullptr);
}

void Refit(const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
std::lock_guard<std::mutex> lock(mutex_);
std::vector<std::vector<int32_t>> v_leaf_preds(nrow, std::vector<int32_t>(ncol, 0));
for (int i = 0; i < nrow; ++i) {
for (int j = 0; j < ncol; ++j) {
v_leaf_preds[i][j] = leaf_preds[i * ncol + j];
}
}
boosting_->RefitTree(v_leaf_preds);
}

bool TrainOneIter(const score_t* gradients, const score_t* hessians) {
std::lock_guard<std::mutex> lock(mutex_);
return boosting_->TrainOneIter(gradients, hessians);
Expand Down Expand Up @@ -956,6 +967,13 @@ int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
API_END();
}

int LGBM_BoosterRefit(BoosterHandle handle, const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->Refit(leaf_preds, nrow, ncol);
API_END();
}

int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Expand Down
6 changes: 6 additions & 0 deletions src/io/config_auto.cpp
Expand Up @@ -199,6 +199,7 @@ std::unordered_set<std::string> Config::parameter_set({
"monotone_constraints",
"feature_contri",
"forcedsplits_filename",
"refit_decay_rate",
"verbosity",
"max_bin",
"min_data_in_bin",
Expand Down Expand Up @@ -368,6 +369,10 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str

GetString(params, "forcedsplits_filename", &forcedsplits_filename);

GetDouble(params, "refit_decay_rate", &refit_decay_rate);
CHECK(refit_decay_rate >=0.0);
CHECK(refit_decay_rate <=1.0);

GetInt(params, "verbosity", &verbosity);

GetInt(params, "max_bin", &max_bin);
Expand Down Expand Up @@ -554,6 +559,7 @@ std::string Config::SaveMembersToString() const {
str_buf << "[monotone_constraints: " << Common::Join(Common::ArrayCast<int8_t, int>(monotone_constraints),",") << "]\n";
str_buf << "[feature_contri: " << Common::Join(feature_contri,",") << "]\n";
str_buf << "[forcedsplits_filename: " << forcedsplits_filename << "]\n";
str_buf << "[refit_decay_rate: " << refit_decay_rate << "]\n";
str_buf << "[verbosity: " << verbosity << "]\n";
str_buf << "[max_bin: " << max_bin << "]\n";
str_buf << "[min_data_in_bin: " << min_data_in_bin << "]\n";
Expand Down
4 changes: 3 additions & 1 deletion src/treelearner/serial_tree_learner.cpp
Expand Up @@ -238,7 +238,9 @@ Tree* SerialTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t*
}
double output = FeatureHistogram::CalculateSplittedLeafOutput(sum_grad, sum_hess,
config_->lambda_l1, config_->lambda_l2, config_->max_delta_step);
tree->SetLeafOutput(i, output* tree->shrinkage());
auto old_leaf_output = tree->LeafOutput(i);
auto new_leaf_output = output * tree->shrinkage();
tree->SetLeafOutput(i, config_->refit_decay_rate * old_leaf_output + (1.0 - config_->refit_decay_rate) * new_leaf_output);
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
Expand Down
18 changes: 18 additions & 0 deletions tests/python_package_test/test_engine.py
Expand Up @@ -633,3 +633,21 @@ def is_correctly_constrained(learner):
}
constrained_model = lgb.train(params, trainset)
assert is_correctly_constrained(constrained_model)

def test_refit(self):
X, y = load_breast_cancer(True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
params = {
'objective': 'binary',
'metric': 'binary_logloss',
'verbose': -1,
'min_data': 10
}
lgb_train = lgb.Dataset(X_train, y_train)
gbm = lgb.train(params, lgb_train,
num_boost_round=20,
verbose_eval=False)
err_pred = log_loss(y_test, gbm.predict(X_test))
new_gbm = gbm.refit(X_test, y_test)
new_err_pred = log_loss(y_test, new_gbm.predict(X_test))
self.assertGreater(err_pred, new_err_pred)

0 comments on commit 2db6377

Please sign in to comment.