Skip to content

Commit

Permalink
Switch RMSE to MSE (true L2 loss) (#408)
Browse files Browse the repository at this point in the history
* RMSE (L2) -> MSE (true L2)

* Remove sqrt unneeded reference

* Square L2 test (RMSE to MSE)

* No square root on test

* Attempt to add RMSE
  • Loading branch information
Laurae2 authored and guolinke committed Apr 26, 2017
1 parent faa3b57 commit d1742d8
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 6 deletions.
3 changes: 2 additions & 1 deletion docs/Parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,10 @@ The parameter format is ```key1=value1 key2=value2 ... ``` . And parameters can

## Metric parameters

* ```metric```, default={```l2``` for regression}, {```binary_logloss``` for binary classification},{```ndcg``` for lambdarank}, type=multi-enum, options=```l1```,```l2```,```ndcg```,```auc```,```binary_logloss```,```binary_error```
* ```metric```, default={```l2``` for regression}, {```binary_logloss``` for binary classification},{```ndcg``` for lambdarank}, type=multi-enum, options=```l1```,```l2```,```ndcg```,```auc```,```binary_logloss```,```binary_error```...
* ```l1```, absolute loss, alias=```mean_absolute_error```, ```mae```
* ```l2```, square loss, alias=```mean_squared_error```, ```mse```
* ```l2_root```, root square loss, alias=```root_mean_squared_error```, ```rmse```
* ```huber```, [Huber loss](https://en.wikipedia.org/wiki/Huber_loss "Huber loss - Wikipedia")
* ```fair```, [Fair loss](https://www.kaggle.com/c/allstate-claims-severity/discussion/24520)
* ```poisson```, [Poisson regression](https://en.wikipedia.org/wiki/Poisson_regression "Poisson regression")
Expand Down
4 changes: 3 additions & 1 deletion src/metric/metric.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ namespace LightGBM {
Metric* Metric::CreateMetric(const std::string& type, const MetricConfig& config) {
if (type == std::string("l2") || type == std::string("mean_squared_error") || type == std::string("mse")) {
return new L2Metric(config);
} else if (type == std::string("l1") || type == std::string("mean_absolute_error") || type == std::string("mae")) {
} else if (type == std::string("l2_root") || type == std::string("root_mean_squared_error") || type == std::string("rmse")) {
return new RMSEMetric(config);
} else if (type == std::string("l1") || type == std::string("mean_absolute_error") || type == std::string("mae")) {
return new L1Metric(config);
} else if (type == std::string("huber")) {
return new HuberLossMetric(config);
Expand Down
23 changes: 21 additions & 2 deletions src/metric/regression_metric.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,25 @@ class RegressionMetric: public Metric {
std::vector<std::string> name_;
};

/*! \brief RMSE loss for regression task */
class RMSEMetric: public RegressionMetric<RMSEMetric> {
public:
explicit RMSEMetric(const MetricConfig& config) :RegressionMetric<RMSEMetric>(config) {}

inline static double LossOnPoint(float label, double score, double, double) {
return (score - label)*(score - label);
}

inline static double AverageLoss(double sum_loss, double sum_weights) {
// need sqrt the result for RMSE loss
return std::sqrt(sum_loss / sum_weights);
}

inline static const char* Name() {
return "rmse";
}
};

/*! \brief L2 loss for regression task */
class L2Metric: public RegressionMetric<L2Metric> {
public:
Expand All @@ -101,8 +120,8 @@ class L2Metric: public RegressionMetric<L2Metric> {
}

inline static double AverageLoss(double sum_loss, double sum_weights) {
// need sqrt the result for L2 loss
return std::sqrt(sum_loss / sum_weights);
// need mean of the result for L2 loss
return sum_loss / sum_weights;
}

inline static const char* Name() {
Expand Down
3 changes: 1 addition & 2 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ def test_binary(self):

def test_regreesion(self):
evals_result, ret = template.test_template()
ret **= 0.5
self.assertLess(ret, 4)
self.assertLess(ret, 16)
self.assertAlmostEqual(min(evals_result['eval']['l2']), ret, places=5)

def test_multiclass(self):
Expand Down

0 comments on commit d1742d8

Please sign in to comment.