Skip to content

Commit

Permalink
Fix release degradation (dmlc#5720)
Browse files Browse the repository at this point in the history
* fix release degradation, related to 5666

* less resizes

Co-authored-by: SHVETS, KIRILL <kirill.shvets@intel.com>
  • Loading branch information
2 people authored and hcho3 committed May 31, 2020
1 parent 882b966 commit cff7013
Showing 1 changed file with 19 additions and 17 deletions.
36 changes: 19 additions & 17 deletions src/objective/regression_obj.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ struct RegLossParam : public XGBoostParameter<RegLossParam> {
template<typename Loss>
class RegLossObj : public ObjFunction {
protected:
HostDeviceVector<int> label_correct_;
HostDeviceVector<float> additional_input_;

public:
RegLossObj() = default;
// 0 - label_correct flag, 1 - scale_pos_weight, 2 - is_null_weight
RegLossObj(): additional_input_(3) {}

void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.UpdateAllowUnknown(args);
Expand All @@ -64,44 +65,45 @@ class RegLossObj : public ObjFunction {
size_t const ndata = preds.Size();
out_gpair->Resize(ndata);
auto device = tparam_->gpu_id;
label_correct_.Resize(1);
label_correct_.Fill(1);
additional_input_.HostVector().begin()[0] = 1; // Fill the label_correct flag

bool is_null_weight = info.weights_.Size() == 0;
if (!is_null_weight) {
CHECK_EQ(info.weights_.Size(), ndata)
<< "Number of weights should be equal to number of data points.";
}
auto scale_pos_weight = param_.scale_pos_weight;
common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t _idx,
common::Span<int> _label_correct,
additional_input_.HostVector().begin()[1] = scale_pos_weight;
additional_input_.HostVector().begin()[2] = is_null_weight;

common::Transform<>::Init([] XGBOOST_DEVICE(size_t _idx,
common::Span<float> _additional_input,
common::Span<GradientPair> _out_gpair,
common::Span<const bst_float> _preds,
common::Span<const bst_float> _labels,
common::Span<const bst_float> _weights) {
const float _scale_pos_weight = _additional_input[1];
const bool _is_null_weight = _additional_input[2];

bst_float p = Loss::PredTransform(_preds[_idx]);
bst_float w = is_null_weight ? 1.0f : _weights[_idx];
bst_float w = _is_null_weight ? 1.0f : _weights[_idx];
bst_float label = _labels[_idx];
if (label == 1.0f) {
w *= scale_pos_weight;
w *= _scale_pos_weight;
}
if (!Loss::CheckLabel(label)) {
// If there is an incorrect label, the host code will know.
_label_correct[0] = 0;
_additional_input[0] = 0;
}
_out_gpair[_idx] = GradientPair(Loss::FirstOrderGradient(p, label) * w,
Loss::SecondOrderGradient(p, label) * w);
},
common::Range{0, static_cast<int64_t>(ndata)}, device).Eval(
&label_correct_, out_gpair, &preds, &info.labels_, &info.weights_);
&additional_input_, out_gpair, &preds, &info.labels_, &info.weights_);

// copy "label correct" flags back to host
std::vector<int>& label_correct_h = label_correct_.HostVector();
for (auto const flag : label_correct_h) {
if (flag == 0) {
LOG(FATAL) << Loss::LabelErrorMsg();
}
auto const flag = additional_input_.HostVector().begin()[0];
if (flag == 0) {
LOG(FATAL) << Loss::LabelErrorMsg();
}
}

Expand Down

0 comments on commit cff7013

Please sign in to comment.