Skip to content

Commit

Permalink
Average precision score (#3347)
Browse files Browse the repository at this point in the history
* Implement average precision score

* Fix lint errors

* Change name to average_precision

* Add to R-package list of metrics

* Empty commit to trigger CI jobs

* Change name to average_precision
  • Loading branch information
btrotta committed Sep 23, 2020
1 parent 7744757 commit 2870490
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 1 deletion.
1 change: 1 addition & 0 deletions R-package/R/metrics.R
Expand Up @@ -22,6 +22,7 @@
, "ndcg" = TRUE
, "map" = TRUE
, "auc" = TRUE
, "average_precision" = TRUE
, "binary_logloss" = FALSE
, "binary_error" = FALSE
, "auc_mu" = TRUE
Expand Down
2 changes: 2 additions & 0 deletions docs/Parameters.rst
Expand Up @@ -1012,6 +1012,8 @@ Metric Parameters

- ``auc``, `AUC <https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve>`__

- ``average_precision``, `average precision score <https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html>`__

- ``binary_logloss``, `log loss <https://en.wikipedia.org/wiki/Cross_entropy>`__, aliases: ``binary``

- ``binary_error``, for one sample: ``0`` for correct classification, ``1`` for error classification
Expand Down
1 change: 1 addition & 0 deletions include/LightGBM/config.h
Expand Up @@ -875,6 +875,7 @@ struct Config {
// descl2 = ``ndcg``, `NDCG <https://en.wikipedia.org/wiki/Discounted_cumulative_gain#Normalized_DCG>`__, aliases: ``lambdarank``, ``rank_xendcg``, ``xendcg``, ``xe_ndcg``, ``xe_ndcg_mart``, ``xendcg_mart``
// descl2 = ``map``, `MAP <https://makarandtapaswi.wordpress.com/2012/07/02/intuition-behind-average-precision-and-map/>`__, aliases: ``mean_average_precision``
// descl2 = ``auc``, `AUC <https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve>`__
// descl2 = ``average_precision``, `average precision score <https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html>`__
// descl2 = ``binary_logloss``, `log loss <https://en.wikipedia.org/wiki/Cross_entropy>`__, aliases: ``binary``
// descl2 = ``binary_error``, for one sample: ``0`` for correct classification, ``1`` for error classification
// descl2 = ``auc_mu``, `AUC-mu <http://proceedings.mlr.press/v97/kleiman19a/kleiman19a.pdf>`__
Expand Down
121 changes: 121 additions & 0 deletions src/metric/binary_metric.hpp
Expand Up @@ -263,5 +263,126 @@ class AUCMetric: public Metric {
std::vector<std::string> name_;
};


/*!
* \brief Average Precision Metric for binary classification task.
*/
class AveragePrecisionMetric: public Metric {
public:
explicit AveragePrecisionMetric(const Config&) {
}

virtual ~AveragePrecisionMetric() {
}

const std::vector<std::string>& GetName() const override {
return name_;
}

double factor_to_bigger_better() const override {
return 1.0f;
}

void Init(const Metadata& metadata, data_size_t num_data) override {
name_.emplace_back("average_precision");

num_data_ = num_data;
// get label
label_ = metadata.label();
// get weights
weights_ = metadata.weights();

if (weights_ == nullptr) {
sum_weights_ = static_cast<double>(num_data_);
} else {
sum_weights_ = 0.0f;
for (data_size_t i = 0; i < num_data; ++i) {
sum_weights_ += weights_[i];
}
}
}

std::vector<double> Eval(const double* score, const ObjectiveFunction*) const override {
// get indices sorted by score, descending order
std::vector<data_size_t> sorted_idx;
for (data_size_t i = 0; i < num_data_; ++i) {
sorted_idx.emplace_back(i);
}
Common::ParallelSort(sorted_idx.begin(), sorted_idx.end(), [score](data_size_t a, data_size_t b) {return score[a] > score[b]; });
// temp sum of postive label
double cur_actual_pos = 0.0f;
// total sum of postive label
double sum_actual_pos = 0.0f;
// total sum of predicted positive
double sum_pred_pos = 0.0f;
// accumulated precision
double accum_prec = 1.0f;
// accumlated pr-auc
double accum = 0.0f;
// temp sum of negative label
double cur_neg = 0.0f;
double threshold = score[sorted_idx[0]];
if (weights_ == nullptr) { // no weights
for (data_size_t i = 0; i < num_data_; ++i) {
const label_t cur_label = label_[sorted_idx[i]];
const double cur_score = score[sorted_idx[i]];
// new threshold
if (cur_score != threshold) {
threshold = cur_score;
// accumulate
sum_actual_pos += cur_actual_pos;
sum_pred_pos += cur_actual_pos + cur_neg;
accum_prec = sum_actual_pos / sum_pred_pos;
accum += cur_actual_pos * accum_prec;
// reset
cur_neg = cur_actual_pos = 0.0f;
}
cur_neg += (cur_label <= 0);
cur_actual_pos += (cur_label > 0);
}
} else { // has weights
for (data_size_t i = 0; i < num_data_; ++i) {
const label_t cur_label = label_[sorted_idx[i]];
const double cur_score = score[sorted_idx[i]];
const label_t cur_weight = weights_[sorted_idx[i]];
// new threshold
if (cur_score != threshold) {
threshold = cur_score;
// accmulate
sum_actual_pos += cur_actual_pos;
sum_pred_pos += cur_actual_pos + cur_neg;
accum_prec = sum_actual_pos / sum_pred_pos;
accum += cur_actual_pos * accum_prec;
// reset
cur_neg = cur_actual_pos = 0.0f;
}
cur_neg += (cur_label <= 0) * cur_weight;
cur_actual_pos += (cur_label > 0) * cur_weight;
}
}
sum_actual_pos += cur_actual_pos;
sum_pred_pos += cur_actual_pos + cur_neg;
accum_prec = sum_actual_pos / sum_pred_pos;
accum += cur_actual_pos * accum_prec;
double ap = 1.0f;
if (sum_actual_pos > 0.0f && sum_actual_pos != sum_weights_) {
ap = accum / sum_actual_pos;
}
return std::vector<double>(1, ap);
}

private:
/*! \brief Number of data */
data_size_t num_data_;
/*! \brief Pointer of label */
const label_t* label_;
/*! \brief Pointer of weighs */
const label_t* weights_;
/*! \brief Sum weights */
double sum_weights_;
/*! \brief Name of test set */
std::vector<std::string> name_;
};

} // namespace LightGBM
#endif // LightGBM_METRIC_BINARY_METRIC_HPP_
2 changes: 2 additions & 0 deletions src/metric/metric.cpp
Expand Up @@ -34,6 +34,8 @@ Metric* Metric::CreateMetric(const std::string& type, const Config& config) {
return new BinaryErrorMetric(config);
} else if (type == std::string("auc")) {
return new AUCMetric(config);
} else if (type == std::string("average_precision")) {
return new AveragePrecisionMetric(config);
} else if (type == std::string("auc_mu")) {
return new AucMuMetric(config);
} else if (type == std::string("ndcg")) {
Expand Down
23 changes: 22 additions & 1 deletion tests/python_package_test/test_engine.py
Expand Up @@ -12,7 +12,7 @@
from scipy.sparse import csr_matrix, isspmatrix_csr, isspmatrix_csc
from sklearn.datasets import (load_boston, load_breast_cancer, load_digits,
load_iris, load_svmlight_file, make_multilabel_classification)
from sklearn.metrics import log_loss, mean_absolute_error, mean_squared_error, roc_auc_score
from sklearn.metrics import log_loss, mean_absolute_error, mean_squared_error, roc_auc_score, average_precision_score
from sklearn.model_selection import train_test_split, TimeSeriesSplit, GroupKFold

try:
Expand Down Expand Up @@ -2507,3 +2507,24 @@ def inner_test(X, y, params, early_stopping_rounds):
inner_test(X, y, params, early_stopping_rounds=1)
inner_test(X, y, params, early_stopping_rounds=5)
inner_test(X, y, params, early_stopping_rounds=None)

def test_average_precision_metric(self):
# test against sklearn average precision metric
X, y = load_breast_cancer(return_X_y=True)
params = {
'objective': 'binary',
'metric': 'average_precision',
'verbose': -1
}
res = {}
lgb_X = lgb.Dataset(X, label=y)
est = lgb.train(params, lgb_X, num_boost_round=10, valid_sets=[lgb_X], evals_result=res)
ap = res['training']['average_precision'][-1]
pred = est.predict(X)
sklearn_ap = average_precision_score(y, pred)
self.assertAlmostEqual(ap, sklearn_ap)
# test that average precision is 1 where model predicts perfectly
y[:] = 1
lgb_X = lgb.Dataset(X, label=y)
lgb.train(params, lgb_X, num_boost_round=1, valid_sets=[lgb_X], evals_result=res)
self.assertAlmostEqual(res['training']['average_precision'][-1], 1)

0 comments on commit 2870490

Please sign in to comment.