Skip to content

Commit

Permalink
check the shape for mat, csr and csc in prediction (#2464)
Browse files Browse the repository at this point in the history
* check the shape for mat, csr and csc

* guess from csr

* support file checking

* better error msg

* grammar

* clean code

* code clean

* check range for CSR

* Update test_.py

* Update test_.py

* added tests
  • Loading branch information
guolinke committed Oct 3, 2019
1 parent dc65e0a commit dee7215
Show file tree
Hide file tree
Showing 10 changed files with 178 additions and 82 deletions.
4 changes: 2 additions & 2 deletions include/LightGBM/c_api.h
Expand Up @@ -683,7 +683,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
* \param data_type Type of ``data`` pointer, can be ``C_API_DTYPE_FLOAT32`` or ``C_API_DTYPE_FLOAT64``
* \param nindptr Number of rows in the matrix + 1
* \param nelem Number of nonzero elements in the matrix
* \param num_col Number of columns; when it's set to 0, then guess from data
* \param num_col Number of columns
* \param predict_type What should be predicted
* - ``C_API_PREDICT_NORMAL``: normal prediction, with transform (if needed);
* - ``C_API_PREDICT_RAW_SCORE``: raw score;
Expand Down Expand Up @@ -726,7 +726,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
* \param data_type Type of ``data`` pointer, can be ``C_API_DTYPE_FLOAT32`` or ``C_API_DTYPE_FLOAT64``
* \param nindptr Number of rows in the matrix + 1
* \param nelem Number of nonzero elements in the matrix
* \param num_col Number of columns; when it's set to 0, then guess from data
* \param num_col Number of columns
* \param predict_type What should be predicted
* - ``C_API_PREDICT_NORMAL``: normal prediction, with transform (if needed);
* - ``C_API_PREDICT_RAW_SCORE``: raw score;
Expand Down
3 changes: 2 additions & 1 deletion include/LightGBM/dataset.h
Expand Up @@ -265,7 +265,7 @@ class Parser {
virtual void ParseOneLine(const char* str,
std::vector<std::pair<int, double>>* out_features, double* out_label) const = 0;

virtual int TotalColumns() const = 0;
virtual int NumFeatures() const = 0;

/*!
* \brief Create a object of parser, will auto choose the format depend on file
Expand All @@ -290,6 +290,7 @@ class Dataset {

void Construct(
std::vector<std::unique_ptr<BinMapper>>* bin_mappers,
int num_total_features,
const std::vector<std::vector<double>>& forced_bins,
int** sample_non_zero_indices,
const int* num_per_col,
Expand Down
4 changes: 3 additions & 1 deletion src/application/predictor.hpp
Expand Up @@ -140,7 +140,9 @@ class Predictor {
if (parser == nullptr) {
Log::Fatal("Could not recognize the data format of data file %s", data_filename);
}

if (parser->NumFeatures() != boosting_->MaxFeatureIdx() + 1) {
Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).", parser->NumFeatures(), boosting_->MaxFeatureIdx() + 1);
}
TextReader<data_size_t> predict_data_reader(data_filename, header);
std::unordered_map<int, int> feature_names_map_;
bool need_adjust = false;
Expand Down
55 changes: 39 additions & 16 deletions src/c_api.cpp
Expand Up @@ -249,17 +249,19 @@ class Booster {
boosting_->RollbackOneIter();
}

void PredictSingleRow(int num_iteration, int predict_type,
void PredictSingleRow(int num_iteration, int predict_type, int ncol,
std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
const Config& config,
double* out_result, int64_t* out_len) {
if (ncol != boosting_->MaxFeatureIdx() + 1) {
Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).", ncol, boosting_->MaxFeatureIdx() + 1);
}
std::lock_guard<std::mutex> lock(mutex_);
if (single_row_predictor_[predict_type].get() == nullptr ||
!single_row_predictor_[predict_type]->IsPredictorEqual(config, num_iteration, boosting_.get())) {
single_row_predictor_[predict_type].reset(new SingleRowPredictor(predict_type, boosting_.get(),
config, num_iteration));
}

auto one_row = get_row_fun(0);
auto pred_wrt_ptr = out_result;
single_row_predictor_[predict_type]->predict_function(one_row, pred_wrt_ptr);
Expand All @@ -268,10 +270,13 @@ class Booster {
}


void Predict(int num_iteration, int predict_type, int nrow,
void Predict(int num_iteration, int predict_type, int nrow, int ncol,
std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
const Config& config,
double* out_result, int64_t* out_len) {
if (ncol != boosting_->MaxFeatureIdx() + 1) {
Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).", ncol, boosting_->MaxFeatureIdx() + 1);
}
std::lock_guard<std::mutex> lock(mutex_);
bool is_predict_leaf = false;
bool is_raw_score = false;
Expand Down Expand Up @@ -647,7 +652,7 @@ int LGBM_DatasetCreateFromMats(int32_t nmat,
DatasetLoader loader(config, nullptr, 1, nullptr);
ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(&sample_values).data(),
Common::Vector2Ptr<int>(&sample_idx).data(),
static_cast<int>(sample_values.size()),
ncol,
Common::VectorSize<double>(sample_values).data(),
sample_cnt, total_nrow));
} else {
Expand Down Expand Up @@ -687,6 +692,11 @@ int LGBM_DatasetCreateFromCSR(const void* indptr,
const DatasetHandle reference,
DatasetHandle* out) {
API_BEGIN();
if (num_col <= 0) {
Log::Fatal("The number of columns should be greater than zero.");
} else if (num_col >= INT32_MAX) {
Log::Fatal("The number of columns should be smaller than INT32_MAX.");
}
auto param = Config::Str2Map(parameters);
Config config;
config.Set(param);
Expand Down Expand Up @@ -718,7 +728,7 @@ int LGBM_DatasetCreateFromCSR(const void* indptr,
DatasetLoader loader(config, nullptr, 1, nullptr);
ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(&sample_values).data(),
Common::Vector2Ptr<int>(&sample_idx).data(),
static_cast<int>(sample_values.size()),
static_cast<int>(num_col),
Common::VectorSize<double>(sample_values).data(),
sample_cnt, nrow));
} else {
Expand Down Expand Up @@ -748,9 +758,12 @@ int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr,
const DatasetHandle reference,
DatasetHandle* out) {
API_BEGIN();

if (num_col <= 0) {
Log::Fatal("The number of columns should be greater than zero.");
} else if (num_col >= INT32_MAX) {
Log::Fatal("The number of columns should be smaller than INT32_MAX.");
}
auto get_row_fun = *static_cast<std::function<void(int idx, std::vector<std::pair<int, double>>&)>*>(get_row_funptr);

auto param = Config::Str2Map(parameters);
Config config;
config.Set(param);
Expand Down Expand Up @@ -783,7 +796,7 @@ int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr,
DatasetLoader loader(config, nullptr, 1, nullptr);
ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(&sample_values).data(),
Common::Vector2Ptr<int>(&sample_idx).data(),
static_cast<int>(sample_values.size()),
static_cast<int>(num_col),
Common::VectorSize<double>(sample_values).data(),
sample_cnt, nrow));
} else {
Expand Down Expand Up @@ -1299,13 +1312,18 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle,
int data_type,
int64_t nindptr,
int64_t nelem,
int64_t,
int64_t num_col,
int predict_type,
int num_iteration,
const char* parameter,
int64_t* out_len,
double* out_result) {
API_BEGIN();
if (num_col <= 0) {
Log::Fatal("The number of columns should be greater than zero.");
} else if (num_col >= INT32_MAX) {
Log::Fatal("The number of columns should be smaller than INT32_MAX.");
}
auto param = Config::Str2Map(parameter);
Config config;
config.Set(param);
Expand All @@ -1315,7 +1333,7 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle,
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
int nrow = static_cast<int>(nindptr - 1);
ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
ref_booster->Predict(num_iteration, predict_type, nrow, static_cast<int>(num_col), get_row_fun,
config, out_result, out_len);
API_END();
}
Expand All @@ -1328,13 +1346,18 @@ int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
int data_type,
int64_t nindptr,
int64_t nelem,
int64_t,
int64_t num_col,
int predict_type,
int num_iteration,
const char* parameter,
int64_t* out_len,
double* out_result) {
API_BEGIN();
if (num_col <= 0) {
Log::Fatal("The number of columns should be greater than zero.");
} else if (num_col >= INT32_MAX) {
Log::Fatal("The number of columns should be smaller than INT32_MAX.");
}
auto param = Config::Str2Map(parameter);
Config config;
config.Set(param);
Expand All @@ -1343,7 +1366,7 @@ int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
}
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
ref_booster->PredictSingleRow(num_iteration, predict_type, get_row_fun, config, out_result, out_len);
ref_booster->PredictSingleRow(num_iteration, predict_type, static_cast<int32_t>(num_col), get_row_fun, config, out_result, out_len);
API_END();
}

Expand Down Expand Up @@ -1395,7 +1418,7 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle,
}
return one_row;
};
ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), get_row_fun, config,
ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), ncol, get_row_fun, config,
out_result, out_len);
API_END();
}
Expand All @@ -1420,7 +1443,7 @@ int LGBM_BoosterPredictForMat(BoosterHandle handle,
}
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
ref_booster->Predict(num_iteration, predict_type, nrow, ncol, get_row_fun,
config, out_result, out_len);
API_END();
}
Expand All @@ -1444,7 +1467,7 @@ int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
}
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, ncol, data_type, is_row_major);
ref_booster->PredictSingleRow(num_iteration, predict_type, get_row_fun, config, out_result, out_len);
ref_booster->PredictSingleRow(num_iteration, predict_type, ncol, get_row_fun, config, out_result, out_len);
API_END();
}

Expand All @@ -1468,7 +1491,7 @@ int LGBM_BoosterPredictForMats(BoosterHandle handle,
}
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowPairFunctionFromDenseRows(data, ncol, data_type);
ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun, config, out_result, out_len);
ref_booster->Predict(num_iteration, predict_type, nrow, ncol, get_row_fun, config, out_result, out_len);
API_END();
}

Expand Down
4 changes: 3 additions & 1 deletion src/io/dataset.cpp
Expand Up @@ -215,12 +215,14 @@ std::vector<std::vector<int>> FastFeatureBundling(const std::vector<std::unique_

void Dataset::Construct(
std::vector<std::unique_ptr<BinMapper>>* bin_mappers,
int num_total_features,
const std::vector<std::vector<double>>& forced_bins,
int** sample_non_zero_indices,
const int* num_per_col,
size_t total_sample_cnt,
const Config& io_config) {
num_total_features_ = static_cast<int>(bin_mappers->size());
num_total_features_ = num_total_features;
CHECK(num_total_features_ == static_cast<int>(bin_mappers->size()));
sparse_threshold_ = io_config.sparse_threshold;
// get num_features
std::vector<int> used_features;
Expand Down
6 changes: 3 additions & 3 deletions src/io/dataset_loader.cpp
Expand Up @@ -721,7 +721,7 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
}
}
auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data));
dataset->Construct(&bin_mappers, forced_bin_bounds, sample_indices, num_per_col, total_sample_size, config_);
dataset->Construct(&bin_mappers, num_col, forced_bin_bounds, sample_indices, num_per_col, total_sample_size, config_);
dataset->set_feature_names(feature_names_);
return dataset.release();
}
Expand Down Expand Up @@ -897,7 +897,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,

if (feature_names_.empty()) {
// -1 means doesn't use this feature
dataset->num_total_features_ = std::max(static_cast<int>(sample_values.size()), parser->TotalColumns() - 1);
dataset->num_total_features_ = std::max(static_cast<int>(sample_values.size()), parser->NumFeatures());
dataset->used_feature_map_ = std::vector<int>(dataset->num_total_features_, -1);
} else {
dataset->used_feature_map_ = std::vector<int>(feature_names_.size(), -1);
Expand Down Expand Up @@ -1059,7 +1059,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
}
}
sample_values.clear();
dataset->Construct(&bin_mappers, forced_bin_bounds, Common::Vector2Ptr<int>(&sample_indices).data(),
dataset->Construct(&bin_mappers, dataset->num_total_features_, forced_bin_bounds, Common::Vector2Ptr<int>(&sample_indices).data(),
Common::VectorSize<int>(sample_indices).data(), sample_data.size(), config_);
}

Expand Down

0 comments on commit dee7215

Please sign in to comment.