Skip to content

Commit

Permalink
fix bug in filter bin
Browse files Browse the repository at this point in the history
  • Loading branch information
guolinke committed Mar 22, 2017
1 parent 8a0f07a commit 1c1749d
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 19 deletions.
4 changes: 2 additions & 2 deletions include/LightGBM/utils/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -425,8 +425,8 @@ inline static double ApproximateHessianWithGaussian(const double y, const double
}

template <typename T>
inline static T** Vector2Ptr(std::vector<std::vector<T>>& data) {
T** ptr = new T*[data.size()];
inline static std::vector<T*> Vector2Ptr(std::vector<std::vector<T>>& data) {
std::vector<T*> ptr(data.size());
for (size_t i = 0; i < data.size(); ++i) {
ptr[i] = data[i].data();
}
Expand Down
12 changes: 6 additions & 6 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data,
}
}
DatasetLoader loader(io_config, nullptr, 1, nullptr);
ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values),
Common::Vector2Ptr<int>(sample_idx),
ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
Common::Vector2Ptr<int>(sample_idx).data(),
static_cast<int>(sample_values.size()),
Common::VectorSize<double>(sample_values).data(),
sample_cnt, nrow));
Expand Down Expand Up @@ -487,8 +487,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr,
}
CHECK(num_col >= static_cast<int>(sample_values.size()));
DatasetLoader loader(io_config, nullptr, 1, nullptr);
ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values),
Common::Vector2Ptr<int>(sample_idx),
ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
Common::Vector2Ptr<int>(sample_idx).data(),
static_cast<int>(sample_values.size()),
Common::VectorSize<double>(sample_values).data(),
sample_cnt, nrow));
Expand Down Expand Up @@ -546,8 +546,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr,
}
}
DatasetLoader loader(io_config, nullptr, 1, nullptr);
ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values),
Common::Vector2Ptr<int>(sample_idx),
ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
Common::Vector2Ptr<int>(sample_idx).data(),
static_cast<int>(sample_values.size()),
Common::VectorSize<double>(sample_values).data(),
sample_cnt, nrow));
Expand Down
8 changes: 2 additions & 6 deletions src/io/bin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,14 @@ bool NeedFilter(std::vector<int>& cnt_in_bin, int total_cnt, int filter_cnt, Bin
int sum_left = 0;
for (size_t i = 0; i < cnt_in_bin.size() - 1; ++i) {
sum_left += cnt_in_bin[i];
if (sum_left >= filter_cnt) {
return false;
} else if (total_cnt - sum_left >= filter_cnt) {
if (sum_left >= filter_cnt && total_cnt - sum_left >= filter_cnt) {
return false;
}
}
} else {
for (size_t i = 0; i < cnt_in_bin.size() - 1; ++i) {
int sum_left = cnt_in_bin[i];
if (sum_left >= filter_cnt) {
return false;
} else if (total_cnt - sum_left >= filter_cnt) {
if (sum_left >= filter_cnt && total_cnt - sum_left >= filter_cnt) {
return false;
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/io/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ void OverallConfig::CheckParamConflict() {
bool objective_type_multiclass = (objective_type == std::string("multiclass"));
int num_class_check = boosting_config.num_class;
if (objective_type_multiclass) {
if (num_class_check <= 2) {
Log::Fatal("Number of classes should be specified and greater than 2 for multiclass training");
if (num_class_check <= 1) {
Log::Fatal("Number of classes should be specified and greater than 1 for multiclass training");
}
} else {
if (task_type == TaskType::kTrain && num_class_check != 1) {
Expand Down
9 changes: 6 additions & 3 deletions src/io/dataset_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,9 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
feature_names_.push_back(str_buf.str());
}
}
const data_size_t filter_cnt = static_cast<data_size_t>(static_cast<double>(0.95 * io_config_.min_data_in_leaf) / num_data * num_col);

const data_size_t filter_cnt = static_cast<data_size_t>(
static_cast<double>(io_config_.min_data_in_leaf * total_sample_size) / num_data);

#pragma omp parallel for schedule(guided)
for (int i = 0; i < num_col; ++i) {
Expand Down Expand Up @@ -701,7 +703,8 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
}
dataset->feature_names_ = feature_names_;
std::vector<std::unique_ptr<BinMapper>> bin_mappers(sample_values.size());
const data_size_t filter_cnt = static_cast<data_size_t>(static_cast<double>(0.95 * io_config_.min_data_in_leaf) / dataset->num_data_ * sample_values.size());
const data_size_t filter_cnt = static_cast<data_size_t>(
static_cast<double>(io_config_.min_data_in_leaf* sample_values.size()) / dataset->num_data_);

// start find bins
if (num_machines == 1) {
Expand Down Expand Up @@ -815,7 +818,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
}
}
sample_values.clear();
dataset->Construct(bin_mappers, Common::Vector2Ptr<int>(sample_indices),
dataset->Construct(bin_mappers, Common::Vector2Ptr<int>(sample_indices).data(),
Common::VectorSize<int>(sample_indices).data(), sample_data.size(), io_config_);
}

Expand Down

0 comments on commit 1c1749d

Please sign in to comment.