In [1]:
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score, f1_score

from rule_ensemble import RuleFitClassifier, LocallyInterpretableRuleEnsembleClassifier
from datasets import Dataset

np.random.seed(0)
pd.set_option("display.max_rows", 200)
pd.set_option("display.min_rows", 20)
pd.set_option("display.max_colwidth", 999)
pd.set_option("display.precision", 3)
pd.set_option("display.float_format", "{:.3f}".format)

In [2]:
D = Dataset(dataset='a')
X_tr, X_ts, y_tr, y_ts = D.get_dataset(split=True)
D.df.head()

Unnamed: 0,Age,fnlwgt,Education-Num,Capital_Gain,Capital_Loss,Hours_per_week,Workclass:Unknown,Workclass:Federal-gov,Workclass:Local-gov,Workclass:Never-worked,...,Country:Puerto-Rico,Country:Scotland,Country:South,Country:Taiwan,Country:Thailand,Country:Trinadad&Tobago,Country:United-States,Country:Vietnam,Country:Yugoslavia,Income
0,39,77516,13,2174,0,40,0,0,0,0,...,0,0,0,0,0,0,1,0,0,1
1,50,83311,13,0,0,13,0,0,0,0,...,0,0,0,0,0,0,1,0,0,1
2,38,215646,9,0,0,40,0,0,0,0,...,0,0,0,0,0,0,1,0,0,1
3,53,234721,7,0,0,40,0,0,0,0,...,0,0,0,0,0,0,1,0,0,1
4,28,338409,13,0,0,40,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1


In [3]:
forest = RandomForestClassifier(max_depth=2, n_estimators=30).fit(X_tr, y_tr)
print('## Base Forest (RF)')
print('### Test Accuracy')
print('- Accuracy:', forest.score(X_ts, y_ts))
print('- F1 score:', f1_score(y_ts, forest.predict(X_ts)))
print('- AUC     :', roc_auc_score(y_ts, forest.predict_proba(X_ts)[:, 1]))

## Base Forest (RF)
### Test Accuracy
- Accuracy: 0.7779142611472792
- F1 score: 0.8723884810841332
- AUC     : 0.8774759466964266


In [4]:
rufi = RuleFitClassifier(C=1/(X_tr.shape[0]*3e-2), forest=forest)
rufi = rufi.fit(X_tr, y_tr, feature_names=D.feature_names, feature_types=D.feature_types, class_names=D.class_names)
print('#### Test Accuracy')
print('- Accuracy:', rufi.score(X_ts, y_ts))
print('- F1 score:', f1_score(y_ts, rufi.predict(X_ts)))
print('- AUC     :', roc_auc_score(y_ts, rufi.predict_proba(X_ts)))
print('#### Interpretability')
print('- Rule Candidates:', rufi.coef_.shape[0])
print('- Global Support :', rufi.global_support_cardinality())
print('- Local Support  :', rufi.local_support_cardinality(X_ts).mean())
print('- Rule Diversity :', rufi.diversity(X_ts))

#### Test Accuracy
- Accuracy: 0.8301191499815748
- F1 score: 0.8943549003131922
- AUC     : 0.8594671293529547
#### Interpretability
- Rule Candidates: 150
- Global Support : 8
- Local Support  : 3.752364574376612
- Rule Diversity : 0.3281527309798746


In [5]:
rufi.explain_global()

Unnamed: 0,No.,Rule,"Contribution to ""Income: <=50K""",Frequency,Importance
0,52,Education-Num <= 12 & Capital_Gain <= 5119,1.006,0.731,0.446
1,36,Marital_Status != Married-civ-spouse & Education != Prof-school,0.644,0.533,0.321
2,79,Capital_Loss <= 1820 & Marital_Status != Married-civ-spouse,0.411,0.532,0.205
3,91,Marital_Status != Married-civ-spouse & Hours_per_week <= 44,0.312,0.433,0.154
4,115,Age > 31 & Sex = Male,-0.191,0.456,0.095
5,6,Marital_Status != Married-civ-spouse & Education != Masters,0.05,0.517,0.025
6,94,Marital_Status = Married-civ-spouse & Education != HS-grad,-0.027,0.311,0.013
7,136,Hours_per_week > 43 & Marital_Status != Never-married,-0.014,0.222,0.006


In [6]:
lire = LocallyInterpretableRuleEnsembleClassifier(C_l0=1.65e-2, C_li=3e-1, max_iter=1000, forest=forest)
lire = lire.fit(X_tr, y_tr, feature_names=D.feature_names, feature_types=D.feature_types, class_names=D.class_names)
print('#### Test Accuracy')
print('- Accuracy:', lire.score(X_ts, y_ts))
print('- F1 score:', f1_score(y_ts, lire.predict(X_ts)))
print('- AUC     :', roc_auc_score(y_ts, lire.predict_proba(X_ts)))
print('#### Interpretability')
print('- Rule Candidates:', lire.coef_.shape[0])
print('- Global Support :', lire.global_support_cardinality())
print('- Local Support  :', lire.local_support_cardinality(X_ts).mean())
print('- Rule Diversity :', lire.diversity(X_ts))

#### Test Accuracy
- Accuracy: 0.8415428080088441
- F1 score: 0.9004475999382621
- AUC     : 0.866352944672449
#### Interpretability
- Rule Candidates: 150
- Global Support : 7
- Local Support  : 1.056749785038693
- Rule Diversity : 0.45489772283487817


In [7]:
lire.explain_global()

Unnamed: 0,No.,Rule,"Contribution to ""Income: <=50K""",Frequency,Importance
0,49,Capital_Gain > 5119,-1.536,0.049,0.331
1,30,Relationship = Own-child & Hours_per_week <= 49,1.255,0.145,0.441
2,20,Capital_Loss > 1820 & Capital_Loss <= 1978,-1.245,0.02,0.174
3,93,Marital_Status = Married-civ-spouse,-1.192,0.462,0.594
4,134,Hours_per_week <= 43 & Occupation = Other-service,0.906,0.088,0.256
5,119,Education-Num > 12,-0.801,0.247,0.346
6,29,Relationship != Own-child & Capital_Gain > 5095,-0.661,0.048,0.141
7,8,Intercept,1.637,1.0,0.0


In [8]:
X_ = X_ts[lire.local_support_cardinality(X_ts) > 1]
s = np.max(rufi.local_support_cardinality(X_) - lire.local_support_cardinality(X_))
X_target = X_[np.where(rufi.local_support_cardinality(X_) - lire.local_support_cardinality(X_) == s)[0]]

n = 0
pd.DataFrame(X_target[n].reshape(1, -1), columns=D.feature_names)

Unnamed: 0,Age,fnlwgt,Education-Num,Capital_Gain,Capital_Loss,Hours_per_week,Workclass:Unknown,Workclass:Federal-gov,Workclass:Local-gov,Workclass:Never-worked,...,Country:Portugal,Country:Puerto-Rico,Country:Scotland,Country:South,Country:Taiwan,Country:Thailand,Country:Trinadad&Tobago,Country:United-States,Country:Vietnam,Country:Yugoslavia
0,39,120985,9,0,0,40,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0


In [9]:
rufi.explain_local(X_target)[n]

Unnamed: 0,No.,Rule,"Contribution to ""Income: <=50K"""
0,52,Education-Num <= 12 & Capital_Gain <= 5119,1.006
1,36,Marital_Status != Married-civ-spouse & Education != Prof-school,0.644
2,79,Capital_Loss <= 1820 & Marital_Status != Married-civ-spouse,0.411
3,91,Marital_Status != Married-civ-spouse & Hours_per_week <= 44,0.312
4,115,Age > 31 & Sex = Male,-0.191
5,6,Marital_Status != Married-civ-spouse & Education != Masters,0.05


In [10]:
lire.explain_local(X_target)[n]

Unnamed: 0,No.,Rule,"Contribution to ""Income: <=50K"""
0,30,Relationship = Own-child & Hours_per_week <= 49,1.255
1,134,Hours_per_week <= 43 & Occupation = Other-service,0.906
2,3,Intercept,1.637
