# Imports

In [None]:
%load_ext autoreload
%autoreload 0

In [None]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [None]:
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score

In [4]:
from explainerdashboard.explainers import *
from explainerdashboard.dashboards import *

# load data

In [5]:
d_train = pd.read_csv('train.csv')
d_test = pd.read_csv('test.csv')

# Generate train and test set:

In [14]:
X_train = d_train.drop(['Survived', 'Name'], axis=1)
y_train = d_train['Survived']
X_test = d_test.drop(['Survived', 'Name'], axis=1)
y_test = d_test['Survived']

X_train.shape, y_train.shape, X_test.shape, y_test.shape

((691, 23), (691,), (200, 23), (200,))

# Save onehotencoded vars en names:

# Fit Random Forest model:

In [15]:
model = RandomForestClassifier(n_estimators=100)
model.fit(X_train, y_train)

RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
                       max_depth=None, max_features='auto', max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, n_estimators=100,
                       n_jobs=None, oob_score=False, random_state=None,
                       verbose=0, warm_start=False)

# Build explainer object:

In [16]:
explainer = RandomForestClassifierBunch(model, X_test, y_test, roc_auc_score, 
                                                       cats=['Sex', 'Cabin', 'Embarked'],
                                                       idxs=d_test['Name'].values.tolist(), #names of passengers 
                                                       labels=['Not survived', 'Survived'])

In [17]:
explainer.plot_pdp('Age', 0)

Calculating prediction probabilities...


# Build dashboard object:

In [18]:
db = ClassifierDashboard(explainer,
                        contributions=True,
                        shap_dependence=True,
                        shap_interaction=True,
                        classifier_summary=True)
db.run(8053)

Calculating importances...
Calculating shap values...
Generating shap TreeExplainer...
Calculating predictions...
Calculating shap interaction values...
Generating shap TreeExplainer...
Running Model Explainer on http://localhost:8053
 * Serving Flask app "explainerdashboard.dashboards" (lazy loading)
 * Environment: production
   Use a production WSGI server instead.
 * Debug mode: off


 * Running on http://127.0.0.1:8053/ (Press CTRL+C to quit)
127.0.0.1 - - [28/Oct/2019 13:55:20] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [28/Oct/2019 13:55:20] "GET /_dash-component-suites/dash_renderer/react@16.8.6.min.js?v=1.0.0&m=1564145200 HTTP/1.1" 200 -
127.0.0.1 - - [28/Oct/2019 13:55:21] "GET /_dash-component-suites/dash_renderer/react-dom@16.8.6.min.js?v=1.0.0&m=1564145200 HTTP/1.1" 200 -
127.0.0.1 - - [28/Oct/2019 13:55:21] "GET /_dash-component-suites/dash_renderer/prop-types@15.7.2.min.js?v=1.0.0&m=1564145200 HTTP/1.1" 200 -
127.0.0.1 - - [28/Oct/2019 13:55:21] "GET /_dash-component-suites/dash_daq/bundle.js?v=0.1.7&m=1564144129 HTTP/1.1" 200 -
127.0.0.1 - - [28/Oct/2019 13:55:21] "GET /_dash-component-suites/dash_core_components/highlight.pack.js?v=1.0.0&m=1564145200 HTTP/1.1" 200 -
127.0.0.1 - - [28/Oct/2019 13:55:21] "GET /_dash-component-suites/dash_core_components/plotly-1.48.3.min.js?v=1.0.0&m=1564145200 HTTP/1.1" 200 -
127.0.0.1 - - [28/Oct/2019 13:55:21] "GET /_dash-co

Generating shap TreeExplainer...


127.0.0.1 - - [28/Oct/2019 13:55:28] "POST /_dash-update-component HTTP/1.1" 200 -
127.0.0.1 - - [28/Oct/2019 13:55:28] "POST /_dash-update-component HTTP/1.1" 200 -
127.0.0.1 - - [28/Oct/2019 13:55:29] "POST /_dash-update-component HTTP/1.1" 200 -
127.0.0.1 - - [28/Oct/2019 13:55:34] "POST /_dash-update-component HTTP/1.1" 200 -
127.0.0.1 - - [28/Oct/2019 13:55:34] "POST /_dash-update-component HTTP/1.1" 200 -
127.0.0.1 - - [28/Oct/2019 13:55:34] "POST /_dash-update-component HTTP/1.1" 200 -
127.0.0.1 - - [28/Oct/2019 13:55:47] "POST /_dash-update-component HTTP/1.1" 200 -
127.0.0.1 - - [28/Oct/2019 13:57:29] "POST /_dash-update-component HTTP/1.1" 200 -
127.0.0.1 - - [28/Oct/2019 13:57:59] "POST /_dash-update-component HTTP/1.1" 200 -
127.0.0.1 - - [28/Oct/2019 13:58:14] "POST /_dash-update-component HTTP/1.1" 200 -


In [None]:
db2 = RandomForestDashboard(explainer,
                        contributions=True,
                        shap_dependence=True,
                        shap_interaction=True,
                        shadow_trees=True)
db2.run(8054)

In [None]:
db3 = RandomForestClassifierDashboard(explainer,
                        contributions=True,
                        shap_dependence=True,
                        shap_interaction=True,
                        shadow_trees=True,
                        classifier_summary=True)
db3.run(8055)