From 81e2485ac9f4a42581d1fbc392c20961360fae3a Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Sat, 29 Sep 2018 10:22:01 +0800 Subject: [PATCH] add indices in shuffle model. (#1710) * add indexs in shuffle model. * fix pep * fix bug --- include/LightGBM/boosting.h | 2 +- include/LightGBM/c_api.h | 2 +- python-package/lightgbm/basic.py | 15 +++++++++++++-- src/boosting/gbdt.h | 11 ++++++++--- src/c_api.cpp | 8 ++++---- 5 files changed, 27 insertions(+), 11 deletions(-) diff --git a/include/LightGBM/boosting.h b/include/LightGBM/boosting.h index 74161f30a86..2cd19b99d06 100644 --- a/include/LightGBM/boosting.h +++ b/include/LightGBM/boosting.h @@ -47,7 +47,7 @@ class LIGHTGBM_EXPORT Boosting { /*! * \brief Shuffle Existing Models */ - virtual void ShuffleModels() = 0; + virtual void ShuffleModels(int start_iter, int end_iter) = 0; virtual void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function, const std::vector& training_metrics) = 0; diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index cc941e1fea4..4d50788efea 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -374,7 +374,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterFree(BoosterHandle handle); /*! * \brief Shuffle Models */ -LIGHTGBM_C_EXPORT int LGBM_BoosterShuffleModels(BoosterHandle handle); +LIGHTGBM_C_EXPORT int LGBM_BoosterShuffleModels(BoosterHandle handle, int start_iter, int end_iter); /*! * \brief Merge model in two booster to first handle diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 22a5a753664..d774473a74e 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -1945,15 +1945,26 @@ def save_model(self, filename, num_iteration=None, start_iteration=0): _save_pandas_categorical(filename, self.pandas_categorical) return self - def shuffle_models(self): + def shuffle_models(self, start_iteration=0, end_iteration=-1): """Shuffle models. + Parameters + ---------- + start_iteration : int, optional (default=0) + Index of the iteration that will start to shuffle. + end_iteration : int, optional (default=-1) + The last iteration that will be shuffled. + If <= 0, means the last iteration. + Returns ------- self : Booster Booster with shuffled models. """ - _safe_call(_LIB.LGBM_BoosterShuffleModels(self.handle)) + _safe_call(_LIB.LGBM_BoosterShuffleModels( + self.handle, + ctypes.c_int(start_iter), + ctypes.c_int(end_iter))) return self def model_from_string(self, model_str, verbose=True): diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h index 4fc5ec6e3db..c519ddcdb7e 100644 --- a/src/boosting/gbdt.h +++ b/src/boosting/gbdt.h @@ -70,16 +70,21 @@ class GBDT : public GBDTBase { num_iteration_for_pred_ = static_cast(models_.size()) / num_tree_per_iteration_; } - void ShuffleModels() override { + void ShuffleModels(int start_iter, int end_iter) override { int total_iter = static_cast(models_.size()) / num_tree_per_iteration_; + start_iter = std::max(0, start_iter); + if (end_iter <= 0) { + end_iter = total_iter; + } + end_iter = std::min(total_iter, end_iter); auto original_models = std::move(models_); std::vector indices(total_iter); for (int i = 0; i < total_iter; ++i) { indices[i] = i; } Random tmp_rand(17); - for (int i = 0; i < total_iter - 1; ++i) { - int j = tmp_rand.NextShort(i + 1, total_iter); + for (int i = start_iter; i < end_iter - 1; ++i) { + int j = tmp_rand.NextShort(i + 1, end_iter); std::swap(indices[i], indices[j]); } models_ = std::vector>(); diff --git a/src/c_api.cpp b/src/c_api.cpp index 3943a9b428a..443a254d695 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -294,9 +294,9 @@ class Booster { dynamic_cast(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val); } - void ShuffleModels() { + void ShuffleModels(int start_iter, int end_iter) { std::lock_guard lock(mutex_); - boosting_->ShuffleModels(); + boosting_->ShuffleModels(start_iter, end_iter); } int GetEvalCounts() const { @@ -919,10 +919,10 @@ int LGBM_BoosterFree(BoosterHandle handle) { API_END(); } -int LGBM_BoosterShuffleModels(BoosterHandle handle) { +int LGBM_BoosterShuffleModels(BoosterHandle handle, int start_iter, int end_iter) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); - ref_booster->ShuffleModels(); + ref_booster->ShuffleModels(start_iter, end_iter); API_END(); }