Skip to content

Commit

Permalink
refine set params (#933)
Browse files Browse the repository at this point in the history
  • Loading branch information
wxchan authored and guolinke committed Sep 26, 2017
1 parent e66a8a3 commit 9c57793
Showing 1 changed file with 30 additions and 42 deletions.
72 changes: 30 additions & 42 deletions src/io/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,42 +32,37 @@ std::unordered_map<std::string, std::string> ConfigBase::Str2Map(const char* par
return params;
}

std::string GetBoostingType(const std::unordered_map<std::string, std::string>& params) {
void GetBoostingType(const std::unordered_map<std::string, std::string>& params, std::string* boosting_type) {
std::string value;
std::string boosting_type = kDefaultBoostingType;
if (ConfigBase::GetString(params, "boosting_type", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("gbdt") || value == std::string("gbrt")) {
boosting_type = "gbdt";
*boosting_type = "gbdt";
} else if (value == std::string("dart")) {
boosting_type = "dart";
*boosting_type = "dart";
} else if (value == std::string("goss")) {
boosting_type = "goss";
*boosting_type = "goss";
} else if (value == std::string("rf") || value == std::string("randomforest")) {
boosting_type = "rf";
*boosting_type = "rf";
} else {
Log::Fatal("Unknown boosting type %s", value.c_str());
}
}
return boosting_type;
}

std::string GetObjectiveType(const std::unordered_map<std::string, std::string>& params) {
void GetObjectiveType(const std::unordered_map<std::string, std::string>& params, std::string* objective_type) {
std::string value;
std::string objective_type = kDefaultObjectiveType;
if (ConfigBase::GetString(params, "objective", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
objective_type = value;
*objective_type = value;
}
return objective_type;
}

std::vector<std::string> GetMetricType(const std::unordered_map<std::string, std::string>& params) {
void GetMetricType(const std::unordered_map<std::string, std::string>& params, std::vector<std::string>* metric_types) {
std::string value;
std::vector<std::string> metric_types;
if (ConfigBase::GetString(params, "metric", &value)) {
// clear old metrics
metric_types.clear();
metric_types->clear();
// to lower
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
// split
Expand All @@ -81,66 +76,59 @@ std::vector<std::string> GetMetricType(const std::unordered_map<std::string, std
}
}
for (auto& metric : metric_sets) {
metric_types.push_back(metric);
metric_types->push_back(metric);
}
metric_types.shrink_to_fit();
metric_types->shrink_to_fit();
}
return metric_types;
}

TaskType GetTaskType(const std::unordered_map<std::string, std::string>& params) {
void GetTaskType(const std::unordered_map<std::string, std::string>& params, TaskType* task_type) {
std::string value;
TaskType task_type = TaskType::kTrain;
if (ConfigBase::GetString(params, "task", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("train") || value == std::string("training")) {
task_type = TaskType::kTrain;
*task_type = TaskType::kTrain;
} else if (value == std::string("predict") || value == std::string("prediction")
|| value == std::string("test")) {
task_type = TaskType::kPredict;
*task_type = TaskType::kPredict;
} else if (value == std::string("convert_model")) {
task_type = TaskType::kConvertModel;
*task_type = TaskType::kConvertModel;
} else {
Log::Fatal("Unknown task type %s", value.c_str());
}
}
return task_type;
}

std::string GetDeviceType(const std::unordered_map<std::string, std::string>& params) {
void GetDeviceType(const std::unordered_map<std::string, std::string>& params, std::string* device_type) {
std::string value;
std::string device_type = kDefaultDevice;
if (ConfigBase::GetString(params, "device", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("cpu")) {
device_type = "cpu";
*device_type = "cpu";
} else if (value == std::string("gpu")) {
device_type = "gpu";
*device_type = "gpu";
} else {
Log::Fatal("Unknown device type %s", value.c_str());
}
}
return device_type;
}

std::string GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params) {
void GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params, std::string* tree_learner_type) {
std::string value;
std::string tree_learner_type = kDefaultTreeLearnerType;
if (ConfigBase::GetString(params, "tree_learner", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("serial")) {
tree_learner_type = "serial";
*tree_learner_type = "serial";
} else if (value == std::string("feature") || value == std::string("feature_parallel")) {
tree_learner_type = "feature";
*tree_learner_type = "feature";
} else if (value == std::string("data") || value == std::string("data_parallel")) {
tree_learner_type = "data";
*tree_learner_type = "data";
} else if (value == std::string("voting") || value == std::string("voting_parallel")) {
tree_learner_type = "voting";
*tree_learner_type = "voting";
} else {
Log::Fatal("Unknown tree learner type %s", value.c_str());
}
}
return tree_learner_type;
}

void OverallConfig::Set(const std::unordered_map<std::string, std::string>& params) {
Expand All @@ -157,17 +145,17 @@ void OverallConfig::Set(const std::unordered_map<std::string, std::string>& para
boosting_config.drop_seed = static_cast<int>(rand.NextShort(0, int_max));
boosting_config.tree_config.feature_fraction_seed = static_cast<int>(rand.NextShort(0, int_max));
}
task_type = GetTaskType(params);
boosting_type = GetBoostingType(params);
GetTaskType(params, &task_type);
GetBoostingType(params, &boosting_type);

metric_types = GetMetricType(params);
GetMetricType(params, &metric_types);

// sub-config setup
network_config.Set(params);
io_config.Set(params);

boosting_config.Set(params);
objective_type = GetObjectiveType(params);
GetObjectiveType(params, &objective_type);
objective_config.Set(params);
metric_config.Set(params);

Expand Down Expand Up @@ -298,7 +286,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetDouble(params, "pred_early_stop_margin", &pred_early_stop_margin);
GetBool(params, "use_missing", &use_missing);
GetBool(params, "zero_as_missing", &zero_as_missing);
device_type = GetDeviceType(params);
GetDeviceType(params, &device_type);
}

void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& params) {
Expand Down Expand Up @@ -413,8 +401,8 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par
GetBool(params, "boost_from_average", &boost_from_average);
CHECK(drop_rate <= 1.0 && drop_rate >= 0.0);
CHECK(skip_drop <= 1.0 && skip_drop >= 0.0);
device_type = GetDeviceType(params);
tree_learner_type = GetTreeLearnerType(params);
GetDeviceType(params, &device_type);
GetTreeLearnerType(params, &tree_learner_type);
tree_config.Set(params);
}

Expand Down

0 comments on commit 9c57793

Please sign in to comment.