Skip to content

Commit

Permalink
fix map metric
Browse files Browse the repository at this point in the history
  • Loading branch information
guolinke committed Aug 18, 2017
1 parent 2f0fb20 commit 486f5db
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions src/metric/map_metric.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ class MapMetric:public Metric {
sum_query_weights_ += query_weights_[i];
}
}

npos_per_query_.resize(num_queries_, 0);
for (data_size_t i = 0; i < num_queries_; ++i) {
for (data_size_t j = query_boundaries_[i]; j < query_boundaries_[i + 1]; ++j) {
if (label_[j] > 0.5f) {
++npos_per_query_[i];
}
}
}
}

const std::vector<std::string>& GetName() const override {
Expand All @@ -66,7 +75,7 @@ class MapMetric:public Metric {
return 1.0f;
}

void CalMapAtK(std::vector<int> ks, const float* label,
void CalMapAtK(std::vector<int> ks, data_size_t npos, const float* label,
const double* score, data_size_t num_data, std::vector<double>* out) const {
// get sorted indices by score
std::vector<data_size_t> sorted_idx;
Expand All @@ -80,7 +89,7 @@ class MapMetric:public Metric {
double sum_ap = 0.0f;
data_size_t cur_left = 0;
for (size_t i = 0; i < ks.size(); ++i) {
data_size_t cur_k = ks[i];
data_size_t cur_k = static_cast<data_size_t>(ks[i]);
if (cur_k > num_data) { cur_k = num_data; }
for (data_size_t j = cur_left; j < cur_k; ++j) {
data_size_t idx = sorted_idx[j];
Expand All @@ -89,7 +98,11 @@ class MapMetric:public Metric {
sum_ap += static_cast<double>(num_hit) / (j + 1.0f);
}
}
(*out)[i] = sum_ap / cur_k;
if (npos > 0) {
(*out)[i] = sum_ap / std::min(npos, cur_k);
} else {
(*out)[i] = 1.0f;
}
cur_left = cur_k;
}
}
Expand All @@ -104,7 +117,7 @@ class MapMetric:public Metric {
#pragma omp parallel for schedule(guided) firstprivate(tmp_map)
for (data_size_t i = 0; i < num_queries_; ++i) {
const int tid = omp_get_thread_num();
CalMapAtK(eval_at_, label_ + query_boundaries_[i],
CalMapAtK(eval_at_, npos_per_query_[i], label_ + query_boundaries_[i],
score + query_boundaries_[i], query_boundaries_[i + 1] - query_boundaries_[i], &tmp_map);
for (size_t j = 0; j < eval_at_.size(); ++j) {
result_buffer_[tid][j] += tmp_map[j];
Expand All @@ -114,7 +127,7 @@ class MapMetric:public Metric {
#pragma omp parallel for schedule(guided) firstprivate(tmp_map)
for (data_size_t i = 0; i < num_queries_; ++i) {
const int tid = omp_get_thread_num();
CalMapAtK(eval_at_, label_ + query_boundaries_[i],
CalMapAtK(eval_at_, npos_per_query_[i], label_ + query_boundaries_[i],
score + query_boundaries_[i], query_boundaries_[i + 1] - query_boundaries_[i], &tmp_map);
for (size_t j = 0; j < eval_at_.size(); ++j) {
result_buffer_[tid][j] += tmp_map[j] * query_weights_[i];
Expand Down Expand Up @@ -150,6 +163,7 @@ class MapMetric:public Metric {
/*! \brief Number of threads */
int num_threads_;
std::vector<std::string> name_;
std::vector<data_size_t> npos_per_query_;
};

} // namespace LightGBM
Expand Down

0 comments on commit 486f5db

Please sign in to comment.