Skip to content

Commit

Permalink
When loading a binary file, take feature penalty and monotone constra…
Browse files Browse the repository at this point in the history
…ints from config if given there. (#1881)

* When loading a binary file, take feature penalty from config if given there.

* When loading a binary file, take feature penalty from config if given there.

* Fix crash when num_features != num_total_features and feature_contri is given.

* Apply the same logic to monotone_types_.

* Fix indentation
  • Loading branch information
remcob-gr authored and StrikerRUS committed Jan 16, 2019
1 parent d038aa5 commit 6152785
Showing 1 changed file with 32 additions and 8 deletions.
40 changes: 32 additions & 8 deletions src/io/dataset_loader.cpp
Expand Up @@ -370,21 +370,45 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
}
mem_ptr += sizeof(int) * (dataset->num_groups_);

const int8_t* tmp_ptr_monotone_type = reinterpret_cast<const int8_t*>(mem_ptr);
dataset->monotone_types_.clear();
for (int i = 0; i < dataset->num_features_; ++i) {
dataset->monotone_types_.push_back(tmp_ptr_monotone_type[i]);
if(!config_.monotone_constraints.empty()) {
CHECK(dataset->num_total_features_ == config_.monotone_constraints.size());
dataset->monotone_types_.resize(dataset->num_features_);
for(int i = 0; i < dataset->num_total_features_; ++i){
int inner_fidx = dataset->InnerFeatureIndex(i);
if(inner_fidx >= 0) {
dataset->monotone_types_[inner_fidx] = config_.monotone_constraints[i];
}
}
}
else {
const int8_t* tmp_ptr_monotone_type = reinterpret_cast<const int8_t*>(mem_ptr);
dataset->monotone_types_.clear();
for (int i = 0; i < dataset->num_features_; ++i) {
dataset->monotone_types_.push_back(tmp_ptr_monotone_type[i]);
}
}
mem_ptr += sizeof(int8_t) * (dataset->num_features_);

if (ArrayArgs<int8_t>::CheckAllZero(dataset->monotone_types_)) {
dataset->monotone_types_.clear();
}

const double* tmp_ptr_feature_penalty = reinterpret_cast<const double*>(mem_ptr);
dataset->feature_penalty_.clear();
for (int i = 0; i < dataset->num_features_; ++i) {
dataset->feature_penalty_.push_back(tmp_ptr_feature_penalty[i]);
if(!config_.feature_contri.empty()) {
CHECK(dataset->num_total_features_ == config_.feature_contri.size());
dataset->feature_penalty_.resize(dataset->num_features_);
for(int i = 0; i < dataset->num_total_features_; ++i){
int inner_fidx = dataset->InnerFeatureIndex(i);
if(inner_fidx >= 0) {
dataset->feature_penalty_[inner_fidx] = config_.feature_contri[i];
}
}
}
else {
const double* tmp_ptr_feature_penalty = reinterpret_cast<const double*>(mem_ptr);
dataset->feature_penalty_.clear();
for (int i = 0; i < dataset->num_features_; ++i) {
dataset->feature_penalty_.push_back(tmp_ptr_feature_penalty[i]);
}
}
mem_ptr += sizeof(double) * (dataset->num_features_);

Expand Down

0 comments on commit 6152785

Please sign in to comment.