Skip to content

Commit

Permalink
DT/RF: Add function to report importance scores
Browse files Browse the repository at this point in the history
JIRA: MADLIB-925

This commit adds a new MADlib function (get_var_importance) to report the
importance scores in decision tree and random forest, by unnesting the
importance values along with corresponding features.

Closes apache#295

Co-authored-by: Rahul Iyer <riyer@apache.org>
Co-authored-by: Jingyi Mei <jmei@pivotal.io>
Co-authored-by: Orhan Kislal <okislal@pivotal.io>
  • Loading branch information
4 people committed Jul 26, 2018
1 parent 2aac418 commit 3ab7554
Show file tree
Hide file tree
Showing 13 changed files with 697 additions and 84 deletions.
11 changes: 2 additions & 9 deletions src/modules/recursive_partitioning/decision_tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ print_decision_tree::run(AnyType &args){
}

AnyType
get_variable_importance::run(AnyType &args){
compute_variable_importance::run(AnyType &args){
Tree dt = args[0].getAs<ByteString>();
const int n_cat_features = args[1].getAs<int>();
const int n_con_features = args[2].getAs<int>();
Expand All @@ -496,19 +496,12 @@ get_variable_importance::run(AnyType &args){
ColumnVector con_var_importance = ColumnVector::Zero(n_con_features);
dt.computeVariableImportance(cat_var_importance, con_var_importance);

// Variable importance is scaled to represent a percentage. Even though
// the importance values are split between categorical and continuous, the
// percentages are relative to the combined set.
ColumnVector combined_var_imp(n_cat_features + n_con_features);
combined_var_imp << cat_var_importance, con_var_importance;

// Avoid divide by zero by adding a small number
double total_var_imp = combined_var_imp.sum();
double VAR_IMP_EPSILON = 1e-6;
combined_var_imp *= (100.0 / (total_var_imp + VAR_IMP_EPSILON));
return combined_var_imp;
}


AnyType
display_text_tree::run(AnyType &args){
Tree dt = args[0].getAs<ByteString>();
Expand Down
2 changes: 1 addition & 1 deletion src/modules/recursive_partitioning/decision_tree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ DECLARE_UDF(recursive_partitioning, compute_surr_stats_transition)
DECLARE_UDF(recursive_partitioning, dt_surr_apply)

DECLARE_UDF(recursive_partitioning, print_decision_tree)
DECLARE_UDF(recursive_partitioning, get_variable_importance)
DECLARE_UDF(recursive_partitioning, compute_variable_importance)
DECLARE_UDF(recursive_partitioning, predict_dt_response)
DECLARE_UDF(recursive_partitioning, predict_dt_prob)

Expand Down
15 changes: 15 additions & 0 deletions src/modules/recursive_partitioning/random_forest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,21 @@ rf_con_imp_score::run(AnyType &args) {
// ------------------------------------------------------------


AnyType
normalize_sum_array::run(AnyType &args){
const MappedColumnVector input_vector = args[0].getAs<MappedColumnVector>();
const double sum_target = args[1].getAs<double>();

double sum_input_vector = input_vector.sum();
// Avoid divide by zero by dividing by a small number if sum is small
double VAR_IMP_EPSILON = 1e-6;
if (sum_input_vector < VAR_IMP_EPSILON)
sum_input_vector = VAR_IMP_EPSILON;
ColumnVector output_vector = input_vector * sum_target / sum_input_vector;
return output_vector;
}


} // namespace recursive_partitioning
} // namespace modules
} // namespace madlib
1 change: 1 addition & 0 deletions src/modules/recursive_partitioning/random_forest.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@

DECLARE_UDF(recursive_partitioning, rf_cat_imp_score)
DECLARE_UDF(recursive_partitioning, rf_con_imp_score)
DECLARE_UDF(recursive_partitioning, normalize_sum_array)
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ from utilities.control import OptimizerControl
from utilities.control import HashaggControl
from utilities.utilities import _assert
from utilities.utilities import _array_to_string
from utilities.utilities import _check_groups
from utilities.utilities import extract_keyvalue_params
from utilities.utilities import unique_string
from utilities.utilities import add_postfix
from utilities.utilities import extract_keyvalue_params
from utilities.utilities import is_psql_numeric_type, is_psql_boolean_type
Expand Down Expand Up @@ -2012,7 +2015,7 @@ def _compute_var_importance(schema_madlib, tree,
impurity_var_importance: Array of importance values
"""
var_imp_sql = """
SELECT {schema_madlib}._get_var_importance(
SELECT {schema_madlib}._compute_var_importance(
$1, -- trained decision tree
{n_cat_features},
{n_con_features}) AS impurity_var_importance
Expand Down Expand Up @@ -2412,7 +2415,6 @@ def _tree_error(schema_madlib, source_table, dependent_varname,
plpy.execute(sql)
# ------------------------------------------------------------


def tree_train_help_message(schema_madlib, message, **kwargs):
""" Help message for Decision Tree
"""
Expand Down Expand Up @@ -2567,6 +2569,10 @@ SELECT madlib.tree_train(
5);

SELECT madlib.tree_display('tree_out');
-- View the impurity importance value of each feature
DROP TABLE IF EXISTS var_imp_out;
SELECT madlib.get_var_importance('tree_out', 'var_imp_out');
SELECT * FROM var_imp_out;
"""
else:
help_string = "No such option. Use {schema_madlib}.tree_train('usage')"
Expand Down
102 changes: 74 additions & 28 deletions src/ports/postgres/modules/recursive_partitioning/decision_tree.sql_in
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ m4_include(`SQLCommon.m4')
<li class="level1"><a href="#train">Training Function</a></li>
<li class="level1"><a href="#predict">Prediction Function</a></li>
<li class="level1"><a href="#display">Tree Display</a></li>
<li class="level1"><a href="#display_importance">Importance Display</a></li>
<li class="level1"><a href="#examples">Examples</a></li>
<li class="level1"><a href="#literature">Literature</a></li>
<li class="level1"><a href="#related">Related Topics</a></li>
Expand Down Expand Up @@ -281,14 +282,17 @@ tree_train(
</tr>
<tr>
<th>impurity_var_importance</th>
<td>DOUBLE PRECISION[]. Impurity importance (also referred to as Gini
importance) of each variable. The order of the variables is the same as
<td>DOUBLE PRECISION[]. Impurity importance of each variable.
The order of the variables is the same as
that of 'independent_varnames' column in the summary table (see below).

The impurity importance of any feature is the decrease in impurity by a
node containing the feature as a primary split, summed over the whole
tree. If surrogates are used, then the importance value includes the
impurity decrease scaled by the adjusted surrogate agreement.
Reported importance values are normalized to sum to 100 across
all variables.
Please refer to [1] for more information on variable importance.
</td>
</tr>

Expand Down Expand Up @@ -597,6 +601,35 @@ plotting software to plot the trees in a nice format:
Please see the examples below for more details on the contents
of the tree output formats.

An additional display function is provided to output the surrogate splits chosen
for each internal node:
<pre class="syntax">
tree_surr_display(tree_model)
</pre>

@anchor display_importance
@par Importance Display

This is a helper function that creates a table to more easily
view impurity variable importance values for a given model
table. This function rescales the importance values to represent them as
percentages i.e. importance values are scaled to sum to 100.

<pre class="syntax">
get_var_importance(model_table, output_table)
</pre>

\b Arguments
<DL class="arglist">
<DT>model_table</DT>
<DD>TEXT. Name of the table containing the decision tree model.</DD>
<DT>output_table</DT>
<DD>TEXT. Name of the table to create for importance values.</DD>
</DL>

The summary table generated by the tree_train function is necessary for this
function to work.

@anchor examples
@examp
<h4>Decision Tree Classification Examples</h4>
Expand Down Expand Up @@ -661,9 +694,8 @@ SELECT pruning_cp, cat_levels_in_text, cat_n_levels, impurity_var_importance, tr
pruning_cp | 0
cat_levels_in_text | {overcast,rain,sunny,False,True}
cat_n_levels | {3,2}
impurity_var_importance | {10.6171201061712,0,89.3828798938288}
impurity_var_importance | {0.102040816326531,0,0.85905612244898}
tree_depth | 5

</pre>
View the summary table:
<pre class="example">
Expand Down Expand Up @@ -695,7 +727,20 @@ independent_var_types | text, boolean, double precision
n_folds | 0
null_proxy |
</pre>

View the impurity importance table using the helper function:
<pre class="example">
\\x off
DROP TABLE IF EXISTS imp_output;
SELECT madlib.get_var_importance('train_output','imp_output');
SELECT * FROM imp_output;
</pre>
<pre class="result">
feature | impurity_var_importance
-------------+-------------------------
"OUTLOOK" | 10.6171090593052
windy | 0
temperature | 89.382786893026
</pre>
-# Predict output categories. For the purpose
of this example, we use the same data that was
used for training:
Expand Down Expand Up @@ -920,24 +965,22 @@ SELECT madlib.tree_train('dt_golf', -- source table
1, -- min bucket
10 -- number of bins per continuous variable
);
SELECT pruning_cp, cat_levels_in_text, cat_n_levels, impurity_var_importance, tree_depth FROM train_output;
</pre>
View the output table (excluding the tree which is in binary format):
<pre class="example">
SELECT pruning_cp, cat_levels_in_text, cat_n_levels, tree_depth FROM train_output;
\\x on
SELECT pruning_cp, cat_levels_in_text, cat_n_levels, impurity_var_importance, tree_depth FROM train_output;
</pre>
<pre class="result">
-[ RECORD 1 ]-----------+-----------------------------------------------------
pruning_cp | 0
cat_levels_in_text | {medium,none,high,low,unhealthy,good,moderate}
cat_n_levels | {4,3}
impurity_var_importance | {0,40.2340084993653,5.6791213643137,54.086870136321}
impurity_var_importance | {0,0.330612244897959,0.0466666666666666,0.444444444444444}
tree_depth | 3

</pre>
The first 4 levels correspond to cloud ceiling and the next 3 levels
correspond to air quality.

-# Weighting observations. Use the 'weights' parameter to
adjust a row's vote to balance the dataset. In our
example, the weights are somewhat random but
Expand Down Expand Up @@ -1041,7 +1084,6 @@ INSERT INTO mt_cars VALUES
(31,19.7,6,145,175,3.62,2.77,15.5,0,1,5,6),
(32,21.4,4,121,109,4.11,2.78,18.6,1,1,4,2);
</pre>

-# We train a regression decision tree with surrogates
in order to handle the NULL feature values:
<pre class="example">
Expand Down Expand Up @@ -1074,7 +1116,6 @@ cat_levels_in_text | {0,1,4,6,8}
cat_n_levels | {2,3}
impurity_var_importance | {0,51.8593201959496,10.976977929129,5.31897402755374,31.8447278473677}
tree_depth | 4

</pre>
View the summary table:
<pre class="example">
Expand Down Expand Up @@ -1105,9 +1146,23 @@ input_cp | 0
independent_var_types | integer, integer, double precision, double precision, double precision
n_folds | 0
null_proxy |

</pre>

View the impurity importance table using the helper function:
<pre class="example">
\\x off
DROP TABLE IF EXISTS imp_output;
SELECT madlib.get_var_importance('train_output','imp_output');
SELECT * FROM imp_output ORDER BY impurity_var_importance DESC;
</pre>
<pre class="result">
feature | impurity_var_importance
---------+-------------------------
cyl | 51.8593190075796
wt | 31.8447271176382
disp | 10.9769776775887
qsec | 5.31897390566817
vs | 0
</pre>
-# Predict regression output for the same data and compare with original:
<pre class="example">
\\x off
Expand Down Expand Up @@ -1158,7 +1213,6 @@ Result:
32 | 21.4 | 22.58 | -1.18
(32 rows)
</pre>

-# Display the decision tree in basic text format:
<pre class="example">
SELECT madlib.tree_display('train_output', FALSE);
Expand Down Expand Up @@ -1186,7 +1240,6 @@ SELECT madlib.tree_display('train_output', FALSE);
&nbsp;-------------------------------------
(1 row)
</pre>

-# Display the surrogate variables that are used
to compute the split for each node when the primary
variable is NULL:
Expand Down Expand Up @@ -1227,7 +1280,6 @@ descending order until a surrogate variable is found that is not NULL. In this c
the two tuples have non-NULL values for <em>disp</em>, hence the <em>disp <= 146.7</em>
split is used to make the prediction. If all the surrogate variables are
NULL then the majority branch would be followed.

-# Now let's use cross validation to select the best
value of the complexity parameter cp:
<pre class="example">
Expand Down Expand Up @@ -1259,9 +1311,8 @@ SELECT pruning_cp, cat_levels_in_text, cat_n_levels, impurity_var_importance, tr
pruning_cp | 0
cat_levels_in_text | {0,1,4,6,8}
cat_n_levels | {2,3}
impurity_var_importance | {0,51.8593201959496,10.976977929129,5.31897402755374,31.8447278473677}
impurity_var_importance | {0,22.6309172500677,4.79024943310653,2.32115,13.8967382920109}
tree_depth | 4

</pre>
The cp values tested and average error and standard deviation are:
<pre class="example">
Expand All @@ -1278,9 +1329,7 @@ SELECT * FROM train_output_cv ORDER BY cv_error_avg ASC;
0.643125226048458 | 5.76814538295394 | 2.10750950120742
(6 rows)
</pre>

<h4>NULL Handling Example</h4>

-# Create toy example to illustrate 'null-as-category' handling
for categorical features:
<pre class='example'>
Expand All @@ -1298,7 +1347,6 @@ INSERT INTO null_handling_example VALUES
(3,'US','NY',null,'c'),
(4,'US','NY','rainy','d');
</pre>

-# Train decision tree. Note that 'NULL' is set as a
valid level for the categorical features country, weather and city:
<pre class='example'>
Expand Down Expand Up @@ -1356,7 +1404,6 @@ independent_var_types | text, text, text
n_folds | 0
null_proxy | __NULL__
</pre>

-# Predict for data not previously seen by assuming NULL
value as the default:
<pre class='example'>
Expand Down Expand Up @@ -1399,7 +1446,6 @@ a NULL (not 'US') country level. Likewise, any
city in the 'US' that is not 'NY' will predict
response 'b', corresponding to a NULL (not 'NY')
city level.

-# Display the decision tree in basic text format:
<pre class="example">
SELECT madlib.tree_display('train_output', FALSE);
Expand Down Expand Up @@ -1427,7 +1473,7 @@ SELECT madlib.tree_display('train_output', FALSE);

@anchor literature
@par Literature
[1] Breiman, Leo; Friedman, J. H.; Olshen, R. A.; Stone, C. J. (1984). Classification and regression trees. Monterey, CA: Wadsworth & Brooks/Cole Advanced Books & Software.
[1] Breiman, Leo; Friedman, J. H.; Olshen, R. A.; Stone, C. J. (1984). Classification and Regression Trees. Monterey, CA: Wadsworth & Brooks/Cole Advanced Books & Software.

@anchor related
@par Related Topics
Expand Down Expand Up @@ -1831,12 +1877,12 @@ LANGUAGE C IMMUTABLE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `NO SQL', `');
-------------------------------------------------------------------------

CREATE OR REPLACE FUNCTION MADLIB_SCHEMA._get_var_importance(
CREATE OR REPLACE FUNCTION MADLIB_SCHEMA._compute_var_importance(
tree MADLIB_SCHEMA.BYTEA8,
n_cat_features INTEGER,
n_con_features INTEGER
) RETURNS DOUBLE PRECISION[] AS
'MODULE_PATHNAME', 'get_variable_importance'
'MODULE_PATHNAME', 'compute_variable_importance'
LANGUAGE C IMMUTABLE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `NO SQL', `');
-------------------------------------------------------------------------
Expand Down Expand Up @@ -1976,6 +2022,7 @@ $$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `READS SQL DATA', `');



/**
*@brief Display decision tree in dot or text format
*
Expand Down Expand Up @@ -2037,7 +2084,6 @@ programs.
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `READS SQL DATA', `');


CREATE OR REPLACE FUNCTION MADLIB_SCHEMA._display_decision_tree(
tree MADLIB_SCHEMA.bytea8,
cat_features TEXT[],
Expand Down

0 comments on commit 3ab7554

Please sign in to comment.