In [None]:
import pandas as pd
import xgboost as xgb
import numpy as np
import sklearn
from sklearn.preprocessing import OneHotEncoder
from sklearn.impute import SimpleImputer
from imblearn.over_sampling import SMOTE
from imblearn.pipeline import make_pipeline
from imblearn.combine import SMOTEENN
from imblearn.over_sampling import SMOTENC
import matplotlib.pyplot as plt
from PIL import Image
from scipy.interpolate import BSpline, make_interp_spline, interp1d
import rpy2.robjects as robjects
from rpy2.robjects.packages import importr
import csv
from dfply import *
from xgboost import XGBClassifier
import itertools
import os
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import StratifiedKFold
import time
import pickle
from glob import glob
from sklearn.metrics import roc_auc_score
from catboost import Pool, cv
import xgboost
import catboost
import scipy.stats as st

In [None]:
def collectSHAP_sub(site, year, stg, fs, oversample, model_type, ckd_group=0, returnflag=False):
# site = 'MCRI'
# year = 3000
# stg = 'stg23'
# fs = 'onlymed'
# oversample = 'raw'
# model_type = 'catd'
# ckd_group=0
# returnflag=False
# if True:
    print('Running shap '+model_type+' on site '+site+":"+str(year)+":"+stg+":"+fs+":"+oversample, flush = True)
    tic = time.perf_counter()     

    #load model
    print('data/'+site+'/model_'+model_type+'_'+site+'_'+str(year)+'_'+stg+'_'+fs+'_'+oversample+'.pkl')
    print('data/'+site+'/X_train_'+site+'_'+str(year)+'_'+stg+'_'+fs+'_'+oversample+'.pkl')
    
    model = pickle.load(open('data/'+site+'/model_'+model_type+'_'+site+'_'+str(year)+'_'+stg+'_'+fs+'_'+oversample+'.pkl', 'rb'))

    #load tables
    X_train = pd.read_pickle('data/'+site+'/X_train_'+site+'_'+str(year)+'_'+stg+'_'+fs+'_'+oversample+'.pkl')
    X_test =  pd.read_pickle('data/'+site+'/X_test_' +site+'_'+str(year)+'_'+stg+'_'+fs+'_'+oversample+'.pkl')
    y_train = pd.read_pickle('data/'+site+'/y_train_'+site+'_'+str(year)+'_'+stg+'_'+fs+'_'+oversample+'.pkl')
    y_test =  pd.read_pickle('data/'+site+'/y_test_' +site+'_'+str(year)+'_'+stg+'_'+fs+'_'+oversample+'.pkl')

    X_train_ckdg = pd.read_pickle('data/'+site+'/X_train_ckdg_'+site+'_'+str(year)+'_'+stg+'_nofs_raw.pkl')
    X_test_ckdg  = pd.read_pickle( 'data/'+site+'/X_test_ckdg_' +site+'_'+str(year)+'_'+stg+'_nofs_raw.pkl')        
        
    # Get AUC
#    pred = model.get_booster().predict(dtest, pred_contribs=False)
#    pred = model.predict(X_test)    
#    roc = roc_auc_score(y_test, pred)    

    if ckd_group != 0:
        X_train = X_train[X_train_ckdg['CKD_group']==ckd_group]
        X_test = X_test[X_test_ckdg['CKD_group']==ckd_group]
        y_train = y_train[X_train_ckdg['CKD_group']==ckd_group]
        y_test = y_test[X_test_ckdg['CKD_group']==ckd_group]

    pred = model.predict_proba(X_test)
    roc = roc_auc_score(y_test, pred[:,1])       
    
    shapX = pd.concat([X_train, X_test])
    shapy = pd.concat([y_train, y_test])
    
    # Calculate SHAP value
    if type(model) == xgboost.sklearn.XGBClassifier:
        dshap  = xgb.DMatrix(shapX, label=shapy)
        shap = model.get_booster().predict(dshap, pred_contribs=True)
        # Get feature importance
        model_data = pd.concat([pd.DataFrame(model.get_booster().get_score(importance_type='cover'), index=['Cover']), \
        pd.DataFrame(model.get_booster().get_score(importance_type='gain'), index=['Gain']), \
        pd.DataFrame(model.get_booster().get_score(importance_type='weight'), index=['Frequency'])]).transpose() >> mutate(Feature = X.index)
        model_data['rank'] = model_data['Gain'].rank(method='min', ascending=False)
        used_feature = list(model.get_booster().get_score().keys())        
    elif type(model) == catboost.core.CatBoostClassifier:
        cat_features = model.get_cat_feature_indices()
        pshap = Pool(data=shapX, label=shapy, cat_features=cat_features)        
        shap = model.get_feature_importance(data=pshap, type='ShapValues')
        model_data = model.get_feature_importance(prettified=True)
        model_data['Feature'] = model_data['Feature Id']
        model_data = model_data >> select('Feature', 'Importances')
        model_data['rank'] = model_data['Importances'].rank(method='min', ascending=False)     
        used_feature = list((model_data >> mask(X.Importances!=0)).Feature)
    else:
    #Using shap package example
        import shap
        explainer = shap.Explainer(model, algorithm='permutation')
        shap_valuesX = explainer.shap_values(shapX)
        #shap.summary_plot(shap_valuesX, X_test, plot_type="bar")    
        shap = shap_valuesX    

    
    # Collect SHAP value
    def CI95(data):
        if len(data) == 1:
            return (np.nan, np.nan)
        return (np.nan, np.nan)            
#        return st.t.interval(alpha=0.95, df=len(data)-1, loc=np.mean(data), scale=st.sem(data)) #95% confidence interval

    shap_data = list()
    shap_data_raw = list()
    for i in range(shapX.columns.shape[0]):
        df = pd.DataFrame(list(zip(shapX.iloc[:,i], shap[:, i], abs(shap[:, i]))),columns =['Name', 'val', 'absval'])
        # Check confidence interval for one data point
        plot_data = df.groupby("Name").agg([np.mean, np.var, np.std, np.median, CI95, 'size']).reset_index()
        df.index = shapX.index

        plot_data_all = df.groupby("Name").agg([np.mean, np.var, np.std, np.median, CI95, 'size']).reset_index()
        plot_data_0= df[shapy==0].groupby("Name").agg([np.mean, np.var, np.std, np.median, CI95, 'size']).reset_index()
        plot_data_1= df[shapy==1].groupby("Name").agg([np.mean, np.var, np.std, np.median, CI95, 'size']).reset_index()

        plot_data_all.columns = [''.join(x) for x in plot_data_all.columns]
        plot_data_0.columns = [x+'_0' for x in plot_data_all.columns]
        plot_data_1.columns = [x+'_1' for x in plot_data_all.columns]
        plot_data_all = plot_data_all.drop('absvalsize', axis=1)
        plot_data_0   = plot_data_0.drop('absvalsize_0', axis=1)
        plot_data_1   = plot_data_1.drop('absvalsize_1', axis=1)
        plot_data_0 = plot_data_0.rename({'Name_0':'Name'},axis=1)
        plot_data_1 = plot_data_1.rename({'Name_1':'Name'},axis=1)
        plot_data = pd.merge(plot_data_all, plot_data_0, left_on='Name', right_on='Name', how='left')
        plot_data = pd.merge(plot_data, plot_data_1, left_on='Name', right_on='Name', how='left')        
        
        plot_data = plot_data >> mutate(Feature=shapX.columns[i])
        plot_data.columns = [''.join(x) for x in plot_data.columns]
        plot_data[['valCI95down', 'valCI95up']] = pd.DataFrame(plot_data['valCI95'].tolist(), index=plot_data.index)
        plot_data[['absvalCI95down', 'absvalCI95up']] = pd.DataFrame(plot_data['absvalCI95'].tolist(), index=plot_data.index)
        plot_data = plot_data.drop(['valCI95', 'absvalCI95'],axis=1)
        shap_data.append(plot_data.copy())        
        plot_data_raw = df >> select(X.Name, X.val) >> mutate(Feature=shapX.columns[i])        
        shap_data_raw.append(plot_data_raw.copy())
    shap_data = pd.concat(shap_data)
    shap_data_raw = pd.concat(shap_data_raw)    
#    shap_data= shap_data[shap_data['Feature'].isin(used_feature)]

    # create csv for metaregression
    shap_data = shap_data >> left_join(model_data, by='Feature')
    siteyr = site+'_'+model_type+'_'+fs+'_'+stg+'_'+oversample+'_'+'005'+"_"+str(year)    
    shap_data = shap_data >> mutate(siteyr=siteyr) >> rename(fval=X.Name) >> rename(mean_val=X.valmean) >> rename(se_val=X.valstd) >> rename(mean_imp = X.absvalmean) >> rename(se_imp = X.absvalstd) >> rename(var_imp = X.absvalvar) >> rename(median_val = X.valmedian) >> rename(median_imp = X.absvalmedian) >> rename(var_val = X.valvar)
    shap_data['site'] = site
    shap_data['year'] = year
    shap_data['stg'] = stg
    shap_data['fs'] = fs
    shap_data['oversample'] = oversample
    shap_data['model'] = model_type
    shap_data['rmcol'] = '005'
    
    # Calculate ranking base on absolute mean value of SHAP
    rank_abs_shap_max = (shap_data >> mutate(abs_shap_max = abs(X.mean_val))).loc[:,['Feature', 'abs_shap_max']].groupby(['Feature']).agg(np.max).reset_index()
    rank_abs_shap_max['rank_abs_shap_max'] = rank_abs_shap_max['abs_shap_max'].rank(method='min', ascending=False)
    shap_data = pd.merge(shap_data, rank_abs_shap_max, left_on=['Feature'], right_on=['Feature'], how='left')

    #Calculate ranking base on SHAP min max difference and variance
    tdata = shap_data.loc[:,['Feature', 'mean_val']].groupby(['Feature']).agg([np.max,np.min,np.var]).reset_index()
    tdata.columns = ['Feature', 'maxSHAP', 'minSHAP', 'varSHAP']
    tdata = (tdata >> mutate(minmax_SHAP = X.maxSHAP-X.minSHAP))
    tdata['rank_minmax_SHAP'] = tdata['minmax_SHAP'].rank(method='min', ascending=False)
    tdata['rank_var_SHAP'] = tdata['varSHAP'].rank(method='min', ascending=False)
    shap_data = pd.merge(shap_data, tdata, left_on=['Feature'], right_on=['Feature'], how='left')    

    # add auc value
    shap_data = shap_data >> mutate(auc=roc)
    
    #sort
    shap_data = shap_data.sort_values(['rank', 'fval'])

    #calculate confusion matrix
    cdata = pd.concat([pd.concat([X_train, y_train], axis=1), pd.concat([X_test, y_test], axis=1)], axis=0)
    cmdata = cdata.melt(id_vars='FLAG', value_vars= list(cdata.columns).remove('FLAG'))
    conmat = cmdata.groupby(['FLAG', 'variable','value']).size().reset_index()
    conmat2 = conmat.pivot(index=['variable', 'value'], columns='FLAG', values=0).fillna(0).reset_index()
    conmat2.columns = ['Feature', 'fval', 'b', 'a']    
    conmat3 = cmdata.groupby(['FLAG', 'variable']).size().reset_index()
    conmat4 = conmat3.pivot(index=['variable'], columns='FLAG', values=0).fillna(0).reset_index()
    conmat4.columns = ['Feature', 'd', 'c'] 
    conmat5 = pd.merge(conmat2, conmat4, left_on='Feature', right_on='Feature', how='left')
    conmat6 = conmat5 >> mutate(d=X.d-X.b) >> mutate(c=X.c-X.a) >> mutate(num=X.a+X.b)
    conmat6['fval'] = conmat6['fval'].astype('float64')
    shap_data = pd.merge(shap_data, conmat6, left_on=['Feature', 'fval'], right_on=['Feature', 'fval'], how='left')

    
    #is categorical?
    X_test =  pd.read_pickle('data/'+site+ '/X_test_'+site+'_'+str(year)+'_'+stg+'_'+fs+'_'+oversample+'.pkl')
    cat_features = pd.DataFrame(list(X_test.select_dtypes('bool').columns)) >> mutate(isCategorical = True)
    cat_features.columns = ['Feature', 'isCategorical']
    shap_data = pd.merge(shap_data, cat_features, right_on='Feature', left_on='Feature', how='left')
    shap_data.loc[:,'isCategorical'] = shap_data.loc[:,'isCategorical'].fillna(False)    
    
    #Collect fval range and stats
    Xdata = pd.concat([X_train, X_test], axis=0)
    try:
        filtertable = Xdata.select_dtypes(exclude=bool).agg([np.min, np.max, np.mean,np.std],axis=0).transpose()
        filtertable = filtertable.assign(upr=filtertable['mean']+3*filtertable['std']).assign(lwr=filtertable['mean']-3*filtertable['std']).reset_index().rename({'index':'Feature', 'mean':'fval_mean', 'std':'fval_std', 'upr':'fval_upr', 'lwr':'fval_lwr', 'amax':'fval_max', 'amin':'fval_min'},axis=1)
        shap_data  = pd.merge(shap_data, filtertable, right_on='Feature', left_on='Feature', how='left')
    except:
        pass
    
    #Save shap_data 
    if returnflag:
        return shap_data, shap_data_raw
    else:
        shap_data.to_pickle('data/'+site+'/shapdata_'+model_type+'_'+site+'_'+str(year)+'_'+stg+'_'+fs+'_'+oversample+'_'+str(ckd_group)+'_005.pkl')
        shap_data_raw.to_pickle('data/'+site+'/shapdataraw_'+model_type+'_'+site+'_'+str(year)+'_'+stg+'_'+fs+'_'+oversample+'_'+str(ckd_group)+'_005.pkl')
    #model.to_pickle('data/'+site+'/model_data_'+site+'_'+str(year)+'.pkl')

    toc = time.perf_counter()
    print(f"{site}:{year} finished in {toc - tic:0.4f} seconds")  
    print('Finished shap '+model_type+' on site '+site+":"+str(year)+":"+stg+":"+fs+":"+oversample, flush = True)    

#print('done')

In [None]:
def collectSHAP(site, year, stg, fs, oversample, model_type):
    shap_data_list = list()
    shap_data_raw_list = list()    
    for i in range(1,5):
        try:
            shap_data, shap_data_raw = collectSHAP_sub(site, year, stg, fs, oversample, model_type, ckd_group=i, returnflag=True)
        except Exception as error:
            print(site+":"+str(year)+":"+stg+":"+fs+":"+oversample+":"+model_type+" raised " + "error" +"\n"+error.traceback)        
        shap_data['ckd_group'] = i
        shap_data_raw['ckd_group'] = i
        shap_data_list.append(shap_data.copy())
        shap_data_raw_list.append(shap_data_raw.copy())
    shap_data_all = pd.concat(shap_data_list)
    shap_data_raw_all = pd.concat(shap_data_raw_list)
    shap_data_all.to_pickle('data/'+site+'/shapdata_'+model_type+'_'+site+'_'+str(year)+'_'+stg+'_'+fs+'_'+oversample+'_005.pkl')
    shap_data_raw_all.to_pickle('data/'+site+'/shapdataraw_'+model_type+'_'+site+'_'+str(year)+'_'+stg+'_'+fs+'_'+oversample+'_005.pkl')