# Rulefit demo - Titanic Dataset

## H2O Rulefit algorithm

Rulefit algorithm combines tree ensembles and linear models to take advantage of both methods: a tree ensemble accuracy and a linear model interpretability. The general algorithm fits a tree ensebmle to the data, builds a rule ensemble by traversing each tree, evaluates the rules on the data to build a rule feature set and fits a sparse linear model (LASSO) to the rule feature set joined with the original feature set.

For more information, refer to: http://statweb.stanford.edu/~jhf/ftp/RuleFit.pdf by Jerome H. Friedman and Bogden E. Popescu.

## Demo example

We will train a rulefit model to predict the rules defining whether or not someone will survive:


In [35]:
import h2o
from h2o.estimators import H2ORuleFitEstimator, H2ORandomForestEstimator

# init h2o cluster
h2o.init(strict_version_check=False, url="http://192.168.59.147:54321")

Checking whether there is an H2O instance running at http://192.168.59.147:54321 ..... not found.
Attempting to start a local H2O server...
  Java Version: java version "1.8.0_231"; Java(TM) SE Runtime Environment (build 1.8.0_231-b11); Java HotSpot(TM) 64-Bit Server VM (build 25.231-b11, mixed mode)
  Starting server from /Users/zuzanaolajcova/IdeaProjects/h2o-3/build/h2o.jar
  Ice root: /var/folders/zn/5r1mf9597431rjrsg0lmr4tc0000gn/T/tmp4yqx03rk
  JVM stdout: /var/folders/zn/5r1mf9597431rjrsg0lmr4tc0000gn/T/tmp4yqx03rk/h2o_zuzanaolajcova_started_from_python.out
  JVM stderr: /var/folders/zn/5r1mf9597431rjrsg0lmr4tc0000gn/T/tmp4yqx03rk/h2o_zuzanaolajcova_started_from_python.err
  Server is running at http://127.0.0.1:54325
Connecting to H2O server at http://127.0.0.1:54325 ... successful.


0,1
H2O_cluster_uptime:,02 secs
H2O_cluster_timezone:,Europe/Prague
H2O_data_parsing_timezone:,UTC
H2O_cluster_version:,3.32.0.99999
H2O_cluster_version_age:,12 hours and 6 minutes
H2O_cluster_name:,H2O_from_python_zuzanaolajcova_sr507w
H2O_cluster_total_nodes:,1
H2O_cluster_free_memory:,3.556 Gb
H2O_cluster_total_cores:,12
H2O_cluster_allowed_cores:,12


In [36]:
df = h2o.import_file("https://s3.amazonaws.com/h2o-public-test-data/smalldata/gbm_test/titanic.csv",
                       col_types={'pclass': "enum", 'survived': "enum"})
x =  ["age", "sibsp", "parch", "sex", "pclass"]

# Split the dataset into train and test
train, test = df.split_frame(ratios=[.8], seed=1234)

Parse progress: |█████████████████████████████████████████████████████████| 100%


Using the `algorithm` parameter, a user can set whether algorithm will use DRF or GBM to fit a tree enseble. 

Using the `min_rule_length` and `max_rule_length` parameters, a user can set interval of tree enseble depths to be fitted. The bigger this interval is, the more tree ensembles will be fitted (1 per each depth) and the bigger the rule feature set will be.

Using the `max_num_rules` parameter, the maximum number of rules to return can be set.

Using the `model_type` parameter, the type of base learners in the enseble can be set.

Using the `rule_generation_ntrees` parameter, the number of trees for tree enseble can be set.

In [37]:
rfit = H2ORuleFitEstimator(algorithm="drf", 
                               min_rule_length=1, 
                               max_rule_length=10, 
                               max_num_rules=100, 
                               model_type="rules_and_linear",
                               rule_generation_ntrees=50,
                               seed=1234)
rfit.train(training_frame=train, x=x, y="survived")

rulefit Model Build progress: |███████████████████████████████████████████| 100%


The output for the Rulefit model includes:
    - model parameters
    - rule importences in tabular form
    - training and validation metrics of the underlying linear model

In [38]:
# Make a pretty HTML table printout of the results
(table, nr, is_pandas) = rule_importance._as_show_table()
display(HTML(table.to_html()))

Unnamed: 0,Unnamed: 1,variable,coefficient,rule
0,,M2T21N13,1.298698,"(sex in {female}) & (sibsp < 3.5 or sibsp is NA) & (pclass in {1, 2} or pclass is NA)"
1,,M2T23N21,-0.8455729,"(sex in {male} or sex is NA) & (pclass in {2, 3} or pclass is NA) & (age >= 9.497750282287598 or age is NA)"
2,,M1T0N7,0.3807125,"(pclass in {1, 2}) & (sex in {female})"
3,,M1T28N10,-0.3445493,(sex in {male} or sex is NA) & (age >= 13.496771812438965 or age is NA)
4,,M1T23N7,0.33104,(sex in {female}) & (sibsp < 2.5 or sibsp is NA)
5,,M1T37N10,-0.2323243,(sex in {male} or sex is NA) & (age >= 14.977890968322754 or age is NA)
6,,M4T3N45,-0.02772966,"(sex in {male} or sex is NA) & (pclass in {2, 3} or pclass is NA) & (parch < 0.5 or parch is NA) & (age >= 14.977890968322754 or age is NA) & (age < 61.49930953979492 or age is NA)"
7,,M1T1N7,1.631369e-13,"(pclass in {1, 2}) & (sex in {female})"
8,,M1T35N9,-1.029435e-13,(sex in {male} or sex is NA) & (age >= 13.496771812438965 or age is NA)


There are several rules that can be recapped as:

### Higgest Likelihood of Survival:
1. women in class 1 or 2 with 3 siblings/spouses aboard or less
2. women in class 1 or 2
3. women with 2 siblings/spouses aboard or less

### Lowest Likelihood of Survival:
1. male in class 2 or 3 of age >= 9.4
2. male of age >= 13.4
3. male of age >= 14.8
4. male in class 2 or 3 with no parents/children aboard of age between 14 to 61

Note: The rules are additive. That means that if a passenger is described by multiple rules, their probability is added together from those rules.

## Accuracy comparision with underlying tree ensebmle

In [39]:
rfit.model_performance(test)


ModelMetricsBinomialGLM: rulefit
** Reported on test data. **

MSE: 0.13630214191578538
RMSE: 0.3691911996727243
LogLoss: 0.4340433289362752
Null degrees of freedom: 254
Residual degrees of freedom: 245
Null deviance: 335.9901436111088
Residual deviance: 221.36209775750038
AIC: 241.36209775750038
AUC: 0.8499405312541297
AUCPR: 0.8273951021706092
Gini: 0.6998810625082594

Confusion Matrix (Act/Pred) for max f1 @ threshold = 0.3073309333932125: 


Unnamed: 0,Unnamed: 1,0,1,Error,Rate
0,0,114.0,47.0,0.2919,(47.0/161.0)
1,1,13.0,81.0,0.1383,(13.0/94.0)
2,Total,127.0,128.0,0.2353,(60.0/255.0)



Maximum Metrics: Maximum metrics at their respective thresholds


Unnamed: 0,metric,threshold,value,idx
0,max f1,0.307331,0.72973,3.0
1,max f2,0.307331,0.803571,3.0
2,max f0point5,0.855041,0.787402,0.0
3,max accuracy,0.441333,0.8,2.0
4,max precision,0.855041,1.0,0.0
5,max recall,0.156312,1.0,6.0
6,max specificity,0.855041,1.0,0.0
7,max absolute_mcc,0.441333,0.566043,2.0
8,max min_per_class_accuracy,0.307331,0.708075,3.0
9,max mean_per_class_accuracy,0.307331,0.784888,3.0



Gains/Lift Table: Avg response rate: 36.86 %, avg score: 35.70 %


Unnamed: 0,group,cumulative_data_fraction,lower_threshold,lift,cumulative_lift,response_rate,score,cumulative_response_rate,cumulative_score,capture_rate,cumulative_capture_rate,gain,cumulative_gain,kolmogorov_smirnov
0,1,0.156863,0.855041,2.712766,2.712766,1.0,0.855041,1.0,0.855041,0.425532,0.425532,171.276596,171.276596,0.425532
1,2,0.305882,0.523804,1.427772,2.086743,0.526316,0.523804,0.769231,0.693669,0.212766,0.638298,42.777156,108.674304,0.526497
2,3,0.501961,0.307331,1.139362,1.716672,0.42,0.336811,0.632812,0.554271,0.223404,0.861702,13.93617,71.667221,0.569777
3,4,1.0,0.156312,0.277685,1.0,0.102362,0.158187,0.368627,0.357006,0.138298,1.0,-72.23153,0.0,0.0







In [40]:
drf = H2ORandomForestEstimator(distribution="AUTO",ntrees=50, max_depth=10)
drf.train(x=x,y="survived", training_frame=train)
drf.model_performance(test)

drf Model Build progress: |███████████████████████████████████████████████| 100%

ModelMetricsBinomial: drf
** Reported on test data. **

MSE: 0.1453956705434043
RMSE: 0.3813078422264671
LogLoss: 0.44538733890667054
Mean Per-Class Error: 0.22211576582529402
AUC: 0.8362296815118276
AUCPR: 0.8122356383011233
Gini: 0.6724593630236553

Confusion Matrix (Act/Pred) for max f1 @ threshold = 0.5274066698551179: 


Unnamed: 0,Unnamed: 1,0,1,Error,Rate
0,0,146.0,15.0,0.0932,(15.0/161.0)
1,1,33.0,61.0,0.3511,(33.0/94.0)
2,Total,179.0,76.0,0.1882,(48.0/255.0)



Maximum Metrics: Maximum metrics at their respective thresholds


Unnamed: 0,metric,threshold,value,idx
0,max f1,0.527407,0.717647,61.0
1,max f2,0.071332,0.77686,149.0
2,max f0point5,0.631524,0.778146,46.0
3,max accuracy,0.527407,0.811765,61.0
4,max precision,0.996534,1.0,0.0
5,max recall,0.071332,1.0,149.0
6,max specificity,0.996534,1.0,0.0
7,max absolute_mcc,0.527407,0.586189,61.0
8,max min_per_class_accuracy,0.386536,0.734043,78.0
9,max mean_per_class_accuracy,0.527407,0.777884,61.0



Gains/Lift Table: Avg response rate: 36.86 %, avg score: 36.82 %


Unnamed: 0,group,cumulative_data_fraction,lower_threshold,lift,cumulative_lift,response_rate,score,cumulative_response_rate,cumulative_score,capture_rate,cumulative_capture_rate,gain,cumulative_gain,kolmogorov_smirnov
0,1,0.015686,0.995845,2.712766,2.712766,1.0,0.996017,1.0,0.996017,0.042553,0.042553,171.276596,171.276596,0.042553
1,2,0.023529,0.993309,2.712766,2.712766,1.0,0.993459,1.0,0.995164,0.021277,0.06383,171.276596,171.276596,0.06383
2,3,0.035294,0.992853,2.712766,2.712766,1.0,0.992961,1.0,0.99443,0.031915,0.095745,171.276596,171.276596,0.095745
3,4,0.043137,0.989873,2.712766,2.712766,1.0,0.99173,1.0,0.993939,0.021277,0.117021,171.276596,171.276596,0.117021
4,5,0.05098,0.982463,2.712766,2.712766,1.0,0.98425,1.0,0.992449,0.021277,0.138298,171.276596,171.276596,0.138298
5,6,0.101961,0.917046,2.712766,2.712766,1.0,0.960553,1.0,0.976501,0.138298,0.276596,171.276596,171.276596,0.276596
6,7,0.152941,0.777207,2.504092,2.643208,0.923077,0.861215,0.974359,0.938072,0.12766,0.404255,150.409165,164.320786,0.398044
7,8,0.2,0.634188,1.808511,2.446809,0.666667,0.706371,0.901961,0.883554,0.085106,0.489362,80.851064,144.680851,0.458306
8,9,0.301961,0.521346,1.565057,2.149074,0.576923,0.582112,0.792208,0.781769,0.159574,0.648936,56.505728,114.907433,0.549557
9,10,0.4,0.33894,0.868085,1.835106,0.32,0.442132,0.676471,0.698524,0.085106,0.734043,-13.191489,83.510638,0.529074





