Skip to content

Commit

Permalink
slight speed up the iterator of sparse bin.
Browse files Browse the repository at this point in the history
  • Loading branch information
guolinke committed Mar 1, 2017
1 parent 5800068 commit 362f9ac
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 17 deletions.
22 changes: 21 additions & 1 deletion src/boosting/goss.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@

namespace LightGBM {

#ifdef TIMETAG
std::chrono::duration<double, std::milli> subset_time;
std::chrono::duration<double, std::milli> re_init_tree_time;
#endif

class GOSS: public GBDT {
public:
/*!
Expand All @@ -27,7 +32,10 @@ class GOSS: public GBDT {
}

~GOSS() {

#ifdef TIMETAG
Log::Info("GOSS::subset costs %f", subset_time * 1e-3);
Log::Info("GOSS::re_init_tree costs %f", re_init_tree_time * 1e-3);
#endif
}

void Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
Expand Down Expand Up @@ -165,9 +173,21 @@ class GOSS: public GBDT {
tree_learner_->SetBaggingData(bag_data_indices_.data(), bag_data_cnt_);
} else {
// get subset
#ifdef TIMETAG
auto start_time = std::chrono::steady_clock::now();
#endif
tmp_subset_->ReSize(bag_data_cnt_);
tmp_subset_->CopySubset(train_data_, bag_data_indices_.data(), bag_data_cnt_, false);
#ifdef TIMETAG
subset_time += std::chrono::steady_clock::now() - start_time;
#endif
#ifdef TIMETAG
start_time = std::chrono::steady_clock::now();
#endif
tree_learner_->ResetTrainingData(tmp_subset_.get());
#ifdef TIMETAG
re_init_tree_time += std::chrono::steady_clock::now() - start_time;
#endif
}
}

Expand Down
52 changes: 36 additions & 16 deletions src/io/sparse_bin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@

namespace LightGBM {

template <typename VAL_T>
class SparseBin;
template <typename VAL_T> class SparseBin;

const size_t kNumFastIndex = 64;
const uint8_t kMaxDelta = 255;

template <typename VAL_T>
class SparseBinIterator: public BinIterator {
Expand Down Expand Up @@ -108,16 +106,18 @@ class SparseBin: public Bin {
inline bool NextNonzero(data_size_t* i_delta,
data_size_t* cur_pos) const {
++(*i_delta);
*cur_pos += deltas_[*i_delta];
data_size_t factor = 1;
data_size_t shift = 0;
data_size_t delta = deltas_[*i_delta];
while (*i_delta < num_vals_ && vals_[*i_delta] == 0) {
++(*i_delta);
factor *= kMaxDelta;
*cur_pos += deltas_[*i_delta] * factor;
shift += 8;
delta |= static_cast<data_size_t>(deltas_[*i_delta]) << shift;
}
if (*i_delta >= 0 && *i_delta < num_vals_) {
*cur_pos += delta;
if (*i_delta < num_vals_) {
return true;
} else {
*cur_pos = num_data_;
return false;
}
}
Expand Down Expand Up @@ -193,10 +193,10 @@ class SparseBin: public Bin {
const data_size_t cur_idx = idx_val_pairs[i].first;
const VAL_T bin = idx_val_pairs[i].second;
data_size_t cur_delta = cur_idx - last_idx;
while (cur_delta > kMaxDelta) {
deltas_.push_back(cur_delta % kMaxDelta);
while (cur_delta >= 256) {
deltas_.push_back(cur_delta & 0xff);
vals_.push_back(0);
cur_delta /= kMaxDelta;
cur_delta >>= 8;
}
deltas_.push_back(static_cast<uint8_t>(cur_delta));
vals_.push_back(bin);
Expand Down Expand Up @@ -298,14 +298,34 @@ class SparseBin: public Bin {
void CopySubset(const Bin* full_bin, const data_size_t* used_indices, data_size_t num_used_indices) override {
auto other_bin = reinterpret_cast<const SparseBin<VAL_T>*>(full_bin);
SparseBinIterator<VAL_T> iterator(other_bin, used_indices[0]);
std::vector<std::pair<data_size_t, VAL_T>> tmp_pair;
deltas_.clear();
vals_.clear();
// transform to delta array
data_size_t last_idx = 0;
for (data_size_t i = 0; i < num_used_indices; ++i) {
VAL_T bin = iterator.RawGet(used_indices[i]);
if (bin > 0) {
tmp_pair.emplace_back(i, bin);
data_size_t cur_delta = i - last_idx;
while (cur_delta >= 256) {
deltas_.push_back(cur_delta & 0xff);
vals_.push_back(0);
cur_delta >>= 8;
}
deltas_.push_back(static_cast<uint8_t>(cur_delta));
vals_.push_back(bin);
last_idx = i;
}
}
LoadFromPair(tmp_pair);
// avoid out of range
deltas_.push_back(0);
num_vals_ = static_cast<data_size_t>(vals_.size());

// reduce memory cost
deltas_.shrink_to_fit();
vals_.shrink_to_fit();

// generate fast index
GetFastIndex();
}

protected:
Expand All @@ -320,10 +340,10 @@ class SparseBin: public Bin {

template <typename VAL_T>
inline VAL_T SparseBinIterator<VAL_T>::RawGet(data_size_t idx) {
while (cur_pos_ < idx && i_delta_ < bin_data_->num_vals_) {
while (cur_pos_ < idx) {
bin_data_->NextNonzero(&i_delta_, &cur_pos_);
}
if (cur_pos_ == idx && i_delta_ < bin_data_->num_vals_ && i_delta_ >= 0) {
if (cur_pos_ == idx) {
return bin_data_->vals_[i_delta_];
} else {
return 0;
Expand Down

0 comments on commit 362f9ac

Please sign in to comment.