Skip to content

Commit

Permalink
fix impurity importance with hellinger splitrule
Browse files Browse the repository at this point in the history
  • Loading branch information
mnwright committed May 17, 2019
1 parent 83cdc82 commit 84c2193
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 27 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
@@ -1,8 +1,8 @@
Package: ranger
Type: Package
Title: A Fast Implementation of Random Forests
Version: 0.10.4-85
Date: 2018-07-30
Version: 0.10.4-86
Date: 2019-05-17
Author: Marvin N. Wright [aut, cre], Stefan Wager [ctb], Philipp Probst [ctb]
Maintainer: Marvin N. Wright <cran@wrig.de>
Description: A fast implementation of Random Forests, particularly suited for high
Expand Down
3 changes: 3 additions & 0 deletions NEWS
@@ -1,3 +1,6 @@
##### Version 0.10.4-86
* Fix impurity importance for "hellinger" splitrule

##### Version 0.10.4-85
* Add "hellinger" splitrule
* Add inbag argument for manual selection of observations in trees
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
@@ -1,3 +1,6 @@
##### Version 0.10.4-86
* Fix impurity importance for "hellinger" splitrule

##### Version 0.10.4-85
* Add "hellinger" splitrule
* Add inbag argument for manual selection of observations in trees
Expand Down
2 changes: 1 addition & 1 deletion cpp_version/src/version.h
@@ -1,3 +1,3 @@
#ifndef RANGER_VERSION
#define RANGER_VERSION "0.10.4"
#define RANGER_VERSION "0.10.4-86"
#endif
28 changes: 16 additions & 12 deletions src/TreeClassification.cpp
Expand Up @@ -701,18 +701,22 @@ void TreeClassification::findBestSplitValueExtraTreesUnordered(size_t nodeID, si

void TreeClassification::addGiniImportance(size_t nodeID, size_t varID, double decrease) {

std::vector<size_t> class_counts;
class_counts.resize(class_values->size(), 0);
double best_decrease;
if (splitrule != HELLINGER) {
std::vector<size_t> class_counts;
class_counts.resize(class_values->size(), 0);

for (auto& sampleID : sampleIDs[nodeID]) {
uint sample_classID = (*response_classIDs)[sampleID];
class_counts[sample_classID]++;
}
double sum_node = 0;
for (auto& class_count : class_counts) {
sum_node += class_count * class_count;
for (auto& sampleID : sampleIDs[nodeID]) {
uint sample_classID = (*response_classIDs)[sampleID];
class_counts[sample_classID]++;
}

double sum_node = 0;
for (auto& class_count : class_counts) {
sum_node += class_count * class_count;
}
best_decrease = decrease - sum_node / (double) sampleIDs[nodeID].size();
}
double best_gini = decrease - sum_node / (double) sampleIDs[nodeID].size();

// No variable importance for no split variables
size_t tempvarID = data->getUnpermutedVarID(varID);
Expand All @@ -724,9 +728,9 @@ void TreeClassification::addGiniImportance(size_t nodeID, size_t varID, double d

// Subtract if corrected importance and permuted variable, else add
if (importance_mode == IMP_GINI_CORRECTED && varID >= data->getNumCols()) {
(*variable_importance)[tempvarID] -= best_gini;
(*variable_importance)[tempvarID] -= best_decrease;
} else {
(*variable_importance)[tempvarID] += best_gini;
(*variable_importance)[tempvarID] += best_decrease;
}
}

Expand Down
27 changes: 15 additions & 12 deletions src/TreeProbability.cpp
Expand Up @@ -706,18 +706,21 @@ void TreeProbability::findBestSplitValueExtraTreesUnordered(size_t nodeID, size_

void TreeProbability::addImpurityImportance(size_t nodeID, size_t varID, double decrease) {

std::vector<size_t> class_counts;
class_counts.resize(class_values->size(), 0);
double best_decrease;
if (splitrule != HELLINGER) {
std::vector<size_t> class_counts;
class_counts.resize(class_values->size(), 0);

for (auto& sampleID : sampleIDs[nodeID]) {
uint sample_classID = (*response_classIDs)[sampleID];
class_counts[sample_classID]++;
}
double sum_node = 0;
for (auto& class_count : class_counts) {
sum_node += class_count * class_count;
for (auto& sampleID : sampleIDs[nodeID]) {
uint sample_classID = (*response_classIDs)[sampleID];
class_counts[sample_classID]++;
}
double sum_node = 0;
for (auto& class_count : class_counts) {
sum_node += class_count * class_count;
}
best_decrease = decrease - sum_node / (double) sampleIDs[nodeID].size();
}
double best_gini = decrease - sum_node / (double) sampleIDs[nodeID].size();

// No variable importance for no split variables
size_t tempvarID = data->getUnpermutedVarID(varID);
Expand All @@ -729,9 +732,9 @@ void TreeProbability::addImpurityImportance(size_t nodeID, size_t varID, double

// Subtract if corrected importance and permuted variable, else add
if (importance_mode == IMP_GINI_CORRECTED && varID >= data->getNumCols()) {
(*variable_importance)[tempvarID] -= best_gini;
(*variable_importance)[tempvarID] -= best_decrease;
} else {
(*variable_importance)[tempvarID] += best_gini;
(*variable_importance)[tempvarID] += best_decrease;
}
}

Expand Down

0 comments on commit 84c2193

Please sign in to comment.