Skip to content

Commit

Permalink
support use objective name to create metric.
Browse files Browse the repository at this point in the history
  • Loading branch information
guolinke committed Jan 21, 2018
1 parent 169a271 commit fc16792
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
7 changes: 7 additions & 0 deletions src/io/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ void GetMetricType(const std::unordered_map<std::string, std::string>& params, s
}
metric_types->shrink_to_fit();
}
// add names of objective function if not providing metric
if (metric_types->empty()) {
if (ConfigBase::GetString(params, "objective", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
metric_types->push_back(value);
}
}
}

void GetTaskType(const std::unordered_map<std::string, std::string>& params, TaskType* task_type) {
Expand Down
8 changes: 4 additions & 4 deletions src/metric/metric.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
namespace LightGBM {

Metric* Metric::CreateMetric(const std::string& type, const MetricConfig& config) {
if (type == std::string("l2") || type == std::string("mean_squared_error") || type == std::string("mse")) {
if (type == std::string("regression") || type == std::string("regression_l2") || type == std::string("l2") || type == std::string("mean_squared_error") || type == std::string("mse")) {
return new L2Metric(config);
} else if (type == std::string("l2_root") || type == std::string("root_mean_squared_error") || type == std::string("rmse")) {
return new RMSEMetric(config);
} else if (type == std::string("l1") || type == std::string("mean_absolute_error") || type == std::string("mae")) {
} else if (type == std::string("regression_l1") || type == std::string("l1") || type == std::string("mean_absolute_error") || type == std::string("mae")) {
return new L1Metric(config);
} else if (type == std::string("quantile")) {
return new QuantileMetric(config);
Expand All @@ -23,7 +23,7 @@ Metric* Metric::CreateMetric(const std::string& type, const MetricConfig& config
return new FairLossMetric(config);
} else if (type == std::string("poisson")) {
return new PoissonMetric(config);
} else if (type == std::string("binary_logloss")) {
} else if (type == std::string("binary_logloss") || type == std::string("binary")) {
return new BinaryLoglossMetric(config);
} else if (type == std::string("binary_error")) {
return new BinaryErrorMetric(config);
Expand All @@ -33,7 +33,7 @@ Metric* Metric::CreateMetric(const std::string& type, const MetricConfig& config
return new NDCGMetric(config);
} else if (type == std::string("map") || type == std::string("mean_average_precision")) {
return new MapMetric(config);
} else if (type == std::string("multi_logloss")) {
} else if (type == std::string("multi_logloss") || type == std::string("multiclass") || type == std::string("multiclass_ova")) {
return new MultiSoftmaxLoglossMetric(config);
} else if (type == std::string("multi_error")) {
return new MultiErrorMetric(config);
Expand Down
3 changes: 2 additions & 1 deletion src/objective/objective_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ namespace LightGBM {

ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& type, const ObjectiveConfig& config) {
if (type == std::string("regression") || type == std::string("regression_l2")
|| type == std::string("mean_squared_error") || type == std::string("mse")) {
|| type == std::string("mean_squared_error") || type == std::string("mse")
|| type == std::string("l2_root") || type == std::string("root_mean_squared_error") || type == std::string("rmse")) {
return new RegressionL2loss(config);
} else if (type == std::string("regression_l1") || type == std::string("mean_absolute_error") || type == std::string("mae")) {
return new RegressionL1loss(config);
Expand Down

0 comments on commit fc16792

Please sign in to comment.