Skip to content

Commit

Permalink
fix a bug in goss
Browse files Browse the repository at this point in the history
  • Loading branch information
guolinke committed Jan 15, 2018
1 parent e283565 commit 0f0eb69
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
1 change: 1 addition & 0 deletions include/LightGBM/utils/array_args.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class ArrayArgs {
*r = i;
};

// Note: k refer to index here. e.g. k=0 means get max number.
inline static int ArgMaxAtK(std::vector<VAL_T>* arr, int start, int end, int k) {
if (start >= end - 1) {
return start;
Expand Down
2 changes: 1 addition & 1 deletion src/boosting/goss.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class GOSS: public GBDT {
data_size_t top_k = static_cast<data_size_t>(cnt * gbdt_config_->top_rate);
data_size_t other_k = static_cast<data_size_t>(cnt * gbdt_config_->other_rate);
top_k = std::max(1, top_k);
ArrayArgs<score_t>::ArgMaxAtK(&tmp_gradients, 0, static_cast<int>(tmp_gradients.size()), top_k);
ArrayArgs<score_t>::ArgMaxAtK(&tmp_gradients, 0, static_cast<int>(tmp_gradients.size()), top_k - 1);
score_t threshold = tmp_gradients[top_k - 1];

score_t multiply = static_cast<score_t>(cnt - top_k) / other_k;
Expand Down

0 comments on commit 0f0eb69

Please sign in to comment.