Skip to content

Commit

Permalink
support label as double type (#1120)
Browse files Browse the repository at this point in the history
  • Loading branch information
guolinke committed Dec 17, 2017
1 parent 162509a commit aa78a6b
Show file tree
Hide file tree
Showing 20 changed files with 175 additions and 137 deletions.
20 changes: 10 additions & 10 deletions include/LightGBM/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ class Metadata {
void CheckOrPartition(data_size_t num_all_data,
const std::vector<data_size_t>& used_data_indices);

void SetLabel(const float* label, data_size_t len);
void SetLabel(const label_t* label, data_size_t len);

void SetWeights(const float* weights, data_size_t len);
void SetWeights(const label_t* weights, data_size_t len);

void SetQuery(const data_size_t* query, data_size_t len);

Expand All @@ -110,14 +110,14 @@ class Metadata {
* \brief Get pointer of label
* \return Pointer of label
*/
inline const float* label() const { return label_.data(); }
inline const label_t* label() const { return label_.data(); }

/*!
* \brief Set label for one record
* \param idx Index of this record
* \param value Label value of this record
*/
inline void SetLabelAt(data_size_t idx, float value)
inline void SetLabelAt(data_size_t idx, label_t value)
{
label_[idx] = value;
}
Expand All @@ -127,7 +127,7 @@ class Metadata {
* \param idx Index of this record
* \param value Weight value of this record
*/
inline void SetWeightAt(data_size_t idx, float value)
inline void SetWeightAt(data_size_t idx, label_t value)
{
weights_[idx] = value;
}
Expand All @@ -146,7 +146,7 @@ class Metadata {
* \brief Get weights, if not exists, will return nullptr
* \return Pointer of weights
*/
inline const float* weights() const {
inline const label_t* weights() const {
if (!weights_.empty()) {
return weights_.data();
} else {
Expand Down Expand Up @@ -179,7 +179,7 @@ class Metadata {
* \brief Get weights for queries, if not exists, will return nullptr
* \return Pointer of weights for queries
*/
inline const float* query_weights() const {
inline const label_t* query_weights() const {
if (!query_weights_.empty()) {
return query_weights_.data();
} else {
Expand Down Expand Up @@ -225,13 +225,13 @@ class Metadata {
/*! \brief Number of weights, used to check correct weight file */
data_size_t num_weights_;
/*! \brief Label data */
std::vector<float> label_;
std::vector<label_t> label_;
/*! \brief Weights data */
std::vector<float> weights_;
std::vector<label_t> weights_;
/*! \brief Query boundaries */
std::vector<data_size_t> query_boundaries_;
/*! \brief Query weights */
std::vector<float> query_weights_;
std::vector<label_t> query_weights_;
/*! \brief Number of querys */
data_size_t num_queries_;
/*! \brief Number of Initial score, used to check correct weight file */
Expand Down
18 changes: 18 additions & 0 deletions include/LightGBM/meta.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,26 @@ namespace LightGBM {

/*! \brief Type of data size, it is better to use signed type*/
typedef int32_t data_size_t;

// Enable following marco to use double for score_t
// #define SCORE_T_USE_DOUBLE

// Enable following marco to use double for label_t
// #define LABEL_T_USE_DOUBLE

/*! \brief Type of score, and gradients */
#ifdef SCORE_T_USE_DOUBLE
typedef double score_t;
#else
typedef float score_t;
#endif

/*! \brief Type of metadata, include weight and label */
#ifdef LABEL_T_USE_DOUBLE
typedef double label_t;
#else
typedef float label_t;
#endif

const score_t kMinScore = -std::numeric_limits<score_t>::infinity();

Expand Down
10 changes: 5 additions & 5 deletions include/LightGBM/metric.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class DCGCalculator {
* \param num_data Number of data
* \return The DCG score
*/
static double CalDCGAtK(data_size_t k, const float* label,
static double CalDCGAtK(data_size_t k, const label_t* label,
const double* score, data_size_t num_data);

/*!
Expand All @@ -82,7 +82,7 @@ class DCGCalculator {
* \param out Output result
*/
static void CalDCG(const std::vector<data_size_t>& ks,
const float* label, const double* score,
const label_t* label, const double* score,
data_size_t num_data, std::vector<double>* out);

/*!
Expand All @@ -93,14 +93,14 @@ class DCGCalculator {
* \return The max DCG score
*/
static double CalMaxDCGAtK(data_size_t k,
const float* label, data_size_t num_data);
const label_t* label, data_size_t num_data);

/*!
* \brief Check the label range for NDCG and lambdarank
* \param label Pointer of label
* \param num_data Number of data
*/
static void CheckLabel(const float* label, data_size_t num_data);
static void CheckLabel(const label_t* label, data_size_t num_data);

/*!
* \brief Calculate the Max DCG score at multi position
Expand All @@ -110,7 +110,7 @@ class DCGCalculator {
* \param out Output result
*/
static void CalMaxDCG(const std::vector<data_size_t>& ks,
const float* label, data_size_t num_data, std::vector<double>* out);
const label_t* label, data_size_t num_data, std::vector<double>* out);

/*!
* \brief Get discount score of position k
Expand Down
2 changes: 1 addition & 1 deletion src/boosting/gbdt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ void GBDT::Bagging(int iter) {
* (i) and (ii) could be selected as say "auto_init_score" = 0 or 1 etc..
*
*/
double ObtainAutomaticInitialScore(const ObjectiveFunction* fobj, const float* label, data_size_t num_data) {
double ObtainAutomaticInitialScore(const ObjectiveFunction* fobj, const label_t* label, data_size_t num_data) {
double init_score = 0.0f;
bool got_custom = false;
if (fobj != nullptr) {
Expand Down
6 changes: 5 additions & 1 deletion src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ class Booster {
return boosting_->TrainOneIter(nullptr, nullptr);
}

bool TrainOneIter(const float* gradients, const float* hessians) {
bool TrainOneIter(const score_t* gradients, const score_t* hessians) {
std::lock_guard<std::mutex> lock(mutex_);
return boosting_->TrainOneIter(gradients, hessians);
}
Expand Down Expand Up @@ -904,11 +904,15 @@ int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
int* is_finished) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
#ifdef SCORE_T_USE_DOUBLE
Log::Fatal("Don't support Custom loss function when enable SCORE_T_USE_DOUBLE.");
#else
if (ref_booster->TrainOneIter(grad, hess)) {
*is_finished = 1;
} else {
*is_finished = 0;
}
#endif
API_END();
}

Expand Down
16 changes: 16 additions & 0 deletions src/io/dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,9 +423,17 @@ bool Dataset::SetFloatField(const char* field_name, const float* field_data, dat
std::string name(field_name);
name = Common::Trim(name);
if (name == std::string("label") || name == std::string("target")) {
#ifdef LABEL_T_USE_DOUBLE
Log::Fatal("Don't Support LABEL_T_USE_DOUBLE.");
#else
metadata_.SetLabel(field_data, num_element);
#endif
} else if (name == std::string("weight") || name == std::string("weights")) {
#ifdef LABEL_T_USE_DOUBLE
Log::Fatal("Don't Support LABEL_T_USE_DOUBLE.");
#else
metadata_.SetWeights(field_data, num_element);
#endif
} else {
return false;
}
Expand Down Expand Up @@ -458,11 +466,19 @@ bool Dataset::GetFloatField(const char* field_name, data_size_t* out_len, const
std::string name(field_name);
name = Common::Trim(name);
if (name == std::string("label") || name == std::string("target")) {
#ifdef LABEL_T_USE_DOUBLE
Log::Fatal("Don't Support LABEL_T_USE_DOUBLE.");
#else
*out_ptr = metadata_.label();
*out_len = num_data_;
#endif
} else if (name == std::string("weight") || name == std::string("weights")) {
#ifdef LABEL_T_USE_DOUBLE
Log::Fatal("Don't Support LABEL_T_USE_DOUBLE.");
#else
*out_ptr = metadata_.weights();
*out_len = num_data_;
#endif
} else {
return false;
}
Expand Down
12 changes: 6 additions & 6 deletions src/io/dataset_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>& text_dat
// parser
parser->ParseOneLine(text_data[i].c_str(), &oneline_features, &tmp_label);
// set label
dataset->metadata_.SetLabelAt(i, static_cast<float>(tmp_label));
dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
// free processed line:
text_data[i].clear();
// shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
Expand All @@ -937,7 +937,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>& text_dat
dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
} else {
if (inner_data.first == weight_idx_) {
dataset->metadata_.SetWeightAt(i, static_cast<float>(inner_data.second));
dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
} else if (inner_data.first == group_idx_) {
dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
}
Expand All @@ -964,7 +964,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>& text_dat
init_score[k * dataset->num_data_ + i] = static_cast<double>(oneline_init_score[k]);
}
// set label
dataset->metadata_.SetLabelAt(i, static_cast<float>(tmp_label));
dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
// free processed line:
text_data[i].clear();
// shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
Expand All @@ -980,7 +980,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>& text_dat
dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
} else {
if (inner_data.first == weight_idx_) {
dataset->metadata_.SetWeightAt(i, static_cast<float>(inner_data.second));
dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
} else if (inner_data.first == group_idx_) {
dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
}
Expand Down Expand Up @@ -1025,7 +1025,7 @@ void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser*
}
}
// set label
dataset->metadata_.SetLabelAt(start_idx + i, static_cast<float>(tmp_label));
dataset->metadata_.SetLabelAt(start_idx + i, static_cast<label_t>(tmp_label));
// push data
for (auto& inner_data : oneline_features) {
if (inner_data.first >= dataset->num_total_features_) { continue; }
Expand All @@ -1037,7 +1037,7 @@ void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser*
dataset->feature_groups_[group]->PushData(tid, sub_feature, start_idx + i, inner_data.second);
} else {
if (inner_data.first == weight_idx_) {
dataset->metadata_.SetWeightAt(start_idx + i, static_cast<float>(inner_data.second));
dataset->metadata_.SetWeightAt(start_idx + i, static_cast<label_t>(inner_data.second));
} else if (inner_data.first == group_idx_) {
dataset->metadata_.SetQueryAt(start_idx + i, static_cast<data_size_t>(inner_data.second));
}
Expand Down

0 comments on commit aa78a6b

Please sign in to comment.