-
Notifications
You must be signed in to change notification settings - Fork 3.8k
/
dataset_loader.h
110 lines (83 loc) · 4.64 KB
/
dataset_loader.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
/*!
* Copyright (c) 2016 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/
#ifndef LIGHTGBM_DATASET_LOADER_H_
#define LIGHTGBM_DATASET_LOADER_H_
#include <LightGBM/dataset.h>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
namespace LightGBM {
class DatasetLoader {
public:
LIGHTGBM_EXPORT DatasetLoader(const Config& io_config, const PredictFunction& predict_fun, int num_class, const char* filename);
LIGHTGBM_EXPORT ~DatasetLoader();
LIGHTGBM_EXPORT Dataset* LoadFromFile(const char* filename, int rank, int num_machines);
LIGHTGBM_EXPORT Dataset* LoadFromFile(const char* filename) {
return LoadFromFile(filename, 0, 1);
}
LIGHTGBM_EXPORT Dataset* LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data);
LIGHTGBM_EXPORT Dataset* LoadFromSerializedReference(const char* buffer, size_t buffer_size, data_size_t num_data, int32_t num_classes);
LIGHTGBM_EXPORT Dataset* ConstructFromSampleData(double** sample_values,
int** sample_indices,
int num_col,
const int* num_per_col,
size_t total_sample_size,
data_size_t num_local_data,
int64_t num_dist_data);
/*! \brief Disable copy */
DatasetLoader& operator=(const DatasetLoader&) = delete;
/*! \brief Disable copy */
DatasetLoader(const DatasetLoader&) = delete;
static std::vector<std::vector<double>> GetForcedBins(std::string forced_bins_path, int num_total_features,
const std::unordered_set<int>& categorical_features);
private:
void LoadHeaderFromMemory(Dataset* dataset, const char* buffer);
Dataset* LoadFromBinFile(const char* data_filename, const char* bin_filename, int rank, int num_machines, int* num_global_data, std::vector<data_size_t>* used_data_indices);
void SetHeader(const char* filename);
void CheckDataset(const Dataset* dataset, bool is_load_from_binary);
std::vector<std::string> LoadTextDataToMemory(const char* filename, const Metadata& metadata, int rank, int num_machines, int* num_global_data, std::vector<data_size_t>* used_data_indices);
std::vector<std::string> SampleTextDataFromMemory(const std::vector<std::string>& data);
std::vector<std::string> SampleTextDataFromFile(const char* filename, const Metadata& metadata, int rank, int num_machines, int* num_global_data, std::vector<data_size_t>* used_data_indices);
void ConstructBinMappersFromTextData(int rank, int num_machines, const std::vector<std::string>& sample_data, const Parser* parser, Dataset* dataset);
/*! \brief Extract local features from memory */
void ExtractFeaturesFromMemory(std::vector<std::string>* text_data, const Parser* parser, Dataset* dataset);
/*! \brief Extract local features from file */
void ExtractFeaturesFromFile(const char* filename, const Parser* parser, const std::vector<data_size_t>& used_data_indices, Dataset* dataset);
/*! \brief Check can load from binary file */
std::string CheckCanLoadFromBin(const char* filename);
/*! \brief Check the number of bins for categorical features.
* The number of bins for categorical features may exceed the configured maximum value.
* Log warnings when such cases happen.
*
* \param bin_mappers the bin_mappers of all features
* \param max_bin max_bin from Config
* \param max_bin_by_feature max_bin_by_feature from Config
*/
void CheckCategoricalFeatureNumBin(const std::vector<std::unique_ptr<BinMapper>>& bin_mappers, const int max_bin, const std::vector<int>& max_bin_by_feature) const;
const Config& config_;
/*! \brief Random generator*/
Random random_;
/*! \brief prediction function for initial model */
const PredictFunction predict_fun_;
/*! \brief number of classes */
int num_class_;
/*! \brief index of label column */
int label_idx_;
/*! \brief index of weight column */
int weight_idx_;
/*! \brief index of group column */
int group_idx_;
/*! \brief Mapper from real feature index to used index*/
std::unordered_set<int> ignore_features_;
/*! \brief store feature names */
std::vector<std::string> feature_names_;
/*! \brief Mapper from real feature index to used index*/
std::unordered_set<int> categorical_features_;
/*! \brief Whether to store raw feature values */
bool store_raw_;
};
} // namespace LightGBM
#endif // LIGHTGBM_DATASET_LOADER_H_