-
Notifications
You must be signed in to change notification settings - Fork 3.8k
/
binary_objective.hpp
216 lines (199 loc) · 7.51 KB
/
binary_objective.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
/*!
* Copyright (c) 2016 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/
#ifndef LIGHTGBM_OBJECTIVE_BINARY_OBJECTIVE_HPP_
#define LIGHTGBM_OBJECTIVE_BINARY_OBJECTIVE_HPP_
#include <LightGBM/network.h>
#include <LightGBM/objective_function.h>
#include <string>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <vector>
namespace LightGBM {
/*!
* \brief Objective function for binary classification
*/
class BinaryLogloss: public ObjectiveFunction {
public:
explicit BinaryLogloss(const Config& config,
std::function<bool(label_t)> is_pos = nullptr)
: deterministic_(config.deterministic) {
sigmoid_ = static_cast<double>(config.sigmoid);
if (sigmoid_ <= 0.0) {
Log::Fatal("Sigmoid parameter %f should be greater than zero", sigmoid_);
}
is_unbalance_ = config.is_unbalance;
scale_pos_weight_ = static_cast<double>(config.scale_pos_weight);
if (is_unbalance_ && std::fabs(scale_pos_weight_ - 1.0f) > 1e-6) {
Log::Fatal("Cannot set is_unbalance and scale_pos_weight at the same time");
}
is_pos_ = is_pos;
if (is_pos_ == nullptr) {
is_pos_ = [](label_t label) { return label > 0; };
}
}
explicit BinaryLogloss(const std::vector<std::string>& strs)
: deterministic_(false) {
sigmoid_ = -1;
for (auto str : strs) {
auto tokens = Common::Split(str.c_str(), ':');
if (tokens.size() == 2) {
if (tokens[0] == std::string("sigmoid")) {
Common::Atof(tokens[1].c_str(), &sigmoid_);
}
}
}
if (sigmoid_ <= 0.0) {
Log::Fatal("Sigmoid parameter %f should be greater than zero", sigmoid_);
}
}
~BinaryLogloss() {}
void Init(const Metadata& metadata, data_size_t num_data) override {
num_data_ = num_data;
label_ = metadata.label();
weights_ = metadata.weights();
data_size_t cnt_positive = 0;
data_size_t cnt_negative = 0;
// count for positive and negative samples
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) reduction(+:cnt_positive, cnt_negative)
for (data_size_t i = 0; i < num_data_; ++i) {
if (is_pos_(label_[i])) {
++cnt_positive;
} else {
++cnt_negative;
}
}
num_pos_data_ = cnt_positive;
if (Network::num_machines() > 1) {
cnt_positive = Network::GlobalSyncUpBySum(cnt_positive);
cnt_negative = Network::GlobalSyncUpBySum(cnt_negative);
}
need_train_ = true;
if (cnt_negative == 0 || cnt_positive == 0) {
Log::Warning("Contains only one class");
// not need to boost.
need_train_ = false;
}
Log::Info("Number of positive: %d, number of negative: %d", cnt_positive, cnt_negative);
// use -1 for negative class, and 1 for positive class
label_val_[0] = -1;
label_val_[1] = 1;
// weight for label
label_weights_[0] = 1.0f;
label_weights_[1] = 1.0f;
// if using unbalance, change the labels weight
if (is_unbalance_ && cnt_positive > 0 && cnt_negative > 0) {
if (cnt_positive > cnt_negative) {
label_weights_[1] = 1.0f;
label_weights_[0] = static_cast<double>(cnt_positive) / cnt_negative;
} else {
label_weights_[1] = static_cast<double>(cnt_negative) / cnt_positive;
label_weights_[0] = 1.0f;
}
}
label_weights_[1] *= scale_pos_weight_;
}
void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override {
if (!need_train_) {
return;
}
if (weights_ == nullptr) {
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (data_size_t i = 0; i < num_data_; ++i) {
// get label and label weights
const int is_pos = is_pos_(label_[i]);
const int label = label_val_[is_pos];
const double label_weight = label_weights_[is_pos];
// calculate gradients and hessians
const double response = -label * sigmoid_ / (1.0f + std::exp(label * sigmoid_ * score[i]));
const double abs_response = fabs(response);
gradients[i] = static_cast<score_t>(response * label_weight);
hessians[i] = static_cast<score_t>(abs_response * (sigmoid_ - abs_response) * label_weight);
}
} else {
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (data_size_t i = 0; i < num_data_; ++i) {
// get label and label weights
const int is_pos = is_pos_(label_[i]);
const int label = label_val_[is_pos];
const double label_weight = label_weights_[is_pos];
// calculate gradients and hessians
const double response = -label * sigmoid_ / (1.0f + std::exp(label * sigmoid_ * score[i]));
const double abs_response = fabs(response);
gradients[i] = static_cast<score_t>(response * label_weight * weights_[i]);
hessians[i] = static_cast<score_t>(abs_response * (sigmoid_ - abs_response) * label_weight * weights_[i]);
}
}
}
// implement custom average to boost from (if enabled among options)
double BoostFromScore(int) const override {
double suml = 0.0f;
double sumw = 0.0f;
if (weights_ != nullptr) {
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) reduction(+:suml, sumw) if (!deterministic_)
for (data_size_t i = 0; i < num_data_; ++i) {
suml += is_pos_(label_[i]) * weights_[i];
sumw += weights_[i];
}
} else {
sumw = static_cast<double>(num_data_);
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) reduction(+:suml) if (!deterministic_)
for (data_size_t i = 0; i < num_data_; ++i) {
suml += is_pos_(label_[i]);
}
}
if (Network::num_machines() > 1) {
suml = Network::GlobalSyncUpBySum(suml);
sumw = Network::GlobalSyncUpBySum(sumw);
}
double pavg = suml / sumw;
pavg = std::min(pavg, 1.0 - kEpsilon);
pavg = std::max<double>(pavg, kEpsilon);
double initscore = std::log(pavg / (1.0f - pavg)) / sigmoid_;
Log::Info("[%s:%s]: pavg=%f -> initscore=%f", GetName(), __func__, pavg, initscore);
return initscore;
}
bool ClassNeedTrain(int /*class_id*/) const override {
return need_train_;
}
const char* GetName() const override {
return "binary";
}
void ConvertOutput(const double* input, double* output) const override {
output[0] = 1.0f / (1.0f + std::exp(-sigmoid_ * input[0]));
}
std::string ToString() const override {
std::stringstream str_buf;
str_buf << GetName() << " ";
str_buf << "sigmoid:" << sigmoid_;
return str_buf.str();
}
bool SkipEmptyClass() const override { return true; }
bool NeedAccuratePrediction() const override { return false; }
data_size_t NumPositiveData() const override { return num_pos_data_; }
protected:
/*! \brief Number of data */
data_size_t num_data_;
/*! \brief Number of positive samples */
data_size_t num_pos_data_;
/*! \brief Pointer of label */
const label_t* label_;
/*! \brief True if using unbalance training */
bool is_unbalance_;
/*! \brief Sigmoid parameter */
double sigmoid_;
/*! \brief Values for positive and negative labels */
int label_val_[2];
/*! \brief Weights for positive and negative labels */
double label_weights_[2];
/*! \brief Weights for data */
const label_t* weights_;
double scale_pos_weight_;
std::function<bool(label_t)> is_pos_;
bool need_train_;
const bool deterministic_;
};
} // namespace LightGBM
#endif // LightGBM_OBJECTIVE_BINARY_OBJECTIVE_HPP_