Skip to content

Commit

Permalink
fix nan in tree model (#2303)
Browse files Browse the repository at this point in the history
* fix nan in tree model

* fix
  • Loading branch information
guolinke committed Aug 14, 2019
1 parent 578a8c8 commit 9558417
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
2 changes: 1 addition & 1 deletion include/LightGBM/tree.h
Expand Up @@ -422,7 +422,7 @@ inline void Tree::Split(int leaf, int feature, int real_feature,
split_feature_inner_[new_node_idx] = feature;
split_feature_[new_node_idx] = real_feature;

split_gain_[new_node_idx] = Common::AvoidInf(gain);
split_gain_[new_node_idx] = gain;
// add two new leaves
left_child_[new_node_idx] = ~leaf;
right_child_[new_node_idx] = ~num_leaves_;
Expand Down
8 changes: 6 additions & 2 deletions include/LightGBM/utils/common.h
Expand Up @@ -663,7 +663,9 @@ inline static std::vector<int> VectorSize(const std::vector<std::vector<T>>& dat
}

inline static double AvoidInf(double x) {
if (x >= 1e300) {
if (std::isnan(x)) {
return 0.0;
} else if (x >= 1e300) {
return 1e300;
} else if (x <= -1e300) {
return -1e300;
Expand All @@ -673,7 +675,9 @@ inline static double AvoidInf(double x) {
}

inline static float AvoidInf(float x) {
if (x >= 1e38) {
if (std::isnan(x)){
return 0.0f;
} else if (x >= 1e38) {
return 1e38f;
} else if (x <= -1e38) {
return -1e38f;
Expand Down
4 changes: 2 additions & 2 deletions src/io/tree.cpp
Expand Up @@ -64,7 +64,7 @@ int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
SetMissingType(&decision_type_[new_node_idx], 2);
}
threshold_in_bin_[new_node_idx] = threshold_bin;
threshold_[new_node_idx] = Common::AvoidInf(threshold_double);
threshold_[new_node_idx] = threshold_double;
++num_leaves_;
return num_leaves_ - 1;
}
Expand Down Expand Up @@ -268,7 +268,7 @@ std::string Tree::NodeToJSON(int index) const {
str_buf << "{" << '\n';
str_buf << "\"split_index\":" << index << "," << '\n';
str_buf << "\"split_feature\":" << split_feature_[index] << "," << '\n';
str_buf << "\"split_gain\":" << split_gain_[index] << "," << '\n';
str_buf << "\"split_gain\":" << Common::AvoidInf(split_gain_[index]) << "," << '\n';
if (GetDecisionType(decision_type_[index], kCategoricalMask)) {
int cat_idx = static_cast<int>(threshold_[index]);
std::vector<int> cats;
Expand Down

0 comments on commit 9558417

Please sign in to comment.