In [None]:
import pandas as pd
import xgboost as xgb
import numpy as np
import sklearn
from sklearn.preprocessing import OneHotEncoder
from sklearn.impute import SimpleImputer
from imblearn.over_sampling import SMOTE
from imblearn.pipeline import make_pipeline
from imblearn.combine import SMOTEENN
from imblearn.over_sampling import SMOTENC
import matplotlib.pyplot as plt
from PIL import Image
from scipy.interpolate import BSpline, make_interp_spline, interp1d
import rpy2.robjects as robjects
from rpy2.robjects.packages import importr
import csv
from dfply import *
from xgboost import XGBClassifier
import itertools
import os
import logging
import time

In [None]:
def drop_too_much_nan(site, year, newdfs, threshold, keep_med=True):
    print('Remove sparse feature on site '+site+":"+str(year), flush = True)                        
    allcols = []
    for newdf in newdfs:
        allcols = allcols + list(newdf.columns)
    allcols = np.unique(np.array(allcols))
    allcols = allcols[allcols != 'FLAG']
    allcols = allcols[allcols != 'PATID']
    allcols = allcols[allcols != 'ENCOUNTERID']

    flag0nan = {key: 0 for key in allcols}
    flag1nan = {key: 0 for key in allcols}
    flag0total = 0
    flag1total = 0

    for newdf in newdfs:
        btX = newdf.replace(False, np.nan)
        flag0total += np.logical_not(btX['FLAG']).sum()
        flag1total += btX['FLAG'].sum()    
        for col in allcols:
            if col in newdf.columns:
                flag0nan[col] += np.logical_and(np.logical_not(btX['FLAG']), np.isnan(btX[col])).sum()
                flag1nan[col] += np.logical_and(btX['FLAG'], np.isnan(btX[col])).sum()
            else:
                flag0nan[col] += np.logical_not(btX['FLAG']).sum()
                flag1nan[col] += btX['FLAG'].sum()
                
    remlist = []        
    for col in allcols:
#        print(col, flag0nan[col]/flag0total, flag1nan[col]/flag1total)        
        if flag0nan[col]/flag0total >= 1-threshold and flag1nan[col]/flag1total >= 1-threshold:
            remlist = remlist + [col]

    if keep_med:
        remlist = [x for x in remlist if 'MED' not in x]
            
    for i in range(len(newdfs)):
        newdfs[i] = newdfs[i].drop(remlist,axis=1, errors='ignore')

    return newdfs, remlist, flag0nan, flag1nan, flag0total, flag1total

In [None]:
def drop_too_much_nan_positive(site, year, newdfs, threshold, keep_med=True):
    print('Remove sparse feature on site '+site+":"+str(year), flush = True)                        
    allcols = []
    for newdf in newdfs:
        allcols = allcols + list(newdf.columns)
    allcols = np.unique(np.array(allcols))
    allcols = allcols[allcols != 'FLAG']
    allcols = allcols[allcols != 'PATID']
    allcols = allcols[allcols != 'ENCOUNTERID']

    flag0nan = {key: 0 for key in allcols}
    flag1nan = {key: 0 for key in allcols}
    flag0total = 0
    flag1total = 0

    for newdf in newdfs:
        btX = newdf.replace(False, np.nan)
        flag0total += np.logical_not(btX['FLAG']).sum()
        flag1total += btX['FLAG'].sum()    
        for col in allcols:
            if col in newdf.columns:
#                flag0nan[col] += np.logical_and(np.logical_not(btX['FLAG']), np.isnan(btX[col])).sum()
                flag1nan[col] += np.logical_and(btX['FLAG'], np.isnan(btX[col])).sum()
            else:
#                flag0nan[col] += np.logical_not(btX['FLAG']).sum()
                flag1nan[col] += btX['FLAG'].sum()
                
    remlist = []        
    for col in allcols:
#        print(col, flag0nan[col]/flag0total, flag1nan[col]/flag1total)        
#        if flag0nan[col]/flag0total >= 1-threshold and flag1nan[col]/flag1total >= 1-threshold:
        if flag1nan[col]/flag1total >= 1-threshold:
            remlist = remlist + [col]

    if keep_med:
        remlist = [x for x in remlist if 'MED' not in x]
            
    for i in range(len(newdfs)):
        newdfs[i] = newdfs[i].drop(remlist,axis=1, errors='ignore')

    return newdfs, remlist, flag0nan, flag1nan, flag0total, flag1total

In [None]:
def bt_ckd(site, year, newdf):                                    
    #lab_num
    print('Merging ckd_info on site '+site+":"+str(year), flush = True)                
    try:
        efgr2 = pd.read_pickle("/home/hchan2/AKI/AKI_Python/"+'data/'+site+'/p0_'+'ckdgroup'+'_'+site+'.pkl')
        return pd.merge(newdf, efgr2, left_on=['PATID', 'ENCOUNTERID'], right_on=['PATID', 'ENCOUNTERID'], how='left')
    except FileNotFoundError:
        logging.basicConfig(filename='BT.log', filemode='a')    
        print('No efgr table!!!!! '+site+":"+str(year), flush = True)
        logging.error('No efgr table!!!!! '+site+":"+str(year))
        logging.shutdown()
        return newdf

In [None]:
def bt_postprocess(site, year, newdf):
    print('Finishing on site '+site+":"+str(year), flush = True)                    
    newdf = newdf.drop(['PATID', 'ENCOUNTERID', 'AKI1_SINCE_ADMIT', 'SINCE_ADMIT', 'DAYS_SINCE_ADMIT','DAYS_SINCE_ADMIT_x'],axis=1, errors='ignore')
    newdf.columns=newdf.columns.str.replace('<','st')
    newdf.columns=newdf.columns.str.replace('>','bt')
    newdf.columns=newdf.columns.str.replace('[','lb')
    newdf.columns=newdf.columns.str.replace(']','rb')   
    return newdf.dropna(axis=1, how='all')

#    newdf_debug['drop'] = newdf.copy()

In [None]:
def flag_convert(dataX, stg):
    data = dataX.copy()
    
    if stg == 'stg23':
        data = data[data['FLAG']!=1]
        data['FLAG'] = (data['FLAG']>1)*1
        return data
    
    if stg == 'stg010':
        data = data[data['FLAG']!=2]
        data = data[data['FLAG']!=3]
        return data
    
    if stg == 'stg123':
        data = data[data['FLAG']!=0]
        
    if stg == 'stg01':
        data['FLAG'] = (data['FLAG']>0)*1
    else:
        data['FLAG'] = (data['FLAG']>1)*1    

    return data

In [None]:
def filter_subsample(df, site, yearX, stg, frac=1000):
    bt_sam = pd.read_pickle('data/'+site+'/bt3sample_'+site+'_'+str(yearX)+"_"+stg+'_'+str(frac)+'.pkl')
    return bt_sam.merge(df, left_on=['PATID', 'ENCOUNTERID'], right_on=['PATID', 'ENCOUNTERID'], how='inner')

In [None]:
def generate_subsample(site, yearX, stg, frac=1000):
    print('Processing '+site+str(frac)+stg)
    yearX = 3000
    
    onset = pd.read_pickle("/home/hchan2/AKI/AKI_Python/"+'data/'+site+'/p0_onset_'+site+'.pkl')
    years = list(pd.to_datetime(onset['ADMIT_DATE']).dt.year.unique())    
    bt_list = list()

    for year in years:
        try:
            data = pd.read_pickle('data/'+site+'/bt3_'+site+'_'+str(year)+'.pkl')
            data = flag_convert(data, stg)
            data = data[['PATID','ENCOUNTERID', 'FLAG']]
            bt_list.append(data.copy())
        except:
            print(str(year)+' not exists')

    bt_all = pd.concat(bt_list, ignore_index=True).drop_duplicates()

    s1 = frac/bt_all.shape[0]
    if s1 < 1:
        bt_sam = bt_all.groupby('FLAG', group_keys=False).apply(lambda x: x.sample(frac=s1)).reset_index(drop=True)
    else:
        bt_sam = bt_all.copy()
    bt_sam = bt_sam.drop('FLAG',axis=1)

    bt_sam.to_pickle('data/'+site+'/bt3sample_'+site+'_'+str(yearX)+"_"+stg+'_'+str(frac)+'.pkl')

In [None]:
def combinebt(site, yearX, stg, frac=50000):

    print('Combine bt on site '+site+":"+str(frac), flush = True)        
       
    onset = pd.read_pickle("/home/hchan2/AKI/AKI_Python/"+'data/'+site+'/p0_onset_'+site+'.pkl')
    years = list(pd.to_datetime(onset['ADMIT_DATE']).dt.year.unique())    
    bt_list = list()
    common_feature = pd.read_pickle('common_feature.pkl')['Feature']
    
    for year in years:
        try:
#        if True:
            data = pd.read_pickle('data/'+site+'/bt3_'+site+'_'+str(year)+'.pkl')
            data = flag_convert(data, stg)
            data = data[data.columns.intersection(common_feature)]
            data = filter_subsample(data, site, yearX, stg, frac)            
            bt_list.append(data.copy())
        except:
            print(str(year)+' not exists')

#    bt_list, remlist, flag0nan, flag1nan, flag0total, flag1total = drop_too_much_nan(site, yearX, bt_list, threshold)
#    return bt_list, remlist, flag0nan, flag1nan, flag0total, flag1total
    bt_all = pd.concat(bt_list, ignore_index=True)
    
    # replace nan in boolean columns with False
    bt_bool = bt_all.select_dtypes('O').columns
    bt_all[bt_bool] = bt_all[bt_bool].fillna(False)

#    bt_all = bt_ckd(site, yearX, bt_all)
    bt_all = bt_postprocess(site, yearX, bt_all)
    bt_all.drop_duplicates().to_pickle('data/'+site+'/bt3_'+site+'_'+stg+'_3000_'+str(frac)+'.pkl')
        

In [None]:
def combinebtpos(site, yearX, stg, threshold=0.01):
    
    onset = pd.read_pickle("/home/hchan2/AKI/AKI_Python/"+'data/'+site+'/p0_onset_'+site+'.pkl')
    years = list(pd.to_datetime(onset['ADMIT_DATE']).dt.year.unique())    
    bt_list = list()

    for year in years:
        try:
            data = pd.read_pickle('data/'+site+'/bt3_'+site+'_'+str(year)+'.pkl')
            data = flag_convert(data, stg)
            bt_list.append(data.copy())
        except:
            print(str(year)+' not exists')
            
#    bt_list, remlist, flag0nan, flag1nan, flag0total, flag1total = drop_too_much_nan(site, yearX, bt_list, threshold)
    bt_list, remlist, flag0nan, flag1nan, flag0total, flag1total = drop_too_much_nan_positive(site, yearX, bt_list, threshold)
#    return bt_list, remlist, flag0nan, flag1nan, flag0total, flag1total
    bt_all = pd.concat(bt_list, ignore_index=True)
    # replace nan in boolean columns with False
    bt_bool = bt_all.select_dtypes('O').columns
    bt_all[bt_bool] = bt_all[bt_bool].fillna(False)

    bt_all = bt_ckd(site, yearX, bt_all)
    bt_all = bt_postprocess(site, yearX, bt_all)
    bt_all.to_pickle('data/'+site+'/bt3pos_'+site+'_'+stg+'_3000.pkl')

In [None]:
#def calculate_sparse(sites):
if __name__ == "__main__":
    sites = ['MCRI', 'MCW', 'UIOWA', 'UMHC', 'UNMC', 'UofU', 'UPITT', 'UTHSCSA', 'KUMC', 'UTSW']
    sparsity = dict()
    cols = dict()
    shape_list = dict()
    for site in sites:
        print('Processing '+site)
        newdf = pd.read_pickle("/home/hchan2/AKI/AKI_Python/"+'data/'+site+'/bt3_'+site+'_3000.pkl')
        btX = newdf.replace(False, np.nan) 
        btX_objcol = btX.select_dtypes('O')
#        bt_list.append(data.copy())
        cols[site] = btX.columns
        sp_site = dict()
        for col in cols[site]:
            if btX[col].dtype != 'O':
#                sp_site[col] = (np.isnan(btX[col])).sum()/btX.shape[0]
                sp_site[col] = np.logical_and(btX['FLAG']==1,np.isnan(btX[col])).sum()/(btX['FLAG']==1).sum()
        sparsity[site] = sp_site
        shape_list[site] = btX.shape[0]

In [None]:
if __name__ == "__main__":
    allcols = np.unique([y for x in list(cols.values()) for y in list(x)])
    var_list = dict()
    for col in allcols:
        sp_col = list()
        for sp in sparsity.values():
            if col in sp:
                sp_col = sp_col + [sp[col]]
            else:
                sp_col = sp_col + [1]
        var_list[col] = np.var(sp_col)
    var_list = pd.DataFrame(var_list, index=[0]).T.sort_values(by=0)
    var_list['rank'] = var_list.rank()
    var_list = var_list.reset_index()
    var_list.to_pickle('spdf1.pkl')        
#    return var_list

In [None]:
if __name__ == "__main__":
    sites = ['MCRI', 'MCW', 'UIOWA', 'UMHC', 'UNMC', 'UofU', 'UPITT', 'UTHSCSA', 'KUMC', 'UTSW']
    var_list = calculate_sparse(sites)

In [None]:
# site = 'KUMC'
# yearX = 3000
# frac = 50000
# combinebt(site, 3000, 'stg01', 50000)

In [None]:
# bt_sam = pd.read_pickle('data/'+site+'/bt3sample_'+site+'_'+str(yearX)+"_"+stg+'_'+str(frac)+'.pkl')

# bt_sam.shape[0]

In [None]:
# bt_allX = pd.read_pickle('data/'+site+'/bt3_'+site+'_'+'stg01'+'_3000_'+str(50000)+'.pkl')

In [None]:
# bt_allX.shape[0]