Skip to content

Commit

Permalink
refine prediction logic. (#395)
Browse files Browse the repository at this point in the history
* refine prediction logic.

* fix test.

* fix out_len in training score of Dart.

* improve predict speed for high dimension data.
  • Loading branch information
guolinke committed Apr 10, 2017
1 parent f1ffc10 commit 71660f1
Show file tree
Hide file tree
Showing 20 changed files with 406 additions and 348 deletions.
28 changes: 19 additions & 9 deletions include/LightGBM/boosting.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,27 +110,29 @@ class LIGHTGBM_EXPORT Boosting {
*/
virtual void GetPredictAt(int data_idx, double* result, int64_t* out_len) = 0;

virtual int NumPredictOneRow(int num_iteration, int is_pred_leaf) const = 0;

/*!
* \brief Prediction for one record, not sigmoid transform
* \param feature_values Feature value on this record
* \return Prediction result for this record
* \param output Prediction result for this record
*/
virtual std::vector<double> PredictRaw(const double* feature_values) const = 0;
virtual void PredictRaw(const double* feature_values, double* output) const = 0;

/*!
* \brief Prediction for one record, sigmoid transformation will be used if needed
* \param feature_values Feature value on this record
* \return Prediction result for this record
* \param output Prediction result for this record
*/
virtual std::vector<double> Predict(const double* feature_values) const = 0;
virtual void Predict(const double* feature_values, double* output) const = 0;

/*!
* \brief Prediction for one record with leaf index
* \param feature_values Feature value on this record
* \return Predicted leaf index for this record
* \param output Prediction result for this record
*/
virtual std::vector<int> PredictLeafIndex(
const double* feature_values) const = 0;
virtual void PredictLeafIndex(
const double* feature_values, double* output) const = 0;

/*!
* \brief Dump model to json format string
Expand Down Expand Up @@ -185,16 +187,24 @@ class LIGHTGBM_EXPORT Boosting {
*/
virtual int NumberOfTotalModel() const = 0;

/*!
* \brief Get number of trees per iteration
* \return Number of trees per iteration
*/
virtual int NumTreePerIteration() const = 0;

/*!
* \brief Get number of classes
* \return Number of classes
*/
virtual int NumberOfClasses() const = 0;

/*!
* \brief Set number of used model for prediction
* \brief Initial work for the prediction
* \param num_iteration number of used iteration
* \return the feature indices mapper
*/
virtual void SetNumIterationForPred(int num_iteration) = 0;
virtual std::vector<int> InitPredict(int num_iteration) = 0;

/*!
* \brief Name of submodel
Expand Down
2 changes: 1 addition & 1 deletion include/LightGBM/meta.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ const score_t kEpsilon = 1e-15f;
using ReduceFunction = std::function<void(const char*, char*, int)>;

using PredictFunction =
std::function<std::vector<double>(const std::vector<std::pair<int, double>>&)>;
std::function<void(const std::vector<std::pair<int, double>>&, double* output)>;

#define NO_SPECIFIC (-1)

Expand Down
3 changes: 1 addition & 2 deletions include/LightGBM/metric.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ class Metric {
* \brief Calcaluting and printing metric result
* \param score Current prediction score
*/
virtual std::vector<double> Eval(const double* score, const ObjectiveFunction* objective,
int num_tree_per_iteration) const = 0;
virtual std::vector<double> Eval(const double* score, const ObjectiveFunction* objective) const = 0;

Metric() = default;
/*! \brief Disable copy */
Expand Down
10 changes: 4 additions & 6 deletions include/LightGBM/objective_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,12 @@ class ObjectiveFunction {

virtual bool SkipEmptyClass() const { return false; }

virtual int numTreePerIteration() const { return 1; }
virtual int NumTreePerIteration() const { return 1; }

virtual std::vector<double> ConvertOutput(std::vector<double>& input) const {
return input;
}
virtual int NumPredictOneRow() const { return 1; }

virtual double ConvertOutput(double input) const {
return input;
virtual void ConvertOutput(const double* input, double* output) const {
output[0] = input[0];
}

virtual std::string ToString() const = 0;
Expand Down
36 changes: 28 additions & 8 deletions include/LightGBM/tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,13 @@ class Tree {
shrinkage_ *= rate;
}

inline void ReMapFeature(const std::vector<int>& feature_mapper) {
mapped_feature_ = split_feature_;
for (int i = 0; i < num_leaves_ - 1; ++i) {
mapped_feature_[i] = feature_mapper[split_feature_[i]];
}
}

/*! \brief Serialize this object to string*/
std::string ToString();

Expand Down Expand Up @@ -194,9 +201,10 @@ class Tree {
std::vector<int> leaf_depth_;
double shrinkage_;
bool has_categorical_;
/*! \brief buffer of mapped split_feature_ */
std::vector<int> mapped_feature_;
};


inline double Tree::Predict(const double* feature_values) const {
if (num_leaves_ > 1) {
int leaf = GetLeaf(feature_values);
Expand All @@ -217,13 +225,25 @@ inline int Tree::PredictLeafIndex(const double* feature_values) const {

inline int Tree::GetLeaf(const double* feature_values) const {
int node = 0;
while (node >= 0) {
if (decision_funs[decision_type_[node]](
feature_values[split_feature_[node]],
threshold_[node])) {
node = left_child_[node];
} else {
node = right_child_[node];
if (has_categorical_) {
while (node >= 0) {
if (decision_funs[decision_type_[node]](
feature_values[mapped_feature_[node]],
threshold_[node])) {
node = left_child_[node];
} else {
node = right_child_[node];
}
}
} else {
while (node >= 0) {
if (NumericalDecision<double>(
feature_values[mapped_feature_[node]],
threshold_[node])) {
node = left_child_[node];
} else {
node = right_child_[node];
}
}
}
return ~node;
Expand Down
12 changes: 6 additions & 6 deletions include/LightGBM/utils/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,18 +377,18 @@ inline void Softmax(std::vector<double>* p_rec) {
}
}

inline void Softmax(double* rec, int len) {
double wmax = rec[0];
inline void Softmax(const double* input, double* output, int len) {
double wmax = input[0];
for (int i = 1; i < len; ++i) {
wmax = std::max(rec[i], wmax);
wmax = std::max(input[i], wmax);
}
double wsum = 0.0f;
for (int i = 0; i < len; ++i) {
rec[i] = std::exp(rec[i] - wmax);
wsum += rec[i];
output[i] = std::exp(input[i] - wmax);
wsum += output[i];
}
for (int i = 0; i < len; ++i) {
rec[i] /= static_cast<double>(wsum);
output[i] /= static_cast<double>(wsum);
}
}

Expand Down
41 changes: 19 additions & 22 deletions src/application/application.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ void Application::LoadParameters(int argc, char** argv) {
continue;
}
params[key] = value;
}
else {
} else {
Log::Warning("Unknown parameter in command line: %s", argv[i]);
}
}
Expand Down Expand Up @@ -86,14 +85,13 @@ void Application::LoadParameters(int argc, char** argv) {
if (params.count(key) == 0) {
params[key] = value;
}
}
else {
} else {
Log::Warning("Unknown parameter in config file: %s", line.c_str());
}
}
} else {
Log::Warning("Config file %s doesn't exist, will ignore",
params["config_file"].c_str());
params["config_file"].c_str());
}
}
// check for alias again
Expand All @@ -110,23 +108,23 @@ void Application::LoadData() {
PredictFunction predict_fun = nullptr;
// need to continue training
if (boosting_->NumberOfTotalModel() > 0) {
predictor.reset(new Predictor(boosting_.get(), true, false));
predictor.reset(new Predictor(boosting_.get(), -1, true, false));
predict_fun = predictor->GetPredictFunction();
}

// sync up random seed for data partition
if (config_.is_parallel_find_bin) {
config_.io_config.data_random_seed =
GlobalSyncUpByMin<int>(config_.io_config.data_random_seed);
GlobalSyncUpByMin<int>(config_.io_config.data_random_seed);
}

DatasetLoader dataset_loader(config_.io_config, predict_fun,
boosting_->NumberOfClasses(), config_.io_config.data_filename.c_str());
config_.boosting_config.num_class, config_.io_config.data_filename.c_str());
// load Training data
if (config_.is_parallel_find_bin) {
// load data for parallel training
train_data_.reset(dataset_loader.LoadFromFile(config_.io_config.data_filename.c_str(),
Network::rank(), Network::num_machines()));
Network::rank(), Network::num_machines()));
} else {
// load data for single machine
train_data_.reset(dataset_loader.LoadFromFile(config_.io_config.data_filename.c_str(), 0, 1));
Expand Down Expand Up @@ -170,7 +168,7 @@ void Application::LoadData() {
auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
if (metric == nullptr) { continue; }
metric->Init(valid_datas_.back()->metadata(),
valid_datas_.back()->num_data());
valid_datas_.back()->num_data());
valid_metrics_.back().push_back(std::move(metric));
}
valid_metrics_.back().shrink_to_fit();
Expand All @@ -181,7 +179,7 @@ void Application::LoadData() {
auto end_time = std::chrono::high_resolution_clock::now();
// output used time on each iteration
Log::Info("Finished loading data in %f seconds",
std::chrono::duration<double, std::milli>(end_time - start_time) * 1e-3);
std::chrono::duration<double, std::milli>(end_time - start_time) * 1e-3);
}

void Application::InitTrain() {
Expand All @@ -201,22 +199,22 @@ void Application::InitTrain() {
// create boosting
boosting_.reset(
Boosting::CreateBoosting(config_.boosting_type,
config_.io_config.input_model.c_str()));
config_.io_config.input_model.c_str()));
// create objective function
objective_fun_.reset(
ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
config_.objective_config));
config_.objective_config));
// load training data
LoadData();
// initialize the objective function
objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
// initialize the boosting
boosting_->Init(&config_.boosting_config, train_data_.get(), objective_fun_.get(),
Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
// add validation data into boosting
for (size_t i = 0; i < valid_datas_.size(); ++i) {
boosting_->AddValidDataset(valid_datas_[i].get(),
Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_[i]));
Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_[i]));
}
Log::Info("Finished initializing training");
}
Expand All @@ -232,7 +230,7 @@ void Application::Train() {
auto end_time = std::chrono::steady_clock::now();
// output used time per iteration
Log::Info("%f seconds elapsed, finished iteration %d", std::chrono::duration<double,
std::milli>(end_time - start_time) * 1e-3, iter + 1);
std::milli>(end_time - start_time) * 1e-3, iter + 1);
}
// save model to file
boosting_->SaveModelToFile(-1, config_.io_config.output_model.c_str());
Expand All @@ -241,12 +239,11 @@ void Application::Train() {


void Application::Predict() {
boosting_->SetNumIterationForPred(config_.io_config.num_iteration_predict);
// create predictor
Predictor predictor(boosting_.get(), config_.io_config.is_predict_raw_score,
config_.io_config.is_predict_leaf_index);
Predictor predictor(boosting_.get(), config_.io_config.num_iteration_predict, config_.io_config.is_predict_raw_score,
config_.io_config.is_predict_leaf_index);
predictor.Predict(config_.io_config.data_filename.c_str(),
config_.io_config.output_result.c_str(), config_.io_config.has_header);
config_.io_config.output_result.c_str(), config_.io_config.has_header);
Log::Info("Finished prediction");
}

Expand All @@ -264,9 +261,9 @@ T Application::GlobalSyncUpByMin(T& local) {
return global;
}
Network::Allreduce(reinterpret_cast<char*>(&local),
sizeof(local), sizeof(local),
sizeof(local), sizeof(local),
reinterpret_cast<char*>(&global),
[](const char* src, char* dst, int len) {
[](const char* src, char* dst, int len) {
int used_size = 0;
const int type_size = sizeof(T);
const T *p1;
Expand Down

0 comments on commit 71660f1

Please sign in to comment.