-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
/
rank_obj.cc
340 lines (323 loc) · 12.9 KB
/
rank_obj.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
/*!
* Copyright 2015 by Contributors
* \file rank.cc
* \brief Definition of rank loss.
* \author Tianqi Chen, Kailong Chen
*/
#include <dmlc/omp.h>
#include <xgboost/logging.h>
#include <xgboost/objective.h>
#include <vector>
#include <algorithm>
#include <utility>
#include "../common/math.h"
#include "../common/random.h"
namespace xgboost {
namespace obj {
DMLC_REGISTRY_FILE_TAG(rank_obj);
struct LambdaRankParam : public dmlc::Parameter<LambdaRankParam> {
int num_pairsample;
float fix_list_weight;
// declare parameters
DMLC_DECLARE_PARAMETER(LambdaRankParam) {
DMLC_DECLARE_FIELD(num_pairsample).set_lower_bound(1).set_default(1)
.describe("Number of pair generated for each instance.");
DMLC_DECLARE_FIELD(fix_list_weight).set_lower_bound(0.0f).set_default(0.0f)
.describe("Normalize the weight of each list by this value,"
" if equals 0, no effect will happen");
}
};
// objective for lambda rank
class LambdaRankObj : public ObjFunction {
public:
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.InitAllowUnknown(args);
}
void GetGradient(const HostDeviceVector<bst_float>& preds,
const MetaInfo& info,
int iter,
HostDeviceVector<GradientPair>* out_gpair) override {
CHECK_EQ(preds.Size(), info.labels_.Size()) << "label size predict size not match";
const auto& preds_h = preds.HostVector();
out_gpair->Resize(preds_h.size());
std::vector<GradientPair>& gpair = out_gpair->HostVector();
// quick consistency when group is not available
std::vector<unsigned> tgptr(2, 0); tgptr[1] = static_cast<unsigned>(info.labels_.Size());
const std::vector<unsigned> &gptr = info.group_ptr_.size() == 0 ? tgptr : info.group_ptr_;
CHECK(gptr.size() != 0 && gptr.back() == info.labels_.Size())
<< "group structure not consistent with #rows";
const auto ngroup = static_cast<bst_omp_uint>(gptr.size() - 1);
#pragma omp parallel
{
// parall construct, declare random number generator here, so that each
// thread use its own random number generator, seed by thread id and current iteration
common::RandomEngine rnd(iter * 1111 + omp_get_thread_num());
std::vector<LambdaPair> pairs;
std::vector<ListEntry> lst;
std::vector< std::pair<bst_float, unsigned> > rec;
bst_float sum_weights = 0;
for (bst_omp_uint k = 0; k < ngroup; ++k) {
sum_weights += info.GetWeight(k);
}
bst_float weight_normalization_factor = ngroup/sum_weights;
const auto& labels = info.labels_.HostVector();
#pragma omp for schedule(static)
for (bst_omp_uint k = 0; k < ngroup; ++k) {
lst.clear(); pairs.clear();
for (unsigned j = gptr[k]; j < gptr[k+1]; ++j) {
lst.emplace_back(preds_h[j], labels[j], j);
gpair[j] = GradientPair(0.0f, 0.0f);
}
std::sort(lst.begin(), lst.end(), ListEntry::CmpPred);
rec.resize(lst.size());
for (unsigned i = 0; i < lst.size(); ++i) {
rec[i] = std::make_pair(lst[i].label, i);
}
std::sort(rec.begin(), rec.end(), common::CmpFirst);
// enumerate buckets with same label, for each item in the lst, grab another sample randomly
for (unsigned i = 0; i < rec.size(); ) {
unsigned j = i + 1;
while (j < rec.size() && rec[j].first == rec[i].first) ++j;
// bucket in [i,j), get a sample outside bucket
unsigned nleft = i, nright = static_cast<unsigned>(rec.size() - j);
if (nleft + nright != 0) {
int nsample = param_.num_pairsample;
while (nsample --) {
for (unsigned pid = i; pid < j; ++pid) {
unsigned ridx = std::uniform_int_distribution<unsigned>(0, nleft + nright - 1)(rnd);
if (ridx < nleft) {
pairs.emplace_back(rec[ridx].second, rec[pid].second,
info.GetWeight(k) * weight_normalization_factor);
} else {
pairs.emplace_back(rec[pid].second, rec[ridx+j-i].second,
info.GetWeight(k) * weight_normalization_factor);
}
}
}
}
i = j;
}
// get lambda weight for the pairs
this->GetLambdaWeight(lst, &pairs);
// rescale each gradient and hessian so that the lst have constant weighted
float scale = 1.0f / param_.num_pairsample;
if (param_.fix_list_weight != 0.0f) {
scale *= param_.fix_list_weight / (gptr[k + 1] - gptr[k]);
}
for (auto & pair : pairs) {
const ListEntry &pos = lst[pair.pos_index];
const ListEntry &neg = lst[pair.neg_index];
const bst_float w = pair.weight * scale;
const float eps = 1e-16f;
bst_float p = common::Sigmoid(pos.pred - neg.pred);
bst_float g = p - 1.0f;
bst_float h = std::max(p * (1.0f - p), eps);
// accumulate gradient and hessian in both pid, and nid
gpair[pos.rindex] += GradientPair(g * w, 2.0f*w*h);
gpair[neg.rindex] += GradientPair(-g * w, 2.0f*w*h);
}
}
}
}
const char* DefaultEvalMetric() const override {
return "map";
}
protected:
/*! \brief helper information in a list */
struct ListEntry {
/*! \brief the predict score we in the data */
bst_float pred;
/*! \brief the actual label of the entry */
bst_float label;
/*! \brief row index in the data matrix */
unsigned rindex;
// constructor
ListEntry(bst_float pred, bst_float label, unsigned rindex)
: pred(pred), label(label), rindex(rindex) {}
// comparator by prediction
inline static bool CmpPred(const ListEntry &a, const ListEntry &b) {
return a.pred > b.pred;
}
// comparator by label
inline static bool CmpLabel(const ListEntry &a, const ListEntry &b) {
return a.label > b.label;
}
};
/*! \brief a pair in the lambda rank */
struct LambdaPair {
/*! \brief positive index: this is a position in the list */
unsigned pos_index;
/*! \brief negative index: this is a position in the list */
unsigned neg_index;
/*! \brief weight to be filled in */
bst_float weight;
// constructor
LambdaPair(unsigned pos_index, unsigned neg_index)
: pos_index(pos_index), neg_index(neg_index), weight(1.0f) {}
// constructor
LambdaPair(unsigned pos_index, unsigned neg_index, bst_float weight)
: pos_index(pos_index), neg_index(neg_index), weight(weight) {}
};
/*!
* \brief get lambda weight for existing pairs
* \param list a list that is sorted by pred score
* \param io_pairs record of pairs, containing the pairs to fill in weights
*/
virtual void GetLambdaWeight(const std::vector<ListEntry> &sorted_list,
std::vector<LambdaPair> *io_pairs) = 0;
private:
LambdaRankParam param_;
};
class PairwiseRankObj: public LambdaRankObj{
protected:
void GetLambdaWeight(const std::vector<ListEntry> &sorted_list,
std::vector<LambdaPair> *io_pairs) override {}
};
// beta version: NDCG lambda rank
class LambdaRankObjNDCG : public LambdaRankObj {
protected:
void GetLambdaWeight(const std::vector<ListEntry> &sorted_list,
std::vector<LambdaPair> *io_pairs) override {
std::vector<LambdaPair> &pairs = *io_pairs;
float IDCG; // NOLINT
{
std::vector<bst_float> labels(sorted_list.size());
for (size_t i = 0; i < sorted_list.size(); ++i) {
labels[i] = sorted_list[i].label;
}
std::sort(labels.begin(), labels.end(), std::greater<bst_float>());
IDCG = CalcDCG(labels);
}
if (IDCG == 0.0) {
for (auto & pair : pairs) {
pair.weight = 0.0f;
}
} else {
IDCG = 1.0f / IDCG;
for (auto & pair : pairs) {
unsigned pos_idx = pair.pos_index;
unsigned neg_idx = pair.neg_index;
float pos_loginv = 1.0f / std::log2(pos_idx + 2.0f);
float neg_loginv = 1.0f / std::log2(neg_idx + 2.0f);
auto pos_label = static_cast<int>(sorted_list[pos_idx].label);
auto neg_label = static_cast<int>(sorted_list[neg_idx].label);
bst_float original =
((1 << pos_label) - 1) * pos_loginv + ((1 << neg_label) - 1) * neg_loginv;
float changed =
((1 << neg_label) - 1) * pos_loginv + ((1 << pos_label) - 1) * neg_loginv;
bst_float delta = (original - changed) * IDCG;
if (delta < 0.0f) delta = - delta;
pair.weight *= delta;
}
}
}
inline static bst_float CalcDCG(const std::vector<bst_float> &labels) {
double sumdcg = 0.0;
for (size_t i = 0; i < labels.size(); ++i) {
const auto rel = static_cast<unsigned>(labels[i]);
if (rel != 0) {
sumdcg += ((1 << rel) - 1) / std::log2(static_cast<bst_float>(i + 2));
}
}
return static_cast<bst_float>(sumdcg);
}
};
class LambdaRankObjMAP : public LambdaRankObj {
protected:
struct MAPStats {
/*! \brief the accumulated precision */
float ap_acc;
/*!
* \brief the accumulated precision,
* assuming a positive instance is missing
*/
float ap_acc_miss;
/*!
* \brief the accumulated precision,
* assuming that one more positive instance is inserted ahead
*/
float ap_acc_add;
/* \brief the accumulated positive instance count */
float hits;
MAPStats() = default;
MAPStats(float ap_acc, float ap_acc_miss, float ap_acc_add, float hits)
: ap_acc(ap_acc), ap_acc_miss(ap_acc_miss), ap_acc_add(ap_acc_add), hits(hits) {}
};
/*!
* \brief Obtain the delta MAP if trying to switch the positions of instances in index1 or index2
* in sorted triples
* \param sorted_list the list containing entry information
* \param index1,index2 the instances switched
* \param map_stats a vector containing the accumulated precisions for each position in a list
*/
inline bst_float GetLambdaMAP(const std::vector<ListEntry> &sorted_list,
int index1, int index2,
std::vector<MAPStats> *p_map_stats) {
std::vector<MAPStats> &map_stats = *p_map_stats;
if (index1 == index2 || map_stats[map_stats.size() - 1].hits == 0) {
return 0.0f;
}
if (index1 > index2) std::swap(index1, index2);
bst_float original = map_stats[index2].ap_acc;
if (index1 != 0) original -= map_stats[index1 - 1].ap_acc;
bst_float changed = 0;
bst_float label1 = sorted_list[index1].label > 0.0f ? 1.0f : 0.0f;
bst_float label2 = sorted_list[index2].label > 0.0f ? 1.0f : 0.0f;
if (label1 == label2) {
return 0.0;
} else if (label1 < label2) {
changed += map_stats[index2 - 1].ap_acc_add - map_stats[index1].ap_acc_add;
changed += (map_stats[index1].hits + 1.0f) / (index1 + 1);
} else {
changed += map_stats[index2 - 1].ap_acc_miss - map_stats[index1].ap_acc_miss;
changed += map_stats[index2].hits / (index2 + 1);
}
bst_float ans = (changed - original) / (map_stats[map_stats.size() - 1].hits);
if (ans < 0) ans = -ans;
return ans;
}
/*
* \brief obtain preprocessing results for calculating delta MAP
* \param sorted_list the list containing entry information
* \param map_stats a vector containing the accumulated precisions for each position in a list
*/
inline void GetMAPStats(const std::vector<ListEntry> &sorted_list,
std::vector<MAPStats> *p_map_acc) {
std::vector<MAPStats> &map_acc = *p_map_acc;
map_acc.resize(sorted_list.size());
bst_float hit = 0, acc1 = 0, acc2 = 0, acc3 = 0;
for (size_t i = 1; i <= sorted_list.size(); ++i) {
if (sorted_list[i - 1].label > 0.0f) {
hit++;
acc1 += hit / i;
acc2 += (hit - 1) / i;
acc3 += (hit + 1) / i;
}
map_acc[i - 1] = MAPStats(acc1, acc2, acc3, hit);
}
}
void GetLambdaWeight(const std::vector<ListEntry> &sorted_list,
std::vector<LambdaPair> *io_pairs) override {
std::vector<LambdaPair> &pairs = *io_pairs;
std::vector<MAPStats> map_stats;
GetMAPStats(sorted_list, &map_stats);
for (auto & pair : pairs) {
pair.weight *=
GetLambdaMAP(sorted_list, pair.pos_index,
pair.neg_index, &map_stats);
}
}
};
// register the objective functions
DMLC_REGISTER_PARAMETER(LambdaRankParam);
XGBOOST_REGISTER_OBJECTIVE(PairwiseRankObj, "rank:pairwise")
.describe("Pairwise rank objective.")
.set_body([]() { return new PairwiseRankObj(); });
XGBOOST_REGISTER_OBJECTIVE(LambdaRankNDCG, "rank:ndcg")
.describe("LambdaRank with NDCG as objective.")
.set_body([]() { return new LambdaRankObjNDCG(); });
XGBOOST_REGISTER_OBJECTIVE(LambdaRankObjMAP, "rank:map")
.describe("LambdaRank with MAP as objective.")
.set_body([]() { return new LambdaRankObjMAP(); });
} // namespace obj
} // namespace xgboost