Skip to content

Commit

Permalink
[java][mmlspark] Fix cached predictor causing bad values for predicte…
Browse files Browse the repository at this point in the history
…d probabilities (#2356)

* [mmlspark] Fix cached predictor causing bad values for predicted probabilities

* updated based on comments

* removed tabs
  • Loading branch information
imatiach-msft authored and guolinke committed Aug 30, 2019
1 parent 254a869 commit 317b1bf
Showing 1 changed file with 56 additions and 25 deletions.
81 changes: 56 additions & 25 deletions src/c_api.cpp
Expand Up @@ -46,6 +46,55 @@ catch(std::string& ex) { return LGBM_APIHandleException(ex); } \
catch(...) { return LGBM_APIHandleException("unknown exception"); } \
return 0;

const int PREDICTOR_TYPES = 4;

// Single row predictor to abstract away caching logic
class SingleRowPredictor {
public:
PredictFunction predict_function;
int64_t num_pred_in_one_row;

SingleRowPredictor(int predict_type, Boosting& boosting, const Config& config, int iter) {
bool is_predict_leaf = false;
bool is_raw_score = false;
bool predict_contrib = false;
if (predict_type == C_API_PREDICT_LEAF_INDEX) {
is_predict_leaf = true;
} else if (predict_type == C_API_PREDICT_RAW_SCORE) {
is_raw_score = true;
} else if (predict_type == C_API_PREDICT_CONTRIB) {
predict_contrib = true;
} else {
is_raw_score = false;
}
early_stop_ = config.pred_early_stop;
early_stop_freq_ = config.pred_early_stop_freq;
early_stop_margin_ = config.pred_early_stop_margin;
iter_ = iter;
predictor_.reset(new Predictor(&boosting, iter_, is_raw_score, is_predict_leaf, predict_contrib,
early_stop_, early_stop_freq_, early_stop_margin_));
num_pred_in_one_row = boosting.NumPredictOneRow(iter_, is_predict_leaf, predict_contrib);
predict_function = predictor_->GetPredictFunction();
num_total_model_ = boosting.NumberOfTotalModel();
}
~SingleRowPredictor() {}
bool IsPredictorEqual(const Config& config, int iter, Boosting& boosting) {
return early_stop_ != config.pred_early_stop ||
early_stop_freq_ != config.pred_early_stop_freq ||
early_stop_margin_ != config.pred_early_stop_margin ||
iter_ != iter ||
num_total_model_ != boosting.NumberOfTotalModel();
}

private:
std::unique_ptr<Predictor> predictor_;
bool early_stop_;
int early_stop_freq_;
double early_stop_margin_;
int iter_;
int num_total_model_;
};

class Booster {
public:
explicit Booster(const char* filename) {
Expand Down Expand Up @@ -205,33 +254,17 @@ class Booster {
const Config& config,
double* out_result, int64_t* out_len) {
std::lock_guard<std::mutex> lock(mutex_);

if (single_row_predictor_.get() == nullptr) {
bool is_predict_leaf = false;
bool is_raw_score = false;
bool predict_contrib = false;
if (predict_type == C_API_PREDICT_LEAF_INDEX) {
is_predict_leaf = true;
} else if (predict_type == C_API_PREDICT_RAW_SCORE) {
is_raw_score = true;
} else if (predict_type == C_API_PREDICT_CONTRIB) {
predict_contrib = true;
} else {
is_raw_score = false;
}

// TODO(eisber): config could be optimized away... (maybe using lambda callback?)
single_row_predictor_.reset(new Predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin));
single_row_num_pred_in_one_row_ = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf, predict_contrib);
single_row_predict_function_ = single_row_predictor_->GetPredictFunction();
if (single_row_predictor_[predict_type].get() == nullptr ||
!single_row_predictor_[predict_type]->IsPredictorEqual(config, num_iteration, *boosting_.get())) {
single_row_predictor_[predict_type].reset(new SingleRowPredictor(predict_type, *boosting_.get(),
config, num_iteration));
}

auto one_row = get_row_fun(0);
auto pred_wrt_ptr = out_result;
single_row_predict_function_(one_row, pred_wrt_ptr);
single_row_predictor_[predict_type]->predict_function(one_row, pred_wrt_ptr);

*out_len = single_row_num_pred_in_one_row_;
*out_len = single_row_predictor_[predict_type]->num_pred_in_one_row;
}


Expand Down Expand Up @@ -364,9 +397,7 @@ class Booster {
private:
const Dataset* train_data_;
std::unique_ptr<Boosting> boosting_;
std::unique_ptr<Predictor> single_row_predictor_;
PredictFunction single_row_predict_function_;
int64_t single_row_num_pred_in_one_row_;
std::unique_ptr<SingleRowPredictor> single_row_predictor_[PREDICTOR_TYPES];

/*! \brief All configs */
Config config_;
Expand Down

0 comments on commit 317b1bf

Please sign in to comment.