Skip to content

Commit

Permalink
Allow unique prediction vector for each input matrix (#4275)
Browse files Browse the repository at this point in the history
  • Loading branch information
RAMitchell committed Mar 20, 2019
1 parent 09bd9e6 commit 8eab966
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions src/learner.cc
Expand Up @@ -485,10 +485,10 @@ class LearnerImpl : public Learner {
this->PerformTreeMethodHeuristic(train);

monitor_.Start("PredictRaw");
this->PredictRaw(train, &preds_);
this->PredictRaw(train, &preds_[train]);
monitor_.Stop("PredictRaw");
monitor_.Start("GetGradient");
obj_->GetGradient(preds_, train->Info(), iter, &gpair_);
obj_->GetGradient(preds_[train], train->Info(), iter, &gpair_);
monitor_.Stop("GetGradient");
gbm_->DoBoost(train, &gpair_, obj_.get());
monitor_.Stop("UpdateOneIter");
Expand Down Expand Up @@ -520,11 +520,12 @@ class LearnerImpl : public Learner {
metrics_.back()->Configure(cfg_.begin(), cfg_.end());
}
for (size_t i = 0; i < data_sets.size(); ++i) {
this->PredictRaw(data_sets[i], &preds_);
obj_->EvalTransform(&preds_);
DMatrix * dmat = data_sets[i];
this->PredictRaw(data_sets[i], &preds_[dmat]);
obj_->EvalTransform(&preds_[dmat]);
for (auto& ev : metrics_) {
os << '\t' << data_names[i] << '-' << ev->Name() << ':'
<< ev->Eval(preds_, data_sets[i]->Info(),
<< ev->Eval(preds_[dmat], data_sets[i]->Info(),
tparam_.dsplit == DataSplitMode::kRow);
}
}
Expand Down Expand Up @@ -565,10 +566,10 @@ class LearnerImpl : public Learner {
std::string metric) {
if (metric == "auto") metric = obj_->DefaultEvalMetric();
std::unique_ptr<Metric> ev(Metric::Create(metric.c_str()));
this->PredictRaw(data, &preds_);
obj_->EvalTransform(&preds_);
this->PredictRaw(data, &preds_[data]);
obj_->EvalTransform(&preds_[data]);
return std::make_pair(metric,
ev->Eval(preds_, data->Info(),
ev->Eval(preds_[data], data->Info(),
tparam_.dsplit == DataSplitMode::kRow));
}

Expand Down Expand Up @@ -771,7 +772,7 @@ class LearnerImpl : public Learner {
// name of objective function
std::string name_obj_;
// temporal storages for prediction
HostDeviceVector<bst_float> preds_;
std::map<DMatrix*, HostDeviceVector<bst_float>> preds_;
// gradient pairs
HostDeviceVector<GradientPair> gpair_;

Expand Down

0 comments on commit 8eab966

Please sign in to comment.