Skip to content

Commit

Permalink
add indices in shuffle model. (#1710)
Browse files Browse the repository at this point in the history
* add indexs in shuffle model.

* fix pep

* fix bug
  • Loading branch information
guolinke committed Sep 29, 2018
1 parent 172caee commit 81e2485
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 11 deletions.
2 changes: 1 addition & 1 deletion include/LightGBM/boosting.h
Expand Up @@ -47,7 +47,7 @@ class LIGHTGBM_EXPORT Boosting {
/*!
* \brief Shuffle Existing Models
*/
virtual void ShuffleModels() = 0;
virtual void ShuffleModels(int start_iter, int end_iter) = 0;

virtual void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) = 0;
Expand Down
2 changes: 1 addition & 1 deletion include/LightGBM/c_api.h
Expand Up @@ -374,7 +374,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterFree(BoosterHandle handle);
/*!
* \brief Shuffle Models
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterShuffleModels(BoosterHandle handle);
LIGHTGBM_C_EXPORT int LGBM_BoosterShuffleModels(BoosterHandle handle, int start_iter, int end_iter);

/*!
* \brief Merge model in two booster to first handle
Expand Down
15 changes: 13 additions & 2 deletions python-package/lightgbm/basic.py
Expand Up @@ -1945,15 +1945,26 @@ def save_model(self, filename, num_iteration=None, start_iteration=0):
_save_pandas_categorical(filename, self.pandas_categorical)
return self

def shuffle_models(self):
def shuffle_models(self, start_iteration=0, end_iteration=-1):
"""Shuffle models.
Parameters
----------
start_iteration : int, optional (default=0)
Index of the iteration that will start to shuffle.
end_iteration : int, optional (default=-1)
The last iteration that will be shuffled.
If <= 0, means the last iteration.
Returns
-------
self : Booster
Booster with shuffled models.
"""
_safe_call(_LIB.LGBM_BoosterShuffleModels(self.handle))
_safe_call(_LIB.LGBM_BoosterShuffleModels(
self.handle,
ctypes.c_int(start_iter),
ctypes.c_int(end_iter)))
return self

def model_from_string(self, model_str, verbose=True):
Expand Down
11 changes: 8 additions & 3 deletions src/boosting/gbdt.h
Expand Up @@ -70,16 +70,21 @@ class GBDT : public GBDTBase {
num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
}

void ShuffleModels() override {
void ShuffleModels(int start_iter, int end_iter) override {
int total_iter = static_cast<int>(models_.size()) / num_tree_per_iteration_;
start_iter = std::max(0, start_iter);
if (end_iter <= 0) {
end_iter = total_iter;
}
end_iter = std::min(total_iter, end_iter);
auto original_models = std::move(models_);
std::vector<int> indices(total_iter);
for (int i = 0; i < total_iter; ++i) {
indices[i] = i;
}
Random tmp_rand(17);
for (int i = 0; i < total_iter - 1; ++i) {
int j = tmp_rand.NextShort(i + 1, total_iter);
for (int i = start_iter; i < end_iter - 1; ++i) {
int j = tmp_rand.NextShort(i + 1, end_iter);
std::swap(indices[i], indices[j]);
}
models_ = std::vector<std::unique_ptr<Tree>>();
Expand Down
8 changes: 4 additions & 4 deletions src/c_api.cpp
Expand Up @@ -294,9 +294,9 @@ class Booster {
dynamic_cast<GBDTBase*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
}

void ShuffleModels() {
void ShuffleModels(int start_iter, int end_iter) {
std::lock_guard<std::mutex> lock(mutex_);
boosting_->ShuffleModels();
boosting_->ShuffleModels(start_iter, end_iter);
}

int GetEvalCounts() const {
Expand Down Expand Up @@ -919,10 +919,10 @@ int LGBM_BoosterFree(BoosterHandle handle) {
API_END();
}

int LGBM_BoosterShuffleModels(BoosterHandle handle) {
int LGBM_BoosterShuffleModels(BoosterHandle handle, int start_iter, int end_iter) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->ShuffleModels();
ref_booster->ShuffleModels(start_iter, end_iter);
API_END();
}

Expand Down

0 comments on commit 81e2485

Please sign in to comment.