In [None]:
import pandas as pd
import xgboost as xgb
import numpy as np
import sklearn
from sklearn.preprocessing import OneHotEncoder
from sklearn.impute import SimpleImputer
from imblearn.over_sampling import SMOTE
from imblearn.pipeline import make_pipeline
from imblearn.combine import SMOTEENN
from imblearn.over_sampling import SMOTENC
import matplotlib.pyplot as plt
from PIL import Image
from scipy.interpolate import BSpline, make_interp_spline, interp1d
#import rpy2.robjects as robjects
#from rpy2.robjects.packages import importr
import csv
from dfply import *
from xgboost import XGBClassifier
import itertools
import os
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import StratifiedKFold
import time
import pickle
import math

import importlib
import ipynb.fs.full.postprocessing3_collect

In [None]:
def plot_importance(df, ax=None, height=0.2,
                    xlim=None, ylim=None,
                    xlabel='score', ylabel='Feature', fmap='',
                    importance_type='auc', max_num_features=None,
                    grid=True, show_values=True, **kwargs):

    title = importance_type
    importance = (df >> select(ylabel, importance_type)).set_index('Feature').to_dict()[importance_type]
    tuples = [(k, importance[k]) for k in importance]
    if max_num_features is not None:
        # pylint: disable=invalid-unary-operand-type
        tuples = sorted(tuples, key=lambda x: x[1])[-max_num_features:]
    else:
        tuples = sorted(tuples, key=lambda x: x[1])
    labels, values = zip(*tuples)
    
    if ax is None:
        _, ax = plt.subplots(1, 1)

    ylocs = np.arange(len(values))
    ax.barh(ylocs, values, align='center', height=height, **kwargs)

    if show_values is True:
        for x, y in zip(values, ylocs):
            ax.text(x + x/100, y, round(x,2), va='center')

    ax.set_yticks(ylocs)
    ax.set_yticklabels(labels)

    if xlim is not None:
        if not isinstance(xlim, tuple) or len(xlim) != 2:
            raise ValueError('xlim must be a tuple of 2 elements')
    else:
        xlim = (0, max(values) * 1.1)
    ax.set_xlim(xlim)

    if ylim is not None:
        if not isinstance(ylim, tuple) or len(ylim) != 2:
            raise ValueError('ylim must be a tuple of 2 elements')
    else:
        ylim = (-1, len(values))
    ax.set_ylim(ylim)

    if title is not None:
        ax.set_title(title)
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)
    ax.grid(grid)
    return ax

In [None]:
def top_n_SHAP(result, site, year, importance_type = 'Importances', max_num_features = 10, numgraphcol=2):
    
    shap_data = result >> mask(X.site==site) >> mask(X.year==year)    
    #Print top_n SHAP plot
#    importance_type = 'Importances'
    A = (shap_data >> select('Feature', importance_type)).set_index('Feature').to_dict()[importance_type]
    topf_n = sorted(A, key=A.get, reverse=True)[:max_num_features]
    
    plotindex=0
    plt.clf()    
    fltrow = math.ceil(max_num_features/numgraphcol)
    fig = plt.figure(figsize=(9,4.5*fltrow))
    
    for f in topf_n:
        plot_data = shap_data >> mask(X.Feature == f) >> select(X.fval, X.mean_val, X.se_val)
#        plt.figure()    
        plotindex = plotindex+1
        plt.subplot(fltrow, numgraphcol, plotindex)
        plt.scatter(x=plot_data['fval'],y=plot_data['mean_val'])
        plt.errorbar(plot_data['fval'],plot_data['mean_val'], yerr=plot_data['se_val'], fmt="o")
        plt.title(f)
        if plot_data.shape[0] > 2:
            spl = np.polynomial.legendre.Legendre.fit(plot_data['fval'], plot_data['mean_val'],5, full=True)
            [spline_x, spline_y] = spl[0].linspace()
#            plt.plot(spline_x, spline_y)      
            plt.grid()
    plt.show()
    #    plt.savefig('data/'+site+'/model_'+site+'_'+str(year)+'_'+f+'.png')    

In [None]:
def one_feature_SHAP_allyear_allsite(shap_data, feature, sites=None, numgraphcol=2):
    #Print top_n SHAP plot
    shap_data = shap_data >> mask(X.Feature == feature)
    if sites is None:
        sites = shap_data['site'].unique()
    years = shap_data['year'].unique()
    years.sort()
    
    plotindex=0
    plt.clf()    
    fltrow = math.ceil(len(sites)/numgraphcol)
    fig = plt.figure(figsize=(18/numgraphcol,9/numgraphcol*fltrow))

    for site in sites:
        plotindex = plotindex+1
        plt.subplot(fltrow, numgraphcol, plotindex)    
        shap_dataX = shap_data >> mask(X.site == site)
        for yr in years:
            plot_data = shap_dataX >> mask(X.year == yr) >> select(X.fval, X.mean_val, X.se_val)
    #        plt.figure()    
            plt.scatter(x=plot_data['fval'],y=plot_data['mean_val'])
            plt.errorbar(plot_data['fval'],plot_data['mean_val'], yerr=plot_data['se_val'], fmt="o")
            if plot_data.shape[0] > 2:
                spl = np.polynomial.legendre.Legendre.fit(plot_data['fval'], plot_data['mean_val'],5, full=True)
                [spline_x, spline_y] = spl[0].linspace()
    #            plt.plot(spline_x, spline_y)      
        plt.title(site+"_"+feature)   
        plt.grid()
    plt.show()
#    plt.savefig('allsite'+f+'.png')

In [None]:
def one_feature_SHAP(shap_data, feature, site, numgraphcol=2):
    #Print top_n SHAP plot
    shap_dataX = shap_data >> mask(X.site == site) >> mask(X.Feature == feature)
    years = shap_dataX['year'].unique()
    years.sort()
    
    plotindex=0
    plt.clf()    
    fltrow = math.ceil(len(years)/numgraphcol)
    fig = plt.figure(figsize=(9,4.5*fltrow))
        
    for yr in years:
        plot_data = shap_dataX >> mask(X.year == yr) >> select(X.fval, X.mean_val, X.se_val)
#        plt.figure()    
        plotindex = plotindex+1
        plt.subplot(fltrow, numgraphcol, plotindex)
        plt.scatter(x=plot_data['fval'],y=plot_data['mean_val'])
        plt.errorbar(plot_data['fval'],plot_data['mean_val'], yerr=plot_data['se_val'], fmt="o")
        plt.title(site+"_"+feature+"_"+str(yr))
        if plot_data.shape[0] > 2:
            spl = np.polynomial.legendre.Legendre.fit(plot_data['fval'], plot_data['mean_val'],5, full=True)
            [spline_x, spline_y] = spl[0].linspace()
#            plt.plot(spline_x, spline_y)      
            plt.grid()
    plt.show()
    #    plt.savefig('data/'+site+'/model_'+site+'_'+str(year)+'_'+f+'.png')

In [None]:
def model_comparison(model1, model2, stg='stg01', site = '', year='2016', oversample='raw', fs='rmscrbun', rmcol='005'):
    import ipynb.fs.full.postprocessing3_collect
    import importlib
    importlib.reload(ipynb.fs.full.postprocessing3_collect)
    data1 = ipynb.fs.full.postprocessing3_collect.result_split(model1, stg=stg, site =site, year=year, oversample=oversample, fs=fs, rmcol=rmcol, return_result=True)
    data2 = ipynb.fs.full.postprocessing3_collect.result_split(model2, stg=stg, site =site, year=year, oversample=oversample, fs=fs, rmcol=rmcol, return_result=True)    
    
    data1 = list(data1.loc[:, ['site', 'auc']].sort_values('site').to_records(index=False))
    data2 = list(data2.loc[:, ['site', 'auc']].sort_values('site').to_records(index=False))
    labels1, values1 = zip(*data1)
    labels2, values2 = zip(*data2)
    
    if ax is None:
        _, ax = plt.subplots(1, 1)

    ylocs1 = np.arange(len(values1))
    ylocs2 = np.arange(len(values2))    
    ax.barh(ylocs1, values1, align='center', height=height, label=model1, **kwargs)
    ax.barh(ylocs2, values2, align='center', height=height, label=model2, **kwargs)

    if show_values is True:
        for x, y in zip(values1, ylocs1):
            ax.text(x + x/100, y, round(x,2), va='center')
        for x, y in zip(values2, ylocs2):
            ax.text(x + x/100, y, round(x,2), va='center')

    ax.set_yticks(ylocs1)
    ax.set_yticklabels(labels)

    if xlim is not None:
        if not isinstance(xlim, tuple) or len(xlim) != 2:
            raise ValueError('xlim must be a tuple of 2 elements')
    else:
        xlim = (0, max(values) * 1.1)
    ax.set_xlim(xlim)

    if ylim is not None:
        if not isinstance(ylim, tuple) or len(ylim) != 2:
            raise ValueError('ylim must be a tuple of 2 elements')
    else:
        ylim = (-1, len(values))
    ax.set_ylim(ylim)

    if title is not None:
        ax.set_title(title)
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)
    ax.grid(grid)
    return ax    

In [None]:
def top_features(shap_data, importance_type = 'Importances', max_num_features = 10):
#    siteyr = shap_data['siteyr'].unique()
    siteyrlen = shap_data.loc[:,['site', 'year']].drop_duplicates().shape[0]
    #    years.sort()
    rank_table = shap_data.sort_values(['site', 'year', importance_type], ascending=False).loc[:,['site', 'year', 'Feature']].drop_duplicates().groupby(['site', 'year']).head(max_num_features).reset_index(drop=True)
    rank_table.loc[:, 'rank'] = list(range(1,max_num_features+1))*siteyrlen
    rank_table = rank_table.pivot(index=['site', 'year'], columns='rank', values='Feature')
    return rank_table   

In [None]:
importlib.reload(ipynb.fs.full.postprocessing3_collect)

#Load statistics
stg = 'stg23'
fs = 'nofs'
oversample='raw'
model = 'catd'    
rmcol = '005'
#ipynb.fs.full.postprocessing3_collect.result_split(model, stg=stg, site = '', year='', oversample=oversample, fs=fs, rmcol=rmcol, return_result=False)
#ipynb.fs.full.postprocessing3_collect.DEID(model, stg=stg, site = '', year='', oversample=oversample, fs=fs, rmcol=rmcol, return_result=False)
result = pd.read_pickle("/home/hoyinchan/blue/Data/data2021/data2021/"+'DEID_resultsplit_'+model+'_'+stg+'_'+fs+'_'+oversample+'_005.pkl')

In [None]:
#Show top freatures for each site year 
top_features(result, max_num_features=5)

In [None]:
# auc per site/year
plotdata = result
plotdata = plotdata.astype({'year': 'str'})
plotdata = (plotdata>>mutate(Feature=X.site+'_'+X.year)>>select('Feature','auc')).drop_duplicates()
ax = plot_importance(plotdata, importance_type='auc', max_num_features = 10)

In [None]:
# average auc per site
plotdata = result
plotdata = plotdata.astype({'year': 'str'})
plotdata = (plotdata>>mutate(Feature=X.site)>>select('Feature','auc')).drop_duplicates().groupby('Feature').mean().reset_index()
ax = plot_importance(plotdata, importance_type='auc', max_num_features = 10)

In [None]:
site = 'KUMC'
year = 2013
shap_data = result >> mask(X.site==site) >> mask(X.year==year)
#plot feature importance
importance_type = 'Importances'
#importance_type = 'minmax_SHAP'
#importance_type = 'varSHAP'
ax = plot_importance(shap_data, importance_type=importance_type, max_num_features = 10)
#ax.figure.savefig('data/'+site+'/model_'+site+'_'+str(year)+"_feature_"+importance_type+".png")

In [None]:
top_n_SHAP(result, 'KUMC', 2013, importance_type=importance_type, max_num_features = 4, numgraphcol=2)

In [None]:
one_feature_SHAP(result, 'AGE', 'KUMC')

In [None]:
one_feature_SHAP(result, 'AGE', 'MCRI')

In [None]:
one_feature_SHAP_allyear_allsite(result, 'AGE', numgraphcol=2)

In [None]:
one_feature_SHAP_allyear_allsite(result, 'SYSTOLIC')

In [None]:
#Calcium
one_feature_SHAP_allyear_allsite(result, 'LAB::17861-6(mg/dL)')

In [None]:
#Potassium
one_feature_SHAP_allyear_allsite(result, 'LAB::2823-3(mmol/L)')

In [None]:
#Chloride
one_feature_SHAP_allyear_allsite(result, 'LAB::2075-0(mmol/L)')