This notebook generates SHAP plots based on the final model fully trained on the training data set. It requires access to Dataset101 folder containing the train/test datasets, as well as the pickled benchmark models.

In progress: Use KernelExplainer to generate SHAP plots for the trained FEAT model....

In [None]:
#Import Packages
import shap
import pickle
import pandas as pd
import warnings

#Read test dataset (300 random patients)
targets = {
            'htn_dx_ia':'Htndx',
            'res_htn_dx_ia':'ResHtndx', 
            'htn_hypok_dx_ia':'HtnHypoKdx', 
            'HTN_heuristic':'HtnHeuri', 
            'res_HTN_heuristic':'ResHtnHeuri',
            'hypoK_heuristic_v4':'HtnHypoKHeuri'
            }
rev_targets = {v:k for k,v in targets.items()}
nice_targets = {
            'Htndx':'HTN',
            'ResHtndx':'Resistant HTN', 
            'HtnHypoKdx':'HTN-Hypokalemia', 
            'HtnHeuri':'HTN Heuristic', 
            'ResHtnHeuri':'Resistant HTN Heuristic',
            'HtnHypoKHeuri':'HTN-Hypokalemia Heuristic'
            }
nice_model = {
          'LogisticRegression_L1':'LR L1',
          'LogisticRegression_L2':'LR L2',
          'RandomForest':'RF',
          'Feat_reconstruct_linear':'FEAT linear',
          'Feat_reconstruct_kernel':'FEAT'
}
drop_cols = ['UNI_ID'] + list(targets.keys())
rdir = 'resultsFinal_r1'

#Select 20 random patients
#NB: This needs to be fixed to represent more true positive patients.
# select = range(20)

In [None]:
## load short feature names
class smart_dict(dict):
    def __missing__(self, key):
#         print(key)
        return key
    
df_names = pd.read_csv('Feat_Variable_Names.csv')
print(df_names.columns)
feature_nice = smart_dict({row['Variable Name ']:row['Short Name'] for _, row in df_names.iterrows()})
print(feature_nice)

In [None]:
from glob import glob
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import numpy as np
import math
import os
from feat_transformer import FeatTransformer

%matplotlib inline

def get_shap_values(model, target, interactions=False):
    """Makes a shap decision and summary plot for a model and target."""
    print(model,target)
    target_raw = rev_targets[target]
    df_train = pd.read_csv(
                '../Dataset' + str(101) + '/' + target + '/' + target + 'ATrain.csv')
    df_X_train = df_train.drop(drop_cols,axis=1)  
    df_test = pd.read_csv(
                '../Dataset' + str(101) + '/' + target + '/' + target + 'ATest.csv')
    df_X_test = df_test.drop(drop_cols,axis=1)  
    df_y_test = df_test[target_raw].values
    
    if 'Feat' in model:
        pickles = ['Feat_reconstruct_res_htn_dx_ia_A_101_1318.pkl']
        #TODO: grab column names
    else:
        pickles = glob('../'+rdir+'/' + target_raw + '/' + model + '/' + '*.pkl')
    print(pickles)
    assert(len(pickles)==1)
    name = pickles[0]
    print('loading',name)
    m = pickle.load(open(name,'rb'))
    feature_names = df_X_train.columns
    if model in ['RandomForest','DecisionTree']:
        explainer = shap.TreeExplainer(m)
        expected_value = explainer.expected_value[1]
        features = df_X_test
    elif 'Logistic' in model or model == 'Feat_reconstruct_linear': 
        feature_perturbation = "correlation_dependent" if interactions else "interventional"
        m_prep = m.named_steps['prep']
        m_est = m.named_steps['est']
        if 'Feat' in model:
            feature_names = m_prep.feature_names
        print('feature_names:',feature_names)
        df_X_train_trans = pd.DataFrame(m_prep.transform(df_X_train), 
#                                         index=df_X_train.index, 
                                        columns=feature_names)
        df_X_test_trans = pd.DataFrame(m_prep.transform(df_X_test), 
#                                         index=df_X_test.index, 
                                       columns=feature_names)
        explainer = shap.LinearExplainer(m_est, df_X_train_trans,
                                         feature_perturbation=feature_perturbation,
                                        )
        
        expected_value = explainer.expected_value
        features = df_X_test_trans
    elif model == 'Feat_reconstruct_kernel':
        explainer = shap.KernelExplainer(m.predict_proba, shap.kmeans(df_X_test, 10))
        features = df_X_test
        expected_value = explainer.expected_value[1]
        
#     expected_value = expected_value[1]
    print(f"Explainer expected value: {expected_value}")

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        shap_values = explainer.shap_values(features) #[1]
    if model in ['RandomForest','DecisionTree', 'Feat_reconstruct_kernel'] :
        shap_values = shap_values[1]
            
    y_pred = m.predict(df_X_test)
    y_pred_proba = m.predict_proba(df_X_test)[:,1]
#     print('y_pred:',y_pred)

    return {
        'shap_values':shap_values, 
        'expected_value':expected_value, 
        'y_true':df_y_test, 
        'y_pred':y_pred, 
        'y_pred_proba':y_pred_proba,
        'features':features, 
        'feature_names':feature_names
    }


# calculate shap values
commented out for now in lieu of loading pickles.

In [None]:
import pickle
models = [
#           'LogisticRegression_L1',
#           'LogisticRegression_L2',
# #           'RandomForest',
          'Feat_reconstruct_linear',
          'Feat_reconstruct_kernel'
         ]
# models = ['Feat_reconstruct']
targets = ['ResHtndx']
shap_values, expected_value, y_true, y_pred, misclassified, features, feature_names = {},{},{},{},{},{},{} 
results = {}
for model in models:
    if model == 'RandomForest' or model == 'Feat_reconstruct_kernel':
        interactions = [False]
    else:
        interactions = [True, False]
    for interaction in interactions: 
        for target in targets:
        
            idx = (model, target, interaction)
            results[idx] = get_shap_values(model, target, interactions=interaction)
            
            # save shap values data
            if not os.path.exists('shap_values/'+rdir):
                os.makedirs('shap_values/'+rdir)
            with open('shap_values/'+rdir+'-'.join([str(i) for i in idx])+'.pkl', 'wb') as f:
                pickle.dump(results[idx],f)
                

In [None]:
from glob import glob
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import numpy as np
import math
import os
from feat_transformer import FeatTransformer

%matplotlib inline
def select(y_true, y_pred, n_samples):
    """Return a selection of points to visualize"""
    random_state = 42 # don't touch unless you want to get a new random sample!
    np.random.seed(random_state)
    misclassified = y_pred != y_true #[select]
    print('misclassified samples:',np.sum(misclassified))
    rate = np.sum(misclassified)/len(y_pred)
    print('misclassification rate:',rate)
    # select a subset of samples for decision plot. 
    # pick n_samples/2 positive cases and n_samples/2 negative cases. 
    # half of each should be misclassified. 
    idx = np.arange(len(y_true))
    miss_pos = (y_true==1) & misclassified
    miss_neg = (y_true==0) & misclassified
    hit_pos = (y_true==1) & ~misclassified
    hit_neg = (y_true==0) & ~misclassified
    hit_pos_samples = np.random.choice(idx[hit_pos], size=math.ceil((1-rate)*n_samples/2))
    miss_pos_samples = np.random.choice(idx[miss_pos], size=math.ceil((rate)*n_samples/2))
    hit_neg_samples = np.random.choice(idx[hit_neg], size=math.ceil((1-rate)*n_samples/2))
    miss_neg_samples = np.random.choice(idx[miss_neg], size=math.ceil(rate*n_samples/2))
    print(
        'positive hits:',len(hit_pos_samples),
        'positive misses:',len(miss_pos_samples),
        'negative hits:',len(hit_neg_samples),
        'negative misses:',len(miss_neg_samples),
         )
    select = list(hit_pos_samples) + list(miss_pos_samples) + list(hit_neg_samples) + list(miss_neg_samples)

#     print('select:',select)
    return select, misclassified
    
def make_shap_plots(model, target, shap_values, expected_value, select,
                    misclassified,  features, feature_names, 
                    n_features = 20, interactions=False,
                    axes=[], axes_labels=['A','B']):
    """Makes a shap decision and summary plot for a model and target."""
    print('axes:',axes)
    if axes == []:
        figsize=(12,6)
        fig, (ax1, ax2) = plt.subplots(1,2, figsize=figsize)
        save = True
    else:
        ax1 = axes[0]
        ax2 = axes[1]
        save=False
    plt.sca(ax1) 
    
    nice_feature_names = [' > '.join([feature_nice[fn.split('>')[0]]]+fn.split('>')[1:]) 
                            for fn in feature_names]
#     feature_order = feature_order[-min(n_features, len(feature_order)):]
        
    shap.summary_plot(shap_values, 
                      features=features, 
                      feature_names=nice_feature_names,
                      show=False,
#                       cmap='viridis',
                      max_display=n_features,
                      sort=True
                     )
#     plt.gca().set_title(model+' model of ' + target)
#     plt.gcf().set_size_inches(figsize)
    ax1.text(-1,1.0,axes_labels[0],
             color='k',fontsize=24, 
             transform=ax1.transAxes )
    print('shap select;',shap_values[select].shape)
    print('features select;',features.loc[select,:].shape)
    plt.sca(ax2) 
    if 'linear' in model.lower() or 'logistic' in model.lower():
        link='logit'
    else:
        link='identity'
    print('link:',link)
    feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0))
    
    shap.decision_plot(expected_value, shap_values[select], features.loc[select,:].reset_index(), 
                       feature_names=nice_feature_names,
                       feature_display_range=slice(None, -(n_features+1), -1),
                       ignore_warnings=True,
                       highlight=misclassified[select],
                       show=False,
                       plot_color='viridis',
                       link=link,
                       feature_order = feature_order
                      )
    plt.plot([0.5, 0.5],[0,n_features],':k', alpha=0.3)
    ax2.text(-0.2,1.0, axes_labels[1], 
             color='k', fontsize=24, 
             transform=ax2.transAxes)
#     ax2.XTick.remove()
    ax2.set_yticklabels([])
#     if 'randomforest' in model.lower():
#         ax2.set_xlim((0,1))
#     ax2.colorbar()
    
    if 'feat' in model.lower(): 
        model_name = 'FEAT'
    else:
        model_name = model 
        
#     plt.gca().set_title(model+' model of ' + target)

#     #Just highlight #11 (misclassified patient)
#     shap.decision_plot(expected_value, shap_values[11], 
#                        feature_names=features.columns,
#                        features.iloc[11], feature_display_range=slice(None, -21, -1),
#                        ignore_warnings=True, highlight=0)

    if save:
        fig.suptitle(nice_model[model]+' model of ' + target)
        fig.set_size_inches(figsize)
        plt.tight_layout()
        name = 'shap_'+target+'_'+model
        if interactions and model != 'RandomForest':
            name = name + '_interactions'
        if not os.path.exists('figs/'+rdir):
            os.mkdir('figs/'+rdir)
        for filetype in ['.pdf','.png','.svg']:
            plt.savefig('figs/'+rdir+'/'+name+filetype, dpi=400)
    else:
        plt.suptitle(nice_model[model]+' model of ' + target)
        return ax1, ax2

In [None]:
import numpy as np

def get_ids(model, target, selected, misclassified):
    target_raw = rev_targets[target]
    df_test = pd.read_csv(
                '../Dataset' + str(101) + '/' + target + '/' + target + 'ATest.csv')
    patient_ids = df_test['UNI_ID'].values

    df_X_test = df_test.drop(drop_cols,axis=1)  
    df_y_test = df_test[target_raw].values
   
    if 'Feat' in model:
        pickles = ['Feat_reconstruct_res_htn_dx_ia_A_101_1318.pkl']
        #TODO: grab column names
    else:
        pickles = glob('../'+rdir+'/' + target_raw + '/' + model + '/' + '*.pkl')
    print(pickles)
    assert(len(pickles)==1)
    name = pickles[0]
    print('loading',name)
    m = pickle.load(open(name,'rb'))
    
    y_pred = m.predict(df_X_test)
    y_pred_proba = m.predict_proba(df_X_test)[:,1]
   
    selected= np.array(selected)
    misclassified = np.array(misclassified)
    miss_select = selected[misclassified[selected]] 
    y_pred_select = y_pred_proba[selected]
    y_true_select = df_y_test[selected]
    pt_select = patient_ids[selected]
    y_pred_miss = y_pred_proba[miss_select]
    pt_miss = patient_ids[miss_select]
    
    idx = np.argsort(y_pred_select)[::-1]
    print('shap patient selections for',model,'of',target)
    for pt, lab, yp in zip(pt_select[idx], y_true_select[idx], y_pred_select[idx]):
        if pt in pt_miss: 
            print('MISS patient:{}, label:{}, prediction:{}'.format(pt,lab,yp))
        else:
            print('patient:{}, label:{}, prediction:{}'.format(pt,lab,yp))
        


In [None]:
# models = ['Feat_reconstruct_kernel']
# targets = ['ResHtndx']
models = [
          'LogisticRegression_L1',
          'LogisticRegression_L2',
          'RandomForest',
          'Feat_reconstruct_linear',
          'Feat_reconstruct_kernel'
         ]
targets = ['ResHtndx']
n_samples = 20
interactions=True
for model in models:
    n_features = 20 if 'Feat' not in model else 6
    if model == 'RandomForest' or model == 'Feat_reconstruct_kernel':
        interactions = [False]
    else:
        interactions = [True, False]
    for interaction in interactions:
        for target in targets:
            idx = (model, target, interaction)
            print(idx)
            with open('shap_values/'+rdir+'-'.join([str(i) for i in idx])+'.pkl','rb') as f:
                results = pickle.load(f)
#             print('results:',results)
            selected, misclassified =  select(results['y_true'], results['y_pred'], n_samples)
            get_ids(model, target, selected, misclassified)
            make_shap_plots(model, target, 
                            results['shap_values'], 
                            results['expected_value'], 
                            selected,
                            misclassified,
                            results['features'], 
                            results['feature_names'],
                            n_features=n_features,
                            interactions=interaction)

## make combined logistic regression / feat plot

In [None]:
models = [
          'LogisticRegression_L1',
          'Feat_reconstruct_linear'
         ]
# models = ['Feat_reconstruct']
targets = ['ResHtndx']

n_samples = 20
interactions=True
fig, axes = plt.subplots(2,2, figsize=(12,12))
print('axes shape:',axes.shape)
i=0
for model in models:
    n_features = 20 if 'Feat' not in model else 6
#     interaction = model != 'Feat_reconstruct_kernel'
    interaction = False
    for target in targets:
        idx = (model, target, interaction)
        print(idx)
        with open('shap_values/'+rdir+'-'.join([str(i) for i in idx])+'.pkl','rb') as f:
            results = pickle.load(f)
        selected, misclassified =  select(results['y_true'], results['y_pred'], n_samples)
        ax1, ax2 = make_shap_plots(model, target, 
                        results['shap_values'], 
                        results['expected_value'], 
                        selected,
                        misclassified,
                        results['features'], 
                        results['feature_names'],
                        n_features=n_features,
                        interactions=interaction,
                        axes=axes[i],
                        axes_labels = ['A','B'] if i==0 else ['C','D'])
    ax1.set_title('LR L1 model of Resistant HTN' if i == 0 else 'FEAT model of Resistant HTN')
    ax2.set_title('Example Decisions')
    i+=1
fig.set_size_inches((12,12))
fig.suptitle('')
fig.tight_layout()

for filetype in ['.pdf','.png','.svg']:
    plt.savefig('figs/'+rdir+'/'+'shap_combined_LRL1_FEAT'+filetype, dpi=400)
plt.show()

## make combined logistic regression / feat plot with interactions

In [None]:
models = [
          'LogisticRegression_L1',
          'Feat_reconstruct_linear'
         ]
# models = ['Feat_reconstruct']
targets = ['ResHtndx']

n_samples = 20
interactions=True
fig, axes = plt.subplots(2,2, figsize=(12,12))
print('axes shape:',axes.shape)
i=0
for model in models:
    n_features = 20 if 'Feat' not in model else 6
#     interaction = model != 'Feat_reconstruct_kernel'
    interaction = True
    for target in targets:
        idx = (model, target, interaction)
        print(idx)
        with open('shap_values/'+rdir+'-'.join([str(i) for i in idx])+'.pkl','rb') as f:
            results = pickle.load(f)
        selected, misclassified =  select(results['y_true'], results['y_pred'], n_samples)
        ax1, ax2 = make_shap_plots(model, target, 
                        results['shap_values'], 
                        results['expected_value'], 
                        selected,
                        misclassified,
                        results['features'], 
                        results['feature_names'],
                        n_features=n_features,
                        interactions=interaction,
                        axes=axes[i],
                        axes_labels = ['A','B'] if i==0 else ['C','D'])
    ax1.set_title('LR L1 model of Resistant HTN' if i == 0 else 'FEAT model of Resistant HTN')
    ax2.set_title('Example Decisions')
    i+=1
fig.set_size_inches((12,12))
fig.suptitle('')
fig.tight_layout()

for filetype in ['.pdf','.png','.svg']:
    plt.savefig('figs/'+rdir+'/'+'shap_combined_LRL1_FEAT_interactions'+filetype, dpi=400)
plt.show()

In [None]:
# grab missing keys
models = [
          'LogisticRegression_L1',
          'Feat_reconstruct_linear',
          'Feat_reconstruct_kernel',
          'RandomForest'
         ]
for model in models:
    n_features = 20 if 'Feat' not in model else 6
#     interaction = model != 'Feat_reconstruct_kernel'
    interaction = False
    for target in targets:
        idx = (model, target, interaction)
        print(idx)
        with open('shap_values/'+rdir+'-'.join([str(i) for i in idx])+'.pkl','rb') as f:
            results = pickle.load(f)
        shap_values = results['shap_values']
        mean_abs_sv = np.mean(np.abs(shap_values),axis=0)
        sv_order = np.argsort(mean_abs_sv)[::-1]
        feature_names = np.array(results['feature_names'])[sv_order[:n_features]]
        
        nice_feature_names = [' > '.join([feature_nice[fn.split('>')[0]]]+fn.split('>')[1:]) 
                            for fn in feature_names]