Skip to content

Commit

Permalink
fix release degradation, related to 5666
Browse files Browse the repository at this point in the history
  • Loading branch information
SHVETS, KIRILL committed May 27, 2020
1 parent 78b4e95 commit 0278f3a
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions src/objective/regression_obj.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,26 @@ class RegLossObj : public ObjFunction {
<< "Number of weights should be equal to number of data points.";
}
auto scale_pos_weight = param_.scale_pos_weight;
common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t _idx,
HostDeviceVector<float> scale_pos_weight_;
scale_pos_weight_.Resize(1);
scale_pos_weight_.Fill(scale_pos_weight);
HostDeviceVector<int> is_null_weight_;
is_null_weight_.Resize(1);
is_null_weight_.Fill(is_null_weight);

common::Transform<>::Init([] XGBOOST_DEVICE(size_t _idx,
common::Span<int> _label_correct,
common::Span<GradientPair> _out_gpair,
common::Span<const bst_float> _preds,
common::Span<const bst_float> _labels,
common::Span<const bst_float> _weights) {
common::Span<const bst_float> _weights,
common::Span<int> _is_null_weight,
common::Span<float> _scale_pos_weight) {
bst_float p = Loss::PredTransform(_preds[_idx]);
bst_float w = is_null_weight ? 1.0f : _weights[_idx];
bst_float w = _is_null_weight[0] ? 1.0f : _weights[_idx];
bst_float label = _labels[_idx];
if (label == 1.0f) {
w *= scale_pos_weight;
w *= _scale_pos_weight[0];
}
if (!Loss::CheckLabel(label)) {
// If there is an incorrect label, the host code will know.
Expand All @@ -94,7 +102,8 @@ class RegLossObj : public ObjFunction {
Loss::SecondOrderGradient(p, label) * w);
},
common::Range{0, static_cast<int64_t>(ndata)}, device).Eval(
&label_correct_, out_gpair, &preds, &info.labels_, &info.weights_);
&label_correct_, out_gpair, &preds, &info.labels_, &info.weights_,
&is_null_weight_, &scale_pos_weight_);

// copy "label correct" flags back to host
std::vector<int>& label_correct_h = label_correct_.HostVector();
Expand Down

0 comments on commit 0278f3a

Please sign in to comment.