Skip to content

Commit

Permalink
add weight in tree model output (#2269)
Browse files Browse the repository at this point in the history
* add weight in tree model output

* fix bug

* updated Python plotting part to handle weights
  • Loading branch information
guolinke authored and StrikerRUS committed Jul 24, 2019
1 parent 86a9578 commit e1d7a7b
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 17 deletions.
Expand Up @@ -355,7 +355,9 @@
" 'split_gain',\n",
" 'internal_value',\n",
" 'internal_count',\n",
" 'leaf_count'],\n",
" 'internal_weight',\n",
" 'leaf_count',\n",
" 'leaf_weight'],\n",
" value=['None']),\n",
" precision=(0, 10))\n",
" tree = None\n",
Expand All @@ -382,7 +384,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
"version": "3.7.1"
},
"varInspector": {
"cols": {
Expand Down
23 changes: 18 additions & 5 deletions include/LightGBM/tree.h
Expand Up @@ -50,14 +50,17 @@ class Tree {
* \param right_value Model Right child output
* \param left_cnt Count of left child
* \param right_cnt Count of right child
* \param left_weight Weight of left child
* \param right_weight Weight of right child
* \param gain Split gain
* \param missing_type missing type
* \param default_left default direction for missing value
* \return The index of new leaf.
*/
int Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
double threshold_double, double left_value, double right_value,
int left_cnt, int right_cnt, float gain, MissingType missing_type, bool default_left);
int left_cnt, int right_cnt, double left_weight, double right_weight,
float gain, MissingType missing_type, bool default_left);

/*!
* \brief Performing a split on tree leaves, with categorical feature
Expand All @@ -72,12 +75,14 @@ class Tree {
* \param right_value Model Right child output
* \param left_cnt Count of left child
* \param right_cnt Count of right child
* \param left_weight Weight of left child
* \param right_weight Weight of right child
* \param gain Split gain
* \return The index of new leaf.
*/
int SplitCategorical(int leaf, int feature, int real_feature, const uint32_t* threshold_bin, int num_threshold_bin,
const uint32_t* threshold, int num_threshold, double left_value, double right_value,
int left_cnt, int right_cnt, float gain, MissingType missing_type);
int left_cnt, int right_cnt, double left_weight, double right_weight, float gain, MissingType missing_type);

/*! \brief Get the output of one leaf */
inline double LeafOutput(int leaf) const { return leaf_value_[leaf]; }
Expand Down Expand Up @@ -297,8 +302,8 @@ class Tree {
}
}

inline void Split(int leaf, int feature, int real_feature,
double left_value, double right_value, int left_cnt, int right_cnt, float gain);
inline void Split(int leaf, int feature, int real_feature, double left_value, double right_value, int left_cnt, int right_cnt,
double left_weight, double right_weight, float gain);
/*!
* \brief Find leaf index of which record belongs by features
* \param feature_values Feature value of this record
Expand Down Expand Up @@ -383,10 +388,14 @@ class Tree {
std::vector<int> leaf_parent_;
/*! \brief Output of leaves */
std::vector<double> leaf_value_;
/*! \brief weight of leaves */
std::vector<double> leaf_weight_;
/*! \brief DataCount of leaves */
std::vector<int> leaf_count_;
/*! \brief Output of non-leaf nodes */
std::vector<double> internal_value_;
/*! \brief weight of non-leaf nodes */
std::vector<double> internal_weight_;
/*! \brief DataCount of non-leaf nodes */
std::vector<int> internal_count_;
/*! \brief Depth for leaves */
Expand All @@ -396,7 +405,8 @@ class Tree {
};

inline void Tree::Split(int leaf, int feature, int real_feature,
double left_value, double right_value, int left_cnt, int right_cnt, float gain) {
double left_value, double right_value, int left_cnt, int right_cnt,
double left_weight, double right_weight, float gain) {
int new_node_idx = num_leaves_ - 1;
// update parent info
int parent = leaf_parent_[leaf];
Expand All @@ -420,11 +430,14 @@ inline void Tree::Split(int leaf, int feature, int real_feature,
leaf_parent_[leaf] = new_node_idx;
leaf_parent_[num_leaves_] = new_node_idx;
// save current leaf value to internal node before change
internal_weight_[new_node_idx] = leaf_weight_[leaf];
internal_value_[new_node_idx] = leaf_value_[leaf];
internal_count_[new_node_idx] = left_cnt + right_cnt;
leaf_value_[leaf] = std::isnan(left_value) ? 0.0f : left_value;
leaf_weight_[leaf] = left_weight;
leaf_count_[leaf] = left_cnt;
leaf_value_[num_leaves_] = std::isnan(right_value) ? 0.0f : right_value;
leaf_weight_[num_leaves_] = right_weight;
leaf_count_[num_leaves_] = right_cnt;
// update leaf depth
leaf_depth_[num_leaves_] = leaf_depth_[leaf] + 1;
Expand Down
10 changes: 7 additions & 3 deletions python-package/lightgbm/plotting.py
Expand Up @@ -390,7 +390,7 @@ def add(root, parent=None, decision=None):
label = 'split_feature_index: {0}'.format(root['split_feature'])
label += r'\nthreshold: {0}'.format(_float2str(root['threshold'], precision))
for info in show_info:
if info in {'split_gain', 'internal_value'}:
if info in {'split_gain', 'internal_value', 'internal_weight'}:
label += r'\n{0}: {1}'.format(info, _float2str(root[info], precision))
elif info == 'internal_count':
label += r'\n{0}: {1}'.format(info, root[info])
Expand All @@ -409,6 +409,8 @@ def add(root, parent=None, decision=None):
label += r'\nleaf_value: {0}'.format(_float2str(root['leaf_value'], precision))
if 'leaf_count' in show_info:
label += r'\nleaf_count: {0}'.format(root['leaf_count'])
if 'leaf_weight' in show_info:
label += r'\nleaf_weight: {0}'.format(_float2str(root['leaf_weight'], precision))
graph.node(name, label=label)
if parent is not None:
graph.edge(parent, name, decision)
Expand Down Expand Up @@ -438,7 +440,8 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None,
The index of a target tree to convert.
show_info : list of strings or None, optional (default=None)
What information should be shown in nodes.
Possible values of list items: 'split_gain', 'internal_value', 'internal_count', 'leaf_count'.
Possible values of list items:
'split_gain', 'internal_value', 'internal_count', 'internal_weight', 'leaf_count', 'leaf_weight'.
precision : int or None, optional (default=None)
Used to restrict the display of floating point values to a certain precision.
**kwargs
Expand Down Expand Up @@ -515,7 +518,8 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None,
Figure size.
show_info : list of strings or None, optional (default=None)
What information should be shown in nodes.
Possible values of list items: 'split_gain', 'internal_value', 'internal_count', 'leaf_count'.
Possible values of list items:
'split_gain', 'internal_value', 'internal_count', 'internal_weight', 'leaf_count', 'leaf_weight'.
precision : int or None, optional (default=None)
Used to restrict the display of floating point values to a certain precision.
**kwargs
Expand Down
2 changes: 1 addition & 1 deletion src/boosting/gbdt_model_text.cpp
Expand Up @@ -14,7 +14,7 @@

namespace LightGBM {

const std::string kModelVersion = "v2";
const std::string kModelVersion = "v3";

std::string GBDT::DumpModel(int start_iteration, int num_iteration) const {
std::stringstream str_buf;
Expand Down
33 changes: 28 additions & 5 deletions src/io/tree.cpp
Expand Up @@ -26,14 +26,17 @@ Tree::Tree(int max_leaves)
split_gain_.resize(max_leaves_ - 1);
leaf_parent_.resize(max_leaves_);
leaf_value_.resize(max_leaves_);
leaf_weight_.resize(max_leaves_);
leaf_count_.resize(max_leaves_);
internal_value_.resize(max_leaves_ - 1);
internal_weight_.resize(max_leaves_ - 1);
internal_count_.resize(max_leaves_ - 1);
leaf_depth_.resize(max_leaves_);
// root is in the depth 0
leaf_depth_[0] = 0;
num_leaves_ = 1;
leaf_value_[0] = 0.0f;
leaf_weight_[0] = 0.0f;
leaf_parent_[0] = -1;
shrinkage_ = 1.0f;
num_cat_ = 0;
Expand All @@ -47,8 +50,8 @@ Tree::~Tree() {

int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
double threshold_double, double left_value, double right_value,
int left_cnt, int right_cnt, float gain, MissingType missing_type, bool default_left) {
Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, gain);
int left_cnt, int right_cnt, double left_weight, double right_weight, float gain, MissingType missing_type, bool default_left) {
Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, left_weight, right_weight, gain);
int new_node_idx = num_leaves_ - 1;
decision_type_[new_node_idx] = 0;
SetDecisionType(&decision_type_[new_node_idx], false, kCategoricalMask);
Expand All @@ -68,8 +71,8 @@ int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,

int Tree::SplitCategorical(int leaf, int feature, int real_feature, const uint32_t* threshold_bin, int num_threshold_bin,
const uint32_t* threshold, int num_threshold, double left_value, double right_value,
data_size_t left_cnt, data_size_t right_cnt, float gain, MissingType missing_type) {
Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, gain);
data_size_t left_cnt, data_size_t right_cnt, double left_weight, double right_weight, float gain, MissingType missing_type) {
Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, left_weight, right_weight, gain);
int new_node_idx = num_leaves_ - 1;
decision_type_[new_node_idx] = 0;
SetDecisionType(&decision_type_[new_node_idx], true, kCategoricalMask);
Expand Down Expand Up @@ -221,10 +224,14 @@ std::string Tree::ToString() const {
<< Common::ArrayToStringFast(right_child_, num_leaves_ - 1) << '\n';
str_buf << "leaf_value="
<< Common::ArrayToString(leaf_value_, num_leaves_) << '\n';
str_buf << "leaf_weight="
<< Common::ArrayToString(leaf_weight_, num_leaves_) << '\n';
str_buf << "leaf_count="
<< Common::ArrayToStringFast(leaf_count_, num_leaves_) << '\n';
str_buf << "internal_value="
<< Common::ArrayToStringFast(internal_value_, num_leaves_ - 1) << '\n';
str_buf << "internal_weight="
<< Common::ArrayToStringFast(internal_weight_, num_leaves_ - 1) << '\n';
str_buf << "internal_count="
<< Common::ArrayToStringFast(internal_count_, num_leaves_ - 1) << '\n';
if (num_cat_ > 0) {
Expand Down Expand Up @@ -294,6 +301,7 @@ std::string Tree::NodeToJSON(int index) const {
str_buf << "\"missing_type\":\"NaN\"," << '\n';
}
str_buf << "\"internal_value\":" << internal_value_[index] << "," << '\n';
str_buf << "\"internal_weight\":" << internal_weight_[index] << "," << '\n';
str_buf << "\"internal_count\":" << internal_count_[index] << "," << '\n';
str_buf << "\"left_child\":" << NodeToJSON(left_child_[index]) << "," << '\n';
str_buf << "\"right_child\":" << NodeToJSON(right_child_[index]) << '\n';
Expand All @@ -304,6 +312,7 @@ std::string Tree::NodeToJSON(int index) const {
str_buf << "{" << '\n';
str_buf << "\"leaf_index\":" << index << "," << '\n';
str_buf << "\"leaf_value\":" << leaf_value_[index] << "," << '\n';
str_buf << "\"leaf_weight\":" << leaf_weight_[index] << "," << '\n';
str_buf << "\"leaf_count\":" << leaf_count_[index] << '\n';
str_buf << "}";
}
Expand Down Expand Up @@ -472,7 +481,7 @@ std::string Tree::NodeToIfElseByMap(int index, bool predict_leaf_index) const {
Tree::Tree(const char* str, size_t* used_len) {
auto p = str;
std::unordered_map<std::string, std::string> key_vals;
const int max_num_line = 15;
const int max_num_line = 17;
int read_line = 0;
while (read_line < max_num_line) {
if (*p == '\r' || *p == '\n') break;
Expand Down Expand Up @@ -557,6 +566,20 @@ Tree::Tree(const char* str, size_t* used_len) {
internal_value_.resize(num_leaves_ - 1);
}

if (key_vals.count("internal_weight")) {
internal_weight_ = Common::StringToArrayFast<double>(key_vals["internal_weight"], num_leaves_ - 1);
}
else {
internal_weight_.resize(num_leaves_ - 1);
}

if (key_vals.count("leaf_weight")) {
leaf_weight_ = Common::StringToArrayFast<double>(key_vals["leaf_weight"], num_leaves_);
}
else {
leaf_weight_.resize(num_leaves_);
}

if (key_vals.count("leaf_count")) {
leaf_count_ = Common::StringToArrayFast<int>(key_vals["leaf_count"], num_leaves_);
} else {
Expand Down
8 changes: 8 additions & 0 deletions src/treelearner/serial_tree_learner.cpp
Expand Up @@ -684,6 +684,8 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, Json& forced_split_json, int*
static_cast<double>(current_split_info.right_output),
static_cast<data_size_t>(current_split_info.left_count),
static_cast<data_size_t>(current_split_info.right_count),
static_cast<double>(current_split_info.left_sum_hessian),
static_cast<double>(current_split_info.right_sum_hessian),
static_cast<float>(current_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type(),
current_split_info.default_left);
Expand Down Expand Up @@ -711,6 +713,8 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, Json& forced_split_json, int*
static_cast<double>(current_split_info.right_output),
static_cast<data_size_t>(current_split_info.left_count),
static_cast<data_size_t>(current_split_info.right_count),
static_cast<double>(current_split_info.left_sum_hessian),
static_cast<double>(current_split_info.right_sum_hessian),
static_cast<float>(current_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type());
data_partition_->Split(current_leaf, train_data_, inner_feature_index,
Expand Down Expand Up @@ -792,6 +796,8 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri
static_cast<double>(best_split_info.right_output),
static_cast<data_size_t>(best_split_info.left_count),
static_cast<data_size_t>(best_split_info.right_count),
static_cast<double>(best_split_info.left_sum_hessian),
static_cast<double>(best_split_info.right_sum_hessian),
static_cast<float>(best_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type(),
best_split_info.default_left);
Expand All @@ -815,6 +821,8 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri
static_cast<double>(best_split_info.right_output),
static_cast<data_size_t>(best_split_info.left_count),
static_cast<data_size_t>(best_split_info.right_count),
static_cast<double>(best_split_info.left_sum_hessian),
static_cast<double>(best_split_info.right_sum_hessian),
static_cast<float>(best_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type());
data_partition_->Split(best_leaf, train_data_, inner_feature_index,
Expand Down
4 changes: 3 additions & 1 deletion tests/python_package_test/test_plotting.py
Expand Up @@ -120,7 +120,7 @@ def test_create_tree_digraph(self):
self.assertRaises(IndexError, lgb.create_tree_digraph, gbm, tree_index=83)

graph = lgb.create_tree_digraph(gbm, tree_index=3,
show_info=['split_gain', 'internal_value'],
show_info=['split_gain', 'internal_value', 'internal_weight'],
name='Tree4', node_attr={'color': 'red'})
graph.render(view=False)
self.assertIsInstance(graph, graphviz.Digraph)
Expand All @@ -137,8 +137,10 @@ def test_create_tree_digraph(self):
self.assertIn('leaf_index', graph_body)
self.assertIn('split_gain', graph_body)
self.assertIn('internal_value', graph_body)
self.assertIn('internal_weight', graph_body)
self.assertNotIn('internal_count', graph_body)
self.assertNotIn('leaf_count', graph_body)
self.assertNotIn('leaf_weight', graph_body)

@unittest.skipIf(not MATPLOTLIB_INSTALLED, 'matplotlib is not installed')
def test_plot_metrics(self):
Expand Down

0 comments on commit e1d7a7b

Please sign in to comment.