Skip to content

Commit

Permalink
Merge pull request #7 from laipaang/sample-pr
Browse files Browse the repository at this point in the history
sample scale
  • Loading branch information
qingshui committed Jun 29, 2021
2 parents 1888e2b + bfbdca1 commit 5c4d959
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 6 deletions.
46 changes: 46 additions & 0 deletions paddle/fluid/framework/fleet/box_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,32 @@ void BasicAucCalculator::add_unlock_data(double pred, int label) {
++_table[label][pos];
}

void BasicAucCalculator::add_unlock_data(double pred, int label,
float sample_scale) {
PADDLE_ENFORCE_GE(pred, 0.0, platform::errors::PreconditionNotMet(
"pred should be greater than 0"));
PADDLE_ENFORCE_LE(pred, 1.0, platform::errors::PreconditionNotMet(
"pred should be lower than 1"));
PADDLE_ENFORCE_EQ(
label * label, label,
platform::errors::PreconditionNotMet(
"label must be equal to 0 or 1, but its value is: %d", label));
int pos = std::min(static_cast<int>(pred * _table_size), _table_size - 1);
PADDLE_ENFORCE_GE(
pos, 0,
platform::errors::PreconditionNotMet(
"pos must be equal or greater than 0, but its value is: %d", pos));
PADDLE_ENFORCE_LT(
pos, _table_size,
platform::errors::PreconditionNotMet(
"pos must be less than table_size, but its value is: %d", pos));
_local_abserr += fabs(pred - label);
_local_sqrerr += (pred - label) * (pred - label);

_local_pred += pred * sample_scale;
_table[label][pos] += sample_scale;
}

void BasicAucCalculator::add_data(const float* d_pred, const int64_t* d_label,
int batch_size,
const paddle::platform::Place& place) {
Expand All @@ -81,6 +107,26 @@ void BasicAucCalculator::add_data(const float* d_pred, const int64_t* d_label,
}
}
}

void BasicAucCalculator::add_sample_data(
const float* d_pred, const int64_t* d_label,
const std::vector<float>& d_sample_scale, int batch_size,
const paddle::platform::Place& place) {
thread_local std::vector<float> h_pred;
thread_local std::vector<int64_t> h_label;
h_pred.resize(batch_size);
h_label.resize(batch_size);
cudaMemcpy(h_pred.data(), d_pred, sizeof(float) * batch_size,
cudaMemcpyDeviceToHost);
cudaMemcpy(h_label.data(), d_label, sizeof(int64_t) * batch_size,
cudaMemcpyDeviceToHost);

std::lock_guard<std::mutex> lock(_table_mutex);
for (int i = 0; i < batch_size; ++i) {
add_unlock_data(h_pred[i], h_label[i], d_sample_scale[i]);
}
}

// add mask data
void BasicAucCalculator::add_mask_data(const float* d_pred,
const int64_t* d_label,
Expand Down
33 changes: 27 additions & 6 deletions paddle/fluid/framework/fleet/box_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,18 @@ class BasicAucCalculator {
void reset();
// add single data in CPU with LOCK, deprecated
void add_unlock_data(double pred, int label);
void add_unlock_data(double pred, int label, float sample_scale);
// add batch data
void add_data(const float* d_pred, const int64_t* d_label, int batch_size,
const paddle::platform::Place& place);
// add mask data
void add_mask_data(const float* d_pred, const int64_t* d_label,
const int64_t* d_mask, int batch_size,
const paddle::platform::Place& place);
// add sample data
void add_sample_data(const float* d_pred, const int64_t* d_label,
const std::vector<float>& d_sample_scale, int batch_size,
const paddle::platform::Place& place);
void compute();
int table_size() const { return _table_size; }
double bucket_error() const { return _bucket_error; }
Expand Down Expand Up @@ -669,9 +674,11 @@ class BoxWrapper {
MetricMsg() {}
MetricMsg(const std::string& label_varname, const std::string& pred_varname,
int metric_phase, int bucket_size = 1000000,
bool mode_collect_in_gpu = false, int max_batch_size = 0)
bool mode_collect_in_gpu = false, int max_batch_size = 0,
const std::string& sample_scale_varname = "")
: label_varname_(label_varname),
pred_varname_(pred_varname),
sample_scale_varname_(sample_scale_varname),
metric_phase_(metric_phase) {
calculator = new BasicAucCalculator(mode_collect_in_gpu);
calculator->init(bucket_size, max_batch_size);
Expand All @@ -692,7 +699,19 @@ class BoxWrapper {
platform::errors::PreconditionNotMet(
"the predict data length should be consistent with "
"the label data length"));
calculator->add_data(pred_data, label_data, label_len, place);
std::vector<float> sample_scale_data;
if (!sample_scale_varname_.empty()) {
get_data<float>(exe_scope, sample_scale_varname_, &sample_scale_data);
PADDLE_ENFORCE_EQ(
label_len, sample_scale_data.size(),
platform::errors::PreconditionNotMet(
"lable size [%lu] and sample_scale_data[%lu] should be same",
label_len, sample_scale_data.size()));
calculator->add_sample_data(pred_data, label_data, sample_scale_data,
label_len, place);
} else {
calculator->add_data(pred_data, label_data, label_len, place);
}
}
template <class T = float>
static void get_data(const Scope* exe_scope, const std::string& varname,
Expand Down Expand Up @@ -728,6 +747,7 @@ class BoxWrapper {
protected:
std::string label_varname_;
std::string pred_varname_;
std::string sample_scale_varname_;
int metric_phase_;
BasicAucCalculator* calculator;
};
Expand Down Expand Up @@ -1050,12 +1070,13 @@ class BoxWrapper {
const std::string& mask_varname, int metric_phase,
const std::string& cmatch_rank_group, bool ignore_rank,
int bucket_size = 1000000, bool mode_collect_in_gpu = false,
int max_batch_size = 0) {
int max_batch_size = 0,
const std::string& sample_scale_varname = "") {
if (method == "AucCalculator") {
metric_lists_.emplace(
name,
new MetricMsg(label_varname, pred_varname, metric_phase, bucket_size,
mode_collect_in_gpu, max_batch_size));
name, new MetricMsg(label_varname, pred_varname, metric_phase,
bucket_size, mode_collect_in_gpu, max_batch_size,
sample_scale_varname));
} else if (method == "MultiTaskAucCalculator") {
metric_lists_.emplace(
name, new MultiTaskMetricMsg(label_varname, pred_varname,
Expand Down

0 comments on commit 5c4d959

Please sign in to comment.