Skip to content

Commit

Permalink
[MRG] translate model to if-else (#469)
Browse files Browse the repository at this point in the history
* translate model to if-else

* support multiclass and predictleaf

* remove java option for now

* support multi-thread

* add task:convert_model
  • Loading branch information
wxchan authored and guolinke committed Apr 28, 2017
1 parent 7f94fd9 commit 8a19834
Show file tree
Hide file tree
Showing 9 changed files with 204 additions and 4 deletions.
5 changes: 5 additions & 0 deletions include/LightGBM/application.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ class Application {
/*! \brief Main predicting logic */
void Predict();

/*! \brief Main Convert model logic */
void ConvertModel();

/*! \brief All configs */
OverallConfig config_;
/*! \brief Training data */
Expand All @@ -80,6 +83,8 @@ inline void Application::Run() {
if (config_.task_type == TaskType::kPredict) {
InitPredict();
Predict();
} else if (config_.task_type == TaskType::kConvertModel) {
ConvertModel();
} else {
InitTrain();
Train();
Expand Down
16 changes: 16 additions & 0 deletions include/LightGBM/boosting.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,26 @@ class LIGHTGBM_EXPORT Boosting {

/*!
* \brief Dump model to json format string
* \param num_iteration Number of iterations that want to dump, -1 means dump all
* \return Json format string of model
*/
virtual std::string DumpModel(int num_iteration) const = 0;

/*!
* \brief Translate model to if-else statement
* \param num_iteration Number of iterations that want to translate, -1 means translate all
* \return if-else format codes of model
*/
virtual std::string ModelToIfElse(int num_iteration) const = 0;

/*!
* \brief Translate model to if-else statement
* \param num_iteration Number of iterations that want to translate, -1 means translate all
* \param filename Filename that want to save to
* \return is_finish Is training finished or not
*/
virtual bool SaveModelToIfElse(int num_iteration, const char* filename) const = 0;

/*!
* \brief Save model to file
* \param num_used_model Number of model that want to save, -1 means save all
Expand Down
4 changes: 3 additions & 1 deletion include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ struct ConfigBase {

/*! \brief Types of tasks */
enum TaskType {
kTrain, kPredict
kTrain, kPredict, kConvertModel
};

/*! \brief Config for input and output files */
Expand All @@ -93,6 +93,7 @@ struct IOConfig: public ConfigBase {
int snapshot_freq = 100;
std::string output_model = "LightGBM_model.txt";
std::string output_result = "LightGBM_predict_result.txt";
std::string convert_model = "LightGBM_convert_model.cpp";
std::string input_model = "";
int verbosity = 1;
int num_iteration_predict = -1;
Expand Down Expand Up @@ -269,6 +270,7 @@ struct OverallConfig: public ConfigBase {
ObjectiveConfig objective_config;
std::vector<std::string> metric_types;
MetricConfig metric_config;
std::string convert_model_language = "";

LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override;

Expand Down
6 changes: 6 additions & 0 deletions include/LightGBM/tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ class Tree {
/*! \brief Serialize this object to json*/
std::string ToJSON();

/*! \brief Serialize this object to if-else statement*/
std::string ToIfElse(int index, bool is_predict_leaf_index);

template<typename T>
static bool CategoricalDecision(T fval, T threshold) {
if (static_cast<int>(fval) == static_cast<int>(threshold)) {
Expand Down Expand Up @@ -160,6 +163,9 @@ class Tree {
/*! \brief Serialize one node to json*/
inline std::string NodeToJSON(int index);

/*! \brief Serialize one node to if-else statement*/
inline std::string NodeToIfElse(int index, bool is_predict_leaf_index);

/*! \brief Number of max leaves*/
int max_leaves_;
/*! \brief Number of current levas*/
Expand Down
14 changes: 12 additions & 2 deletions src/application/application.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Application::Application(int argc, char** argv) {
if (config_.num_threads > 0) {
omp_set_num_threads(config_.num_threads);
}
if (config_.io_config.data_filename.size() == 0) {
if (config_.io_config.data_filename.size() == 0 && config_.task_type != TaskType::kConvertModel) {
Log::Fatal("No training/prediction data, application quit");
}
}
Expand Down Expand Up @@ -239,10 +239,13 @@ void Application::Train() {
}
// save model to file
boosting_->SaveModelToFile(-1, config_.io_config.output_model.c_str());
// convert model to if-else statement code
if (config_.convert_model_language == std::string("cpp")) {
boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str());
}
Log::Info("Finished training");
}


void Application::Predict() {
// create predictor
Predictor predictor(boosting_.get(), config_.io_config.num_iteration_predict, config_.io_config.is_predict_raw_score,
Expand All @@ -258,6 +261,13 @@ void Application::InitPredict() {
Log::Info("Finished initializing prediction");
}

void Application::ConvertModel() {
boosting_.reset(
Boosting::CreateBoosting(config_.boosting_type,
config_.io_config.input_model.c_str()));
boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str());
}

template<typename T>
T Application::GlobalSyncUpByMin(T& local) {
T global = local;
Expand Down
93 changes: 93 additions & 0 deletions src/boosting/gbdt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,99 @@ std::string GBDT::DumpModel(int num_iteration) const {
return str_buf.str();
}

std::string GBDT::ModelToIfElse(int num_iteration) const {
std::stringstream str_buf;

int num_used_model = static_cast<int>(models_.size());
if (num_iteration > 0) {
num_iteration += boost_from_average_ ? 1 : 0;
num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
}

// PredictRaw
for (int i = 0; i < num_used_model; ++i) {
str_buf << models_[i]->ToIfElse(i, false) << std::endl;
}

str_buf << "double (*PredictTreePtr[])(const double*) = { ";
for (int i = 0; i < num_used_model; ++i) {
if (i > 0) {
str_buf << " , ";
}
str_buf << "PredictTree" << i;
}
str_buf << " };" << std::endl << std::endl;

std::stringstream pred_str_buf;

pred_str_buf << "\t" << "if (num_threads_ <= num_tree_per_iteration_) {" << std::endl;
pred_str_buf << "\t\t" << "#pragma omp parallel for schedule(static)" << std::endl;
pred_str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << std::endl;
pred_str_buf << "\t\t\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << std::endl;
pred_str_buf << "\t\t\t\t" << "output[k] += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << std::endl;
pred_str_buf << "\t\t\t" << "}" << std::endl;
pred_str_buf << "\t\t" << "}" << std::endl;
pred_str_buf << "\t" << "} else {" << std::endl;
pred_str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << std::endl;
pred_str_buf << "\t\t\t" << "double t = 0.0f;" << std::endl;
pred_str_buf << "\t\t\t" << "#pragma omp parallel for schedule(static) reduction(+:t)" << std::endl;
pred_str_buf << "\t\t\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << std::endl;
pred_str_buf << "\t\t\t\t" << "t += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << std::endl;
pred_str_buf << "\t\t\t" << "}" << std::endl;
pred_str_buf << "\t\t\t" << "output[k] = t;" << std::endl;
pred_str_buf << "\t\t" << "}" << std::endl;
pred_str_buf << "\t" << "}" << std::endl;

str_buf << "void GBDT::PredictRaw(const double* features, double *output) const {" << std::endl;
str_buf << pred_str_buf.str();
str_buf << "}" << std::endl;
str_buf << std::endl;

// Predict
str_buf << "void GBDT::Predict(const double* features, double *output) const {" << std::endl;
str_buf << pred_str_buf.str();
str_buf << "\t" << "if (objective_function_ != nullptr) {" << std::endl;
str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << std::endl;
str_buf << "\t" << "}" << std::endl;
str_buf << "}" << std::endl;
str_buf << std::endl;

// PredictLeafIndex
for (int i = 0; i < num_used_model; ++i) {
str_buf << models_[i]->ToIfElse(i, true) << std::endl;
}

str_buf << "double (*PredictTreeLeafPtr[])(const double*) = { ";
for (int i = 0; i < num_used_model; ++i) {
if (i > 0) {
str_buf << " , ";
}
str_buf << "PredictTree" << i << "Leaf";
}
str_buf << " };" << std::endl << std::endl;

str_buf << "void GBDT::PredictLeafIndex(const double* features, double *output) const {" << std::endl;
str_buf << "\t" << "int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;" << std::endl;
str_buf << "\t" << "#pragma omp parallel for schedule(static)" << std::endl;
str_buf << "\t" << "for (int i = 0; i < total_tree; ++i) {" << std::endl;
str_buf << "\t\t" << "output[i] = (*PredictTreeLeafPtr[i])(features);" << std::endl;
str_buf << "\t" << "}" << std::endl;
str_buf << "}" << std::endl;
return str_buf.str();
}

bool GBDT::SaveModelToIfElse(int num_iteration, const char* filename) const {
/*! \brief File to write models */
std::ofstream output_file;
output_file.open(filename);

output_file << ModelToIfElse(num_iteration);

output_file.close();

return (bool)output_file;
}

std::string GBDT::SaveModelToString(int num_iteration) const {
std::stringstream ss;

Expand Down
18 changes: 17 additions & 1 deletion src/boosting/gbdt.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,15 +144,31 @@ class GBDT: public Boosting {

/*!
* \brief Dump model to json format string
* \param num_iteration Number of iterations that want to dump, -1 means dump all
* \return Json format string of model
*/
std::string DumpModel(int num_iteration) const override;

/*!
* \brief Translate model to if-else statement
* \param num_iteration Number of iterations that want to translate, -1 means translate all
* \return if-else format codes of model
*/
std::string ModelToIfElse(int num_iteration) const override;

/*!
* \brief Translate model to if-else statement
* \param num_iteration Number of iterations that want to translate, -1 means translate all
* \param filename Filename that want to save to
* \return is_finish Is training finished or not
*/
bool SaveModelToIfElse(int num_iteration, const char* filename) const override;

/*!
* \brief Save model to file
* \param num_used_model Number of model that want to save, -1 means save all
* \param is_finish Is training finished or not
* \param filename Filename that want to save to
* \return is_finish Is training finished or not
*/
virtual bool SaveModelToFile(int num_iterations, const char* filename) const override;

Expand Down
4 changes: 4 additions & 0 deletions src/io/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ std::unordered_map<std::string, std::string> ConfigBase::Str2Map(const char* par
void OverallConfig::Set(const std::unordered_map<std::string, std::string>& params) {
// load main config types
GetInt(params, "num_threads", &num_threads);
GetString(params, "convert_model_language", &convert_model_language);

// generate seeds by seed.
if (GetInt(params, "seed", &seed)) {
Expand Down Expand Up @@ -129,6 +130,8 @@ void OverallConfig::GetTaskType(const std::unordered_map<std::string, std::strin
} else if (value == std::string("predict") || value == std::string("prediction")
|| value == std::string("test")) {
task_type = TaskType::kPredict;
} else if (value == std::string("convert_model")) {
task_type = TaskType::kConvertModel;
} else {
Log::Fatal("Unknown task type %s", value.c_str());
}
Expand Down Expand Up @@ -210,6 +213,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "snapshot_freq", &snapshot_freq);
GetString(params, "output_model", &output_model);
GetString(params, "input_model", &input_model);
GetString(params, "convert_model", &convert_model);
GetString(params, "output_result", &output_result);
std::string tmp_str = "";
if (GetString(params, "valid_data", &tmp_str)) {
Expand Down
48 changes: 48 additions & 0 deletions src/io/tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,54 @@ std::string Tree::NodeToJSON(int index) {
return str_buf.str();
}

std::string Tree::ToIfElse(int index, bool is_predict_leaf_index) {
std::stringstream str_buf;
str_buf << "double PredictTree" << index;
if (is_predict_leaf_index) {
str_buf << "Leaf";
}
str_buf << "(const double* arr) { ";
if (num_leaves_ == 1) {
str_buf << "return 0";
} else {
str_buf << NodeToIfElse(0, is_predict_leaf_index);
}
str_buf << " }" << std::endl;
return str_buf.str();
}

std::string Tree::NodeToIfElse(int index, bool is_predict_leaf_index) {
std::stringstream str_buf;
str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
if (index >= 0) {
// non-leaf
str_buf << "if ( arr[" << split_feature_[index] << "] ";
if (decision_type_[index] == 0) {
str_buf << "<";
} else {
str_buf << "=";
}
str_buf << "= " << threshold_[index] << " ) { ";
// left subtree
str_buf << NodeToIfElse(left_child_[index], is_predict_leaf_index);
str_buf << " } else { ";
// right subtree
str_buf << NodeToIfElse(right_child_[index], is_predict_leaf_index);
str_buf << " }";
} else {
// leaf
str_buf << "return ";
if (is_predict_leaf_index) {
str_buf << ~index;
} else {
str_buf << leaf_value_[~index];
}
str_buf << ";";
}

return str_buf.str();
}

Tree::Tree(const std::string& str) {
std::vector<std::string> lines = Common::Split(str.c_str(), '\n');
std::unordered_map<std::string, std::string> key_vals;
Expand Down

0 comments on commit 8a19834

Please sign in to comment.