Skip to content

Commit

Permalink
[MRG] expose feature importance to c_api (#860)
Browse files Browse the repository at this point in the history
* expose feature importance to c_api

* support type=gain

* remove dump model from examples and tests temporarily because it's unstable

* use double instead of float
  • Loading branch information
wxchan authored and guolinke committed Aug 24, 2017
1 parent e5aa565 commit 603bffc
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 64 deletions.
9 changes: 0 additions & 9 deletions examples/python-guide/simple_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,7 @@
# eval
print('The rmse of prediction is:', mean_squared_error(y_test, y_pred) ** 0.5)

print('Dump model to JSON...')
# dump model to json (and save to file)
model_json = gbm.dump_model()

with open('model.json', 'w+') as f:
json.dump(model_json, f, indent=4)


print('Feature names:', gbm.feature_name())

print('Calculate feature importances...')
# feature importances
print('Feature importances:', list(gbm.feature_importance()))
1 change: 0 additions & 1 deletion examples/python-guide/sklearn_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
# eval
print('The rmse of prediction is:', mean_squared_error(y_test, y_pred) ** 0.5)

print('Calculate feature importances...')
# feature importances
print('Feature importances:', list(gbm.feature_importances_))

Expand Down
8 changes: 8 additions & 0 deletions include/LightGBM/boosting.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,14 @@ class LIGHTGBM_EXPORT Boosting {
*/
virtual bool LoadModelFromString(const std::string& model_str) = 0;

/*!
* \brief Calculate feature importances
* \param num_iteration Number of model that want to use for feature importance, -1 means use all
* \param importance_type: 0 for split, 1 for gain
* \return vector of feature_importance
*/
virtual std::vector<double> FeatureImportance(int num_iteration, int importance_type) const = 0;

/*!
* \brief Get max feature index of this model
* \return Max feature index of this model
Expand Down
13 changes: 13 additions & 0 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,19 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle,
int leaf_idx,
double val);

/*!
* \brief get model feature importance
* \param handle handle
* \param num_iteration, <= 0 means use all
* \param importance_type: 0 for split, 1 for gain
* \param out_results output value array
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterFeatureImportance(BoosterHandle handle,
int num_iteration,
int importance_type,
double* out_results);

#if defined(_MSC_VER)
// exception handle and error msg
static char* LastErrorMsg() { static __declspec(thread) char err_msg[512] = "Everything is fine"; return err_msg; }
Expand Down
51 changes: 27 additions & 24 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1680,6 +1680,14 @@ def _to_predictor(self, pred_parameter=None):
predictor.pandas_categorical = self.pandas_categorical
return predictor

def num_feature(self):
"""Get num of features"""
out_num_feature = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetNumFeature(
self.handle,
ctypes.byref(out_num_feature)))
return out_num_feature.value

def feature_name(self):
"""
Get feature names.
Expand All @@ -1689,12 +1697,7 @@ def feature_name(self):
result : array
Array of feature names.
"""
out_num_feature = ctypes.c_int(0)
"""Get num of features"""
_safe_call(_LIB.LGBM_BoosterGetNumFeature(
self.handle,
ctypes.byref(out_num_feature)))
num_feature = out_num_feature.value
num_feature = self.num_feature()
"""Get name of features"""
tmp_out_len = ctypes.c_int(0)
string_buffers = [ctypes.create_string_buffer(255) for i in range_(num_feature)]
Expand All @@ -1707,7 +1710,7 @@ def feature_name(self):
raise ValueError("Length of feature names doesn't equal with num_feature")
return [string_buffers[i].value.decode() for i in range_(num_feature)]

def feature_importance(self, importance_type='split'):
def feature_importance(self, importance_type='split', iteration=-1):
"""
Get feature importances
Expand All @@ -1723,23 +1726,23 @@ def feature_importance(self, importance_type='split'):
result : array
Array of feature importances.
"""
if importance_type not in ["split", "gain"]:
raise KeyError("importance_type must be split or gain")
dump_model = self.dump_model()
ret = [0] * (dump_model["max_feature_idx"] + 1)

def dfs(root):
if "split_feature" in root:
if root['split_gain'] > 0:
if importance_type == 'split':
ret[root["split_feature"]] += 1
elif importance_type == 'gain':
ret[root["split_feature"]] += root["split_gain"]
dfs(root["left_child"])
dfs(root["right_child"])
for tree in dump_model["tree_info"]:
dfs(tree["tree_structure"])
return np.array(ret)
if importance_type == "split":
importance_type_int = 0
elif importance_type == "gain":
importance_type_int = 1
else:
importance_type_int = -1
num_feature = self.num_feature()
result = np.array([0 for _ in range_(num_feature)], dtype=np.float64)
_safe_call(_LIB.LGBM_BoosterFeatureImportance(
self.handle,
ctypes.c_int(iteration),
ctypes.c_int(importance_type_int),
result.ctypes.data_as(ctypes.POINTER(ctypes.c_double))))
if importance_type_int == 0:
return result.astype(int)
else:
return result

def __inner_eval(self, data_name, data_idx, feval=None):
"""
Expand Down
60 changes: 40 additions & 20 deletions src/boosting/gbdt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -994,6 +994,8 @@ std::string GBDT::SaveModelToString(int num_iteration) const {

ss << "feature_infos=" << Common::Join(feature_infos_, " ") << std::endl;

std::vector<double> feature_importances = FeatureImportance(num_iteration, 0);

ss << std::endl;
int num_used_model = static_cast<int>(models_.size());
if (num_iteration > 0) {
Expand All @@ -1006,7 +1008,20 @@ std::string GBDT::SaveModelToString(int num_iteration) const {
ss << models_[i]->ToString() << std::endl;
}

std::vector<std::pair<size_t, std::string>> pairs = FeatureImportance(num_used_model);
// store the importance first
std::vector<std::pair<size_t, std::string>> pairs;
for (size_t i = 0; i < feature_importances.size(); ++i) {
size_t feature_importances_int = static_cast<size_t>(feature_importances[i]);
if (feature_importances_int > 0) {
pairs.emplace_back(feature_importances_int, feature_names_[i]);
}
}
// sort the importance
std::sort(pairs.begin(), pairs.end(),
[] (const std::pair<size_t, std::string>& lhs,
const std::pair<size_t, std::string>& rhs) {
return lhs.first > rhs.first;
});
ss << std::endl << "feature importances:" << std::endl;
for (size_t i = 0; i < pairs.size(); ++i) {
ss << pairs[i].second << "=" << std::to_string(pairs[i].first) << std::endl;
Expand Down Expand Up @@ -1130,30 +1145,35 @@ bool GBDT::LoadModelFromString(const std::string& model_str) {
return true;
}

std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance(int num_used_model) const {
std::vector<double> GBDT::FeatureImportance(int num_iteration, int importance_type) const {

int num_used_model = static_cast<int>(models_.size());
if (num_iteration > 0) {
num_iteration += boost_from_average_ ? 1 : 0;
num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
}

std::vector<size_t> feature_importances(max_feature_idx_ + 1, 0);
for (int iter = 0; iter < num_used_model; ++iter) {
for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
if (models_[iter]->split_gain(split_idx) > 0) {
++feature_importances[models_[iter]->split_feature(split_idx)];
std::vector<double> feature_importances(max_feature_idx_ + 1, 0.0);
if (importance_type == 0) {
for (int iter = boost_from_average_ ? 1 : 0; iter < num_used_model; ++iter) {
for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
if (models_[iter]->split_gain(split_idx) > 0) {
feature_importances[models_[iter]->split_feature(split_idx)] += 1.0;
}
}
}
}
// store the importance first
std::vector<std::pair<size_t, std::string>> pairs;
for (size_t i = 0; i < feature_importances.size(); ++i) {
if (feature_importances[i] > 0) {
pairs.emplace_back(feature_importances[i], feature_names_[i]);
} else if (importance_type == 1) {
for (int iter = boost_from_average_ ? 1 : 0; iter < num_used_model; ++iter) {
for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) {
if (models_[iter]->split_gain(split_idx) > 0) {
feature_importances[models_[iter]->split_feature(split_idx)] += models_[iter]->split_gain(split_idx);
}
}
}
} else {
Log::Fatal("Unknown importance type: only support split=0 and gain=1.");
}
// sort the importance
std::sort(pairs.begin(), pairs.end(),
[] (const std::pair<size_t, std::string>& lhs,
const std::pair<size_t, std::string>& rhs) {
return lhs.first > rhs.first;
});
return pairs;
return feature_importances;
}

} // namespace LightGBM
14 changes: 8 additions & 6 deletions src/boosting/gbdt.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,14 @@ class GBDT: public Boosting {
*/
bool LoadModelFromString(const std::string& model_str) override;

/*!
* \brief Calculate feature importances
* \param num_iteration Number of model that want to use for feature importance, -1 means use all
* \param importance_type: 0 for split, 1 for gain
* \return vector of feature_importance
*/
std::vector<double> FeatureImportance(int num_iteration, int importance_type) const override;

/*!
* \brief Get max feature index of this model
* \return Max feature index of this model
Expand Down Expand Up @@ -302,12 +310,6 @@ class GBDT: public Boosting {
* \return best_msg if met early_stopping
*/
std::string OutputMetric(int iter);
/*!
* \brief Calculate feature importances
* \param num_used_model Number of model that want to use for feature importance, -1 means use all
* \return sorted pairs of (feature_importance, feature_name)
*/
std::vector<std::pair<size_t, std::string>> FeatureImportance(int num_used_model) const;

/*! \brief current iteration */
int iter_;
Expand Down
17 changes: 17 additions & 0 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ class Booster {
return boosting_->DumpModel(num_iteration);
}

std::vector<double> FeatureImportance(int num_iteration, int importance_type) {
return boosting_->FeatureImportance(num_iteration, importance_type);
}

double GetLeafValue(int tree_idx, int leaf_idx) const {
return dynamic_cast<GBDT*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
}
Expand Down Expand Up @@ -1175,6 +1179,19 @@ int LGBM_BoosterSetLeafValue(BoosterHandle handle,
API_END();
}

int LGBM_BoosterFeatureImportance(BoosterHandle handle,
int num_iteration,
int importance_type,
double* out_results) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
std::vector<double> feature_importances = ref_booster->FeatureImportance(num_iteration, importance_type);
for (size_t i = 0; i < feature_importances.size(); ++i) {
(out_results)[i] = feature_importances[i];
}
API_END();
}

// ---- start of some help functions

std::function<std::vector<double>(int row_idx)>
Expand Down
5 changes: 1 addition & 4 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def test_early_stopping(self):
self.assertIn(valid_set_name, gbm.best_score)
self.assertIn('binary_logloss', gbm.best_score[valid_set_name])

def test_continue_train_and_dump_model(self):
def test_continue_train(self):
X, y = load_boston(True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
params = {
Expand All @@ -317,9 +317,6 @@ def test_continue_train_and_dump_model(self):
self.assertAlmostEqual(evals_result['valid_0']['l1'][-1], ret, places=5)
for l1, mae in zip(evals_result['valid_0']['l1'], evals_result['valid_0']['mae']):
self.assertAlmostEqual(l1, mae, places=5)
# test dump model
self.assertIn('tree_info', gbm.dump_model())
self.assertIsInstance(gbm.feature_importance(), np.ndarray)
os.remove(model_name)

def test_continue_train_multiclass(self):
Expand Down

0 comments on commit 603bffc

Please sign in to comment.