In [None]:
import pandas as pd
import xgboost as xgb
import numpy as np
import sklearn
from sklearn.preprocessing import OneHotEncoder
from sklearn.impute import SimpleImputer
from imblearn.over_sampling import SMOTE
from imblearn.pipeline import make_pipeline
from imblearn.combine import SMOTEENN
from imblearn.over_sampling import SMOTENC
import matplotlib.pyplot as plt
from PIL import Image

from scipy.interpolate import BSpline, make_interp_spline, interp1d
#import rpy2.robjects as robjects
#from rpy2.robjects.packages import importr
import csv
from dfply import *
from xgboost import XGBClassifier
import itertools
import os
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import StratifiedKFold
import time
import pickle
import math

import importlib
import ipynb.fs.full.postprocessing3_collect

In [None]:
def plot_importance(df, ax=None, height=0.2,
                    xlim=None, ylim=None,
                    xlabel='score', ylabel='Feature', fmap='',
                    importance_type='auc', max_num_features=None,
                    grid=True, show_values=True, **kwargs):

    title = importance_type
    importance = (df >> select(ylabel, importance_type)).set_index(ylabel).to_dict()[importance_type]
    tuples = [(k, importance[k]) for k in importance]
    if max_num_features is not None:
        # pylint: disable=invalid-unary-operand-type
        tuples = sorted(tuples, key=lambda x: x[1])[-max_num_features:]
    else:
        tuples = sorted(tuples, key=lambda x: x[1])
    labels, values = zip(*tuples)
    
    if ax is None:
        _, ax = plt.subplots(1, 1)

    ylocs = np.arange(len(values))
    ax.barh(ylocs, values, align='center', height=height, **kwargs)

    if show_values is True:
        for x, y in zip(values, ylocs):
            ax.text(x + x/100, y, round(x,2), va='center')

    ax.set_yticks(ylocs)
    ax.set_yticklabels(labels)

    if xlim is not None:
        if not isinstance(xlim, tuple) or len(xlim) != 2:
            raise ValueError('xlim must be a tuple of 2 elements')
    else:
        xlim = (0, max(values) * 1.1)
    ax.set_xlim(xlim)

    if ylim is not None:
        if not isinstance(ylim, tuple) or len(ylim) != 2:
            raise ValueError('ylim must be a tuple of 2 elements')
    else:
        ylim = (-1, len(values))
    ax.set_ylim(ylim)

    if title is not None:
        ax.set_title(title)
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)
    ax.grid(grid)
    return ax

In [None]:
def top_n_SHAP(result, site, year, importance_type = 'Importances', max_num_features = 10, numgraphcol=2):
    
    shap_data = result >> mask(X.site==site) >> mask(X.year==year)    
    #Print top_n SHAP plot
#    importance_type = 'Importances'
    A = (shap_data >> select('Feature', importance_type)).set_index('Feature').to_dict()[importance_type]
    topf_n = sorted(A, key=A.get, reverse=True)[:max_num_features]
    
    plotindex=0
    plt.clf()    
    fltrow = math.ceil(max_num_features/numgraphcol)
    fig = plt.figure(figsize=(9,4.5*fltrow))
    
    for f in topf_n:
        plot_data = shap_data >> mask(X.Feature == f) >> select(X.fval, X.mean_val, X.se_val)
#        plt.figure()    
        plotindex = plotindex+1
        plt.subplot(fltrow, numgraphcol, plotindex)
        plt.scatter(x=plot_data['fval'],y=plot_data['mean_val'])
        plt.errorbar(plot_data['fval'],plot_data['mean_val'], yerr=plot_data['se_val'], fmt="o")
        plt.title(f)
        # if plot_data.shape[0] > 2:
        #     spl = np.polynomial.legendre.Legendre.fit(plot_data['fval'], plot_data['mean_val'],5, full=True)
        #     [spline_x, spline_y] = spl[0].linspace()
#            plt.plot(spline_x, spline_y)      
        plt.grid()
    plt.show()
    #    plt.savefig('data/'+site+'/model_'+site+'_'+str(year)+'_'+f+'.png')    

In [None]:
def one_feature_SHAP_allyear_allsite(shap_data, feature, sites=None, numgraphcol=5):
    #Print top_n SHAP plot
    shap_data = shap_data >> mask(X.Feature == feature)
    if sites is None:
        sites = shap_data['site'].unique()
    years = shap_data['year'].unique()
    years.sort()
    
    plotindex=0
    plt.clf()    
    fltrow = math.ceil(len(sites)/numgraphcol)
#    fig = plt.figure(figsize=(18/numgraphcol,9/numgraphcol*fltrow))
    fig = plt.figure(figsize=(22.5,9))

    for site in sites:
        plotindex = plotindex+1
        plt.subplot(fltrow, numgraphcol, plotindex)    
        shap_dataX = shap_data >> mask(X.site == site)
        for yr in years:
            plot_data = shap_dataX >> mask(X.year == yr) >> select(X.fval, X.mean_val, X.se_val)
    #        plt.figure()    
            plt.scatter(x=plot_data['fval'],y=plot_data['mean_val'])
            plt.errorbar(plot_data['fval'],plot_data['mean_val'], yerr=plot_data['se_val'], fmt="o")
#            if plot_data.shape[0] > 2:
#                spl = np.polynomial.legendre.Legendre.fit(plot_data['fval'], plot_data['mean_val'],5, full=True)
#                [spline_x, spline_y] = spl[0].linspace()
#    #            plt.plot(spline_x, spline_y)                 
        plt.title(site+"_"+feature)   
        plt.grid()
        plt.ylim([-0.5, 2])
    plt.show()
    return fig
#    plt.savefig('allsite'+f+'.png')

In [None]:
def one_feature_SHAP(shap_data, feature, site, numgraphcol=2):
    #Print top_n SHAP plot
    shap_dataX = shap_data >> mask(X.site == site) >> mask(X.Feature == feature)
    years = shap_dataX['year'].unique()
    years.sort()
    
    plotindex=0
    plt.clf()    
    fltrow = math.ceil(len(years)/numgraphcol)
    fig = plt.figure(figsize=(9,4.5*fltrow))
        
    for yr in years:
        plot_data = shap_dataX >> mask(X.year == yr) >> select(X.fval, X.mean_val, X.se_val)
#        plt.figure()    
        plotindex = plotindex+1
        plt.subplot(fltrow, numgraphcol, plotindex)
        plt.scatter(x=plot_data['fval'],y=plot_data['mean_val'])
        plt.errorbar(plot_data['fval'],plot_data['mean_val'], yerr=plot_data['se_val'], fmt="o")
        plt.title(site+"_"+feature+"_"+str(yr))
#         if plot_data.shape[0] > 2:
#             spl = np.polynomial.legendre.Legendre.fit(plot_data['fval'], plot_data['mean_val'],5, full=True)
#             [spline_x, spline_y] = spl[0].linspace()
#            plt.plot(spline_x, spline_y)      
        plt.grid()
    plt.show()
    return fig    
    #    plt.savefig('data/'+site+'/model_'+site+'_'+str(year)+'_'+f+'.png')

In [None]:
def zero_feature_SHAP(shap_data, feature, site, yr, vline=[], vlinelabel=[]):     
    fig = plt.figure()    
    cmap = ['r', 'b', 'g', 'y', 'c']
    shap_dataX = shap_data >> mask(X.site == site) >> mask(X.Feature == feature)
    plot_data = shap_dataX >> mask(X.year == yr) >> select(X.fval, X.mean_val, X.se_val)
    plt.scatter(x=plot_data['fval'],y=plot_data['mean_val'])
    plt.errorbar(plot_data['fval'],plot_data['mean_val'], yerr=plot_data['se_val'], fmt="o")
#    plt.vlines(vline, ymin=plot_data['mean_val'].min(), ymax=plot_data['mean_val'].max(), label=vlinelabel, colors=cmap[:len(vlinelabel)])
    for i in range(len(vline)):
        plt.vlines(vline[i], ymin=plot_data['mean_val'].min(), ymax=plot_data['mean_val'].max(), label=vlinelabel[i], colors='r')        
#    plt.legend()
    plt.title(site+"_"+feature+"_"+str(yr))
    plt.grid()
    plt.show()
    return fig
    #    plt.savefig('data/'+site+'/model_'+site+'_'+str(year)+'_'+f+'.png')

In [None]:
def model_comparison(model1, model2, stg='stg01', site = '', year='2016', oversample='raw', fs='rmscrbun', rmcol='005'):
    import ipynb.fs.full.postprocessing3_collect
    import importlib
    importlib.reload(ipynb.fs.full.postprocessing3_collect)
    data1 = ipynb.fs.full.postprocessing3_collect.result_split(model1, stg=stg, site =site, year=year, oversample=oversample, fs=fs, rmcol=rmcol, return_result=True)
    data2 = ipynb.fs.full.postprocessing3_collect.result_split(model2, stg=stg, site =site, year=year, oversample=oversample, fs=fs, rmcol=rmcol, return_result=True)    
    
    data1 = list(data1.loc[:, ['site', 'auc']].sort_values('site').to_records(index=False))
    data2 = list(data2.loc[:, ['site', 'auc']].sort_values('site').to_records(index=False))
    labels1, values1 = zip(*data1)
    labels2, values2 = zip(*data2)
    
    if ax is None:
        _, ax = plt.subplots(1, 1)

    ylocs1 = np.arange(len(values1))
    ylocs2 = np.arange(len(values2))    
    ax.barh(ylocs1, values1, align='center', height=height, label=model1, **kwargs)
    ax.barh(ylocs2, values2, align='center', height=height, label=model2, **kwargs)

    if show_values is True:
        for x, y in zip(values1, ylocs1):
            ax.text(x + x/100, y, round(x,2), va='center')
        for x, y in zip(values2, ylocs2):
            ax.text(x + x/100, y, round(x,2), va='center')

    ax.set_yticks(ylocs1)
    ax.set_yticklabels(labels)

    if xlim is not None:
        if not isinstance(xlim, tuple) or len(xlim) != 2:
            raise ValueError('xlim must be a tuple of 2 elements')
    else:
        xlim = (0, max(values) * 1.1)
    ax.set_xlim(xlim)

    if ylim is not None:
        if not isinstance(ylim, tuple) or len(ylim) != 2:
            raise ValueError('ylim must be a tuple of 2 elements')
    else:
        ylim = (-1, len(values))
    ax.set_ylim(ylim)

    if title is not None:
        ax.set_title(title)
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)
    ax.grid(grid)
    return ax    

In [None]:
def top_features(shap_data, importance_type = 'Importances', max_num_features = 10):
#    siteyr = shap_data['siteyr'].unique()
    siteyrlen = shap_data.loc[:,['site', 'year']].drop_duplicates().shape[0]
    #    years.sort()
    rank_table = shap_data.sort_values(['site', 'year', importance_type], ascending=False).loc[:,['site', 'year', 'Feature']].drop_duplicates().groupby(['site', 'year']).head(max_num_features).reset_index(drop=True)
    rank_table.loc[:, 'rank'] = list(range(1,max_num_features+1))*siteyrlen
    rank_table = rank_table.pivot(index=['site', 'year'], columns='rank', values='Feature')
    return rank_table   

In [None]:
def top_features_no_lab(shap_data, importance_type = 'Importances', max_num_features = 10):
    siteyrlen = shap_data.loc[:,['site', 'year']].drop_duplicates().shape[0]
    shap_dataX = shap_data[['site', 'year','Feature','Importances']].drop_duplicates()
    shap_dataX['irank'] = shap_dataX[['site', 'year','Importances']].drop_duplicates().groupby(['site', 'year']).rank("dense", ascending=False)
    shap_dataX = shap_dataX[~shap_dataX['Feature'].str.contains('LAB')]
    shap_dataX['Feature(Rank)'] = shap_dataX['Feature']+'('+shap_dataX['irank'].astype(str)+')'
    shap_dataX['Dummyrank'] = shap_dataX[['site', 'year', 'irank']].drop_duplicates().groupby(['site', 'year']).rank("dense", ascending=True)
    return shap_dataX[['site', 'year', 'Dummyrank', 'Feature(Rank)']].dropna().pivot(index=['site', 'year'], columns='Dummyrank', values='Feature(Rank)').reset_index().iloc[:,:(max_num_features+2)]

In [None]:
def top_features_med(shap_data, importance_type = 'Importances', max_num_features = 10):
    siteyrlen = shap_data.loc[:,['site', 'year']].drop_duplicates().shape[0]
    shap_dataX = shap_data[['site', 'year','Feature','Importances']].drop_duplicates()
    shap_dataX['irank'] = shap_dataX[['site', 'year','Importances']].drop_duplicates().groupby(['site', 'year']).rank("dense", ascending=False)
    shap_dataX = shap_dataX[shap_dataX['Feature'].str.contains('MED')]
    shap_dataX['Feature(Rank)'] = shap_dataX['Feature']+'('+shap_dataX['irank'].astype(str)+')'
    shap_dataX['Dummyrank'] = shap_dataX[['site', 'year', 'irank']].drop_duplicates().groupby(['site', 'year']).rank("dense", ascending=True)
    return shap_dataX[['site', 'year', 'Dummyrank', 'Feature(Rank)']].dropna().pivot(index=['site', 'year'], columns='Dummyrank', values='Feature(Rank)').reset_index().iloc[:,:(max_num_features+2)]

In [None]:
importlib.reload(ipynb.fs.full.postprocessing3_collect)

#Load statistics
stg = 'stg01'
#fs = 'rmscrbun'
fs = 'nofs'
#fs = 'onlymed'
oversample='raw'
model = 'catd'    
rmcol = '005'
year = '3000'

#ipynb.fs.full.postprocessing3_collect.result_split(model, stg=stg, site = '', year='', oversample=oversample, fs=fs, rmcol=rmcol, return_result=False)

ipynb.fs.full.postprocessing3_collect.DEID(model, stg=stg, site = '', year=year, oversample=oversample, fs=fs, rmcol=rmcol, return_result=False)

#result = pd.read_pickle('DEID_resultsplit_'+model+'_'+stg+'_'+fs+'_'+oversample+'_005.pkl')
result = pd.read_pickle("/home/hoyinchan/blue/Data/data2021/data2021/"+'DEID_resultsplit_'+model+'_'+stg+'_'+year+'_'+fs+'_'+oversample+'_005.pkl')

In [None]:
def lablonic2name(lonic):
    if lonic.split(':')[0] != 'LAB':
        return lonic
    lonic = lonic.split(':')[2].split('(')[0]
    import requests
    r = requests.get('https://loinc.org/'+lonic+'/')
    from bs4 import BeautifulSoup
    soup = BeautifulSoup(r.text, 'html.parser')
    return soup.find('meta', {'property':"og:title"})['content'].split(' ')[0] + '(' + lonic + ')'

In [None]:
def atc2name(ATC):
    ATC_Split = str(ATC).split(':')
    if ATC_Split[0] != 'MED' or ATC_Split[1] != 'ATC':
        return ATC
    rank = ATC.split(':')[2].split('(')[1]    
    ATC = ATC.split(':')[2].split('(')[0]
    if ATC == 'L01XC':
        return 'Monoclonal antibodies(L01XC)'+ '(' + rank
    import requests
    url = 'https://www.whocc.no/atc_ddd_index/?code='+ATC+'&showdescription=no'
    r = requests.get(url)
    from bs4 import BeautifulSoup
    soup = BeautifulSoup(r.text, 'html.parser')
    #return soup.find('meta', {'property':"og:title"})['content'].split(' ')[0] + '(' + lonic + ')'    
    return soup.find_all('a', {'href':"./?code="+ATC+"&showdescription=no"})[0].text + '(' + ATC + ')' + '(' + rank

In [None]:
def atc2name2(ATC):
    try:
        if ATC == 'L01XC':
            return 'Monoclonal antibodies(L01XC)'
        import requests
        url = 'https://www.whocc.no/atc_ddd_index/?code='+ATC+'&showdescription=no'
        r = requests.get(url)
        from bs4 import BeautifulSoup
        soup = BeautifulSoup(r.text, 'html.parser')
        return soup.find_all('a', {'href':"./?code="+ATC+"&showdescription=no"})[0].text
    except:
        return ATC

In [None]:
import matplotlib.pylab as plt
plt.rcParams['figure.dpi'] = 200
myax = plot_importance(result,ylabel='site')

In [None]:
#avg_rank = result[result['site']!='MCRI'][['site','Feature','Importances','rank']].drop_duplicates()[['Feature','rank']].groupby('Feature').sum().reset_index().sort_values('rank').head(10)
max_rank = result[result['site']!='MCRI'][['site','rank']].groupby('site').max().reset_index()
max_rank.columns = ['site','max_rank']
from itertools import product
allsite_faature = pd.DataFrame(list(product(result[result['site']!='MCRI']['site'].unique(), result[result['site']!='MCRI']['Feature'].unique())), columns=['site', 'Feature'])

In [None]:
max_rank

In [None]:
allsite_faature2 = allsite_faature.merge(max_rank, on ='site', how='left').merge(result, on =['site','Feature'], how='left')

In [None]:
allsite_faature2 = allsite_faature2[['site','Feature','rank']]
allsite_faature2 = allsite_faature2.fillna(min(max_rank['max_rank']))

In [None]:
allsite_faature3 = allsite_faature2[['site','Feature','rank',]].drop_duplicates()
allsite_faature3 = allsite_faature3.drop_duplicates()[['Feature','rank']].groupby('Feature').sum().reset_index().sort_values('rank').head(30)
allsite_faature3.index = allsite_faature3['rank']
allsite_faature3 = allsite_faature3.drop('rank',axis=1)
allsite_faature3.iloc[20:30,:]

In [None]:
#Show top freatures for each site year 
ttf = top_features(result, max_num_features=10).reset_index().drop('year',axis=1)
ttf.style.hide_index()

In [None]:
maxmaxfcount = ttf.melt().groupby('value').count().sort_values('rank',ascending=False)
maxffindex = maxmaxfcount[maxmaxfcount['rank']>2].index
siteflist = result[['site', 'Feature']].drop_duplicates()
maxmaxfexi = siteflist[[x in list(maxffindex) for x in siteflist['Feature']]].groupby('Feature').count()
d1 = maxmaxfcount[maxmaxfcount['rank']>3].merge(maxmaxfexi, left_index=True, right_index=True)

In [None]:
d1['name'] = [lablonic2name(x) for x in d1.index]
d1

In [None]:
topf = top_features(result, max_num_features=50).reset_index()
topf = topf[topf['site']=='UTSW']
topf.drop(['site', 'year'],axis=1).stack().unique()

In [None]:
topmed = top_features_med(result, max_num_features=10)
topmed

In [None]:
#topmed.drop('year',axis=1).applymap(lambda x: x.split(':')[-1]).style.hide_index()
topmed.drop('year',axis=1).applymap(lambda x: x.split(':')[-1] if type(x) == str else x).style.hide_index()

In [None]:
[x.split('(')[0] for x in topmed[topmed['site']=='UTSW'].drop(['site','year'],axis=1).stack().unique()]

In [None]:
topmedrank = pd.DataFrame(topmed.drop(['site','year'],axis=1).applymap(lambda x : (x.split('(')[0].split(':')[-1] if float(x.split('(')[-1].split(')')[0]) <= 30 else np.nan) if type(x) == str else x).melt()[['value']]['value'].value_counts()).reset_index()
topmedrank['name'] = topmedrank['index'].apply(atc2name2)
topmedrank['name(ATC)'] = topmedrank['name'] + '(' + topmedrank['index'] + ')'

df_high_corr = pd.read_csv('df_high_corr.csv')
df_low_corr = pd.read_csv('df_low_corr.csv')

df_high_corr['index2'] = [x.split(':')[-1] for x in df_high_corr['index']]
df_low_corr['index2'] = [x.split(':')[-1] for x in df_low_corr['index']]
df_high_corr = df_high_corr[df_high_corr['corr'] > 0].copy()
df_low_corr = df_low_corr[df_low_corr['corr'] > 0].copy()
df_high_corr_count = pd.DataFrame(df_high_corr.groupby('index2').size()).reset_index()
df_high_corr_count.columns = ['index2','count_high']
df_low_corr_count = pd.DataFrame(df_low_corr.groupby('index2').size()).reset_index()
df_low_corr_count.columns = ['index2','count_low']

topmedrankX1 = topmedrank.merge(df_high_corr_count, left_on='index', right_on='index2', how='left').fillna(0).drop('index2',axis=1)
topmedrankX1['count_high'] = topmedrankX1['count_high'].astype(int)
topmedrankX2 = topmedrankX1.merge(df_low_corr_count, left_on='index', right_on='index2', how='left').fillna(0).drop('index2',axis=1)
topmedrankX2['count_low'] = topmedrankX2['count_low'].astype(int)

pd.set_option('display.max_colwidth', None)
topmedrankX2['name(ATC)[high:low]'] = topmedrankX2['name(ATC)'] + '[' + topmedrankX2['count_high'].astype(str) + ':' + topmedrankX2['count_low'].astype(str) + ']'
topmedrankX2['combinestr'] = topmedrankX2[['value', 'name(ATC)[high:low]']].groupby('value').transform(lambda x: ','.join(x))
topmedrankX2[['value', 'combinestr']].drop_duplicates()

In [None]:
topmedrankX2[['value', 'combinestr']].drop_duplicates().style.hide_index()

In [None]:
df_low_corr_avg = df_low_corr[['index2', 'corr']].groupby('index2').mean().reset_index()
df_low_corr_avg.columns = ['index2', 'low_avg_corr']
df_high_corr_avg = df_high_corr[['index2', 'corr']].groupby('index2').mean().reset_index()
df_high_corr_avg.columns = ['index2', 'high_avg_corr']

In [None]:
# auc per site/year
plotdata = result
plotdata = plotdata.astype({'year': 'str'})
plotdata = (plotdata>>mutate(Feature=X.site+'_'+X.year)>>select('Feature','auc')).drop_duplicates()
ax = plot_importance(plotdata, importance_type='auc', max_num_features = 10)

In [None]:
# average auc per site
plotdata = result
plotdata = plotdata.astype({'year': 'str'})
plotdata = (plotdata>>mutate(Feature=X.site)>>select('Feature','auc')).drop_duplicates().groupby('Feature').mean().reset_index()
ax = plot_importance(plotdata, importance_type='auc', max_num_features = 10)

In [None]:
site = 'KUMC'
year = 3000
shap_data = result >> mask(X.site==site) >> mask(X.year==year)
#plot feature importance
importance_type = 'Importances'
#importance_type = 'minmax_SHAP'
#importance_type = 'varSHAP'
ax = plot_importance(shap_data, importance_type=importance_type, max_num_features = 10)
#ax.figure.savefig('data/'+site+'/model_'+site+'_'+str(year)+"_feature_"+importance_type+".png")

In [None]:
top_n_SHAP(result, 'KUMC', 3000, importance_type=importance_type, max_num_features = 4, numgraphcol=2)

In [None]:
one_feature_SHAP(result, 'AGE', 'KUMC')

In [None]:
one_feature_SHAP(result, 'AGE', 'MCRI')

In [None]:
one_feature_SHAP_allyear_allsite(result, 'AGE', numgraphcol=2)

In [None]:
one_feature_SHAP_allyear_allsite(result, 'LAB::2075-0(mmol/L)')

In [None]:
#Calcium
one_feature_SHAP_allyear_allsite(result, 'LAB::17861-6(mg/dL)')

In [None]:
#Potassium
myfig = one_feature_SHAP_allyear_allsite(result, 'LAB::2823-3(mmol/L)')
myfig.savefig("SHAP2823_3_potassium.svg")

In [None]:
#Chloride
one_feature_SHAP_allyear_allsite(result, 'SYSTOLIC')

In [None]:
result2 = result[result['Feature'] == 'AGE']
result2[['site', 'Importances']].drop_duplicates().groupby('site').mean().sort_values('Importances')

In [None]:
result2

In [None]:
one_feature_SHAP_allyear_allsite(result, '2823-3')

In [None]:
one_feature_SHAP(result, 'SYSTOLIC', 'MCRI')

In [None]:
from catboost import CatBoost, Pool
stg = 'stg01'
fs = 'rmscrbun'
oversample='raw'
model_type = 'catd'    
rmcol = '005'
site = 'MCRI'
year = '2011'
suffix=''
year=3000
model = pickle.load(open('data/'+site+'/model_'+model_type+'_'+site+'_'+str(year)+'_'+stg+'_'+fs+'_'+oversample+'.pkl', 'rb'))
X_train = pd.read_pickle('data/'+site+'/X_train_'+site+'_'+str(year)+'_'+stg+'_'+fs+'_'+oversample+suffix+'.pkl')
y_train = pd.read_pickle('data/'+site+'/y_train_'+site+'_'+str(year)+'_'+stg+'_'+fs+'_'+oversample+suffix+'.pkl')

In [None]:
is_cat = (X_train.dtypes == bool)
cat_features_index = np.where(is_cat)[0]

In [None]:
pool = Pool(X_train, y_train, cat_features=cat_features_index, feature_names=list(X_train.columns))

In [None]:
model.save_model('testtree.txt', format="json", export_parameters=None)

In [None]:
tfea = 'SYSTOLIC'

ageidx = np.where(np.array(model.feature_names_) == tfea)[0][0]

In [None]:
import json
f = open('testtree.txt')
tree = json.load(f)
f.close()

In [None]:
for i in range(len(tree['features_info']['float_features'])):
    if tree['features_info']['float_features'][i]['flat_feature_index'] == ageidx:
        print(i)
        ageidx2 = i

In [None]:
#model.plot_tree(tree_idx=38,pool=pool)

In [None]:
#tree['oblivious_trees'][5]
sp0 = X_train['SYSTOLIC']<93.33
sp1 = np.logical_and(X_train['SYSTOLIC']>=93.33, X_train['SYSTOLIC']<=109.25)
sp2 = X_train['SYSTOLIC']>109.25
spt = np.logical_not(np.isnan(X_train['SYSTOLIC']))
p0 = y_train[sp0].sum()/sp0.sum()
p1 = y_train[sp1].sum()/sp1.sum()
p2 = y_train[sp2].sum()/sp2.sum()
pt = y_train[spt].sum()/spt.sum()
print(p0, p1, p2, pt)

In [None]:
#tree['oblivious_trees'][5]
sp0 = X_train['SYSTOLIC']<93.33
sp1 = np.logical_and(X_train['SYSTOLIC']>=93.33, X_train['SYSTOLIC']<=108.25)
sp2 = X_train['SYSTOLIC']>108.25
spt = np.logical_not(np.isnan(X_train['SYSTOLIC']))
p0 = y_train[sp0].sum()/sp0.sum()
p1 = y_train[sp1].sum()/sp1.sum()
p2 = y_train[sp2].sum()/sp2.sum()
pt = y_train[spt].sum()/spt.sum()
print(p0, p1, p2, pt)

In [None]:
#tree['oblivious_trees'][5]
sp0 = X_train['SYSTOLIC']<93.33
sp1 = np.logical_and(X_train['SYSTOLIC']>=108.25, X_train['SYSTOLIC']<=110.25)
sp2 = X_train['SYSTOLIC']>108.25
spt = np.logical_not(np.isnan(X_train['SYSTOLIC']))
p0 = y_train[sp0].sum()/sp0.sum()
p1 = y_train[sp1].sum()/sp1.sum()
p2 = y_train[sp2].sum()/sp2.sum()
pt = y_train[spt].sum()/spt.sum()
print(p0, p1, p2, pt)


In [None]:
#myfig = zero_feature_SHAP(result, tfea, 'MCRI', int(year), vline=vline[:7], vlinelabel=vlinelabel[:7])
myfig = zero_feature_SHAP(result, tfea, 'MCRI', 3000, vline=vline[:7], vlinelabel=vlinelabel[:7])
print(list(zip(vline,vlinelabel))[:7])
myfig.savefig("SHAP_MCRI_2011_overelay2013.svg")

In [None]:
zero_feature_SHAP(result, tfea, 'UIOWA', int(year), vline=vline[:7], vlinelabel=vlinelabel[:7])
print(list(zip(vline,vlinelabel))[:7])

In [None]:
vline = []
vlinelabel = []
rank=0
for i in range(len(tree['oblivious_trees'])):
    for j in range(len(tree['oblivious_trees'][i]['splits'])):
        if 'float_feature_index' in tree['oblivious_trees'][i]['splits'][j].keys():
#            print(tree['oblivious_trees'][i]['splits'][j]['float_feature_index'])
            if tree['oblivious_trees'][i]['splits'][j]['float_feature_index'] == ageidx2:
                print(i, j, tree['oblivious_trees'][i]['splits'][j])
                vline.append(tree['oblivious_trees'][i]['splits'][j]['border'])
                vlinelabel.append(i)