Last updated 2020-02-03

In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.metrics import roc_auc_score

from clf_eval_utils import quickplot_eval_3, quickplot_eval_4

import shap

%matplotlib inline
%load_ext autoreload
%autoreload 2

TypeError: object of type <class 'float'> cannot be safely interpreted as an integer.

# Load NHFES data

Available for download at https://cdn1.sph.harvard.edu/wp-content/uploads/sites/1268/1268/20/nhefs.csv

In [None]:
full_data = pd.read_csv(os.path.join('data','nhefs.csv'))
full_data.drop(['yrdth','modth','dadth'], axis=1, inplace=True)
full_data.dropna(axis=1,inplace=True)
full_data.head()

In [None]:
data_desc = pd.read_excel(os.path.join('data','NHEFS_Codebook.xls'))
data_desc_dict = pd.Series(index=data_desc['Variable name'], data=data_desc['Description'].values).to_dict()
data_desc.head()

# Split train\test

In [None]:
data_train, data_test = train_test_split(full_data, test_size=0.3)
print(f'Train: {data_train.shape}, Test: {data_test.shape}')\

X_train, y_train = data_train.loc[:, data_train.columns!='death'], data_train.loc[:, 'death']
X_test, y_test = data_test.loc[:, data_test.columns!='death'], data_test.loc[:, 'death']

# Predict

In [None]:
clf = GradientBoostingClassifier(n_estimators=40, learning_rate=1.0,
    max_depth=1, random_state=0).fit(X_train, y_train)

y_pred_train = clf.predict_proba(X_train)[:,1]
y_pred_test = clf.predict_proba(X_test)[:,1]

print(f'Train AUC: {roc_auc_score(y_train,y_pred_train)}')
print(f'Test AUC: {roc_auc_score(y_test,y_pred_test)}')

# Eval

In [None]:
y_true = y_train
y_pred = y_pred_train

quickplot_eval_3(y_true, y_pred)

quickplot_eval_4(y_true, y_pred)

In [None]:
y_true = y_test
y_pred = y_pred_test

quickplot_eval_3(y_true, y_pred)

quickplot_eval_4(y_true, y_pred)

# simple SHAP

In [None]:
explainer = shap.TreeExplainer(clf)
shap_values = explainer.shap_values(X_train)

In [None]:
shap.summary_plot(shap_values, X_train.rename(columns=data_desc_dict), plot_type='bar')