-
Notifications
You must be signed in to change notification settings - Fork 3.8k
/
application.h
88 lines (69 loc) · 2.13 KB
/
application.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#ifndef LIGHTGBM_APPLICATION_H_
#define LIGHTGBM_APPLICATION_H_
#include <LightGBM/config.h>
#include <LightGBM/meta.h>
#include <memory>
#include <vector>
namespace LightGBM {
class DatasetLoader;
class Dataset;
class Boosting;
class ObjectiveFunction;
class Metric;
/*!
* \brief The main entrance of LightGBM. this application has two tasks:
* Train and Predict.
* Train task will train a new model
* Predict task will predict the scores of test data using existing model,
* and save the score to disk.
*/
class Application {
public:
Application(int argc, char** argv);
/*! \brief Destructor */
~Application();
/*! \brief To call this funciton to run application*/
inline void Run();
private:
/*! \brief Load parameters from command line and config file*/
void LoadParameters(int argc, char** argv);
/*! \brief Load data, including training data and validation data*/
void LoadData();
/*! \brief Initialization before training*/
void InitTrain();
/*! \brief Main Training logic */
void Train();
/*! \brief Initializations before prediction */
void InitPredict();
/*! \brief Main predicting logic */
void Predict();
/*! \brief Main Convert model logic */
void ConvertModel();
/*! \brief All configs */
Config config_;
/*! \brief Training data */
std::unique_ptr<Dataset> train_data_;
/*! \brief Validation data */
std::vector<std::unique_ptr<Dataset>> valid_datas_;
/*! \brief Metric for training data */
std::vector<std::unique_ptr<Metric>> train_metric_;
/*! \brief Metrics for validation data */
std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
/*! \brief Boosting object */
std::unique_ptr<Boosting> boosting_;
/*! \brief Training objective function */
std::unique_ptr<ObjectiveFunction> objective_fun_;
};
inline void Application::Run() {
if (config_.task == TaskType::kPredict || config_.task == TaskType::KRefitTree) {
InitPredict();
Predict();
} else if (config_.task == TaskType::kConvertModel) {
ConvertModel();
} else {
InitTrain();
Train();
}
}
} // namespace LightGBM
#endif // LightGBM_APPLICATION_H_