In [None]:
from utils import generate_suffix_dict, generate_prefix_dict, to_json
from train_test_split import train_val_test
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from time import time
from scipy.stats import linregress
from IPython.core.interactiveshell import InteractiveShell

proc_path = 'data/processed/'

encounters = pd.read_pickle(proc_path + 'encounters.pickle')

def merge_encounters(df):
    '''
    
    Resulting dataframe is (656726, 29) with index (mrn, id);
    time is not standardized.
    '''

    
    # Encounters contains 6500 unique encounters
    enc = (
        encounters[encounters.sbo_poa]
        [['adm_datetime', 'any_sbo_surg', 'age',
          'surg_datetime', ]]
        .add_suffix('_enc')
    )
    merged_df = pd.merge(
        enc, df.reset_index(level=2),
        left_index=True, right_index=True,
        how='inner'
    )
    merged_df['any_sbo_surg_enc'] = merged_df['any_sbo_surg_enc'].astype(int)

    return merged_df

def standardize_times(df):

    start = time()
    df['min_datetime'] = (
        df.datetime
        .groupby(level=[0,1])
        .transform(lambda x: x.min())
    )
    to_hour = lambda td: td.total_seconds() // 3600
    df['hour_since_adm'] = (
        (df.datetime - df.min_datetime)
        .transform(to_hour)
        .astype(int)
    )
    
    nonsurg_cutoff = 21*24
    df['max_datetime_hour'] = (
        df.hour_since_adm
        .groupby(level=[0,1])
        .transform(lambda x: min(nonsurg_cutoff, x.max()))
    )
    
    df['event_hour_enc'] = (
        (df.surg_datetime_enc - df.min_datetime)
        .transform(to_hour)
        .fillna(df['max_datetime_hour'])
    )
    
    df['time_to_event_enc'] = (
        df.event_hour_enc - df.hour_since_adm - 1
    )
    
    df['time_of_day'] = (
        df.hour_since_adm + df.min_datetime.apply(lambda dt: dt.hour)
    )
    
    df = filter_populations(df, surg_cutoff=21*24)
    
    # Filter out measurements after surgery
    df = df[df.hour_since_adm < df.event_hour_enc]
    
    df = df.drop(['datetime', 'min_datetime', 
                  'max_datetime_hour', 'surg_datetime_enc',
                 'adm_datetime_enc', 'event_hour_enc'], 1)
    
    col_dict = generate_suffix_dict(df)
    
    mean_columns = (
        df.reset_index()
        .loc[:, col_dict['vitals'] + col_dict['labs'] + ['mrn', 'id','hour_since_adm']]
        .groupby(['mrn', 'id','hour_since_adm'])
        .agg('mean')
    )
    
    sum_columns = (
        df.reset_index().loc[:, col_dict['io'] + col_dict['occ'] + ['mrn', 'id','hour_since_adm']]
        .groupby(['mrn', 'id','hour_since_adm'])
        .agg('sum')
    )
    
    enc = (
        df.reset_index()[col_dict['enc'] + ['mrn', 'id','hour_since_adm']]
        .groupby(['mrn', 'id','hour_since_adm'])
        .agg('max')
    )


    print(f'Finished standardizing times in {round(time()-start)}s')
    return pd.concat([enc, mean_columns, sum_columns], axis=1)

def filter_populations(df, surg_cutoff):
    ''' '''
    print('\nBefore')
    print("Number of patients: " + str(df.reset_index().mrn.nunique()))
    print('Number of encounters: ' + str(df.reset_index()['id'].nunique()))
    print(df.reset_index()[['id', 'any_sbo_surg_enc']].drop_duplicates().any_sbo_surg_enc
                .value_counts().transform(lambda x: x/x.sum()))

    df = (df.groupby(level=[0,1])
          .filter(
              lambda df: not ((df.event_hour_enc > surg_cutoff) & (df.any_sbo_surg_enc == 1)).all()
    ))
    
    print('\nAfter')
    print("Number of patients: " + str(df.reset_index().mrn.nunique()))
    print('Number of  encounters: '+ str(df.reset_index()['id'].nunique()))
    print(df.reset_index()[['id', 'any_sbo_surg_enc']].drop_duplicates().any_sbo_surg_enc
            .value_counts().transform(lambda x: x/x.sum()))
        
    return df

def fix_occurrences(df):
    col_dict = generate_suffix_dict(df)
    df.loc[:, col_dict['occ']] = (
        df[col_dict['occ']]            
        .fillna(0).astype(bool).astype(int)
    )
    return df

def create_rolling_matrix(data, window):
    """
    Return rolling matrix of shape (m, n, window).
    Pad with zeros so that we don't shrink.
    """
    n, m = data.shape
    buffer = np.zeros((window-1, m))
    s2 = np.concatenate([buffer, data])
    rolling_matrix = np.empty((m, n, window))
    for i in range(n):
        start = i
        end   = i + window
        # (m,0,window) <- (window,m).T
        rolling_matrix[:,i,:] = s2[start:end].T
    return rolling_matrix

def expsum(data, zerolifes):
    """
    Vectorized form of exponentially weighted sum.
    We truncate the exponential sum when the weight
    reaches 0.05 and we parameterize by the number of
    steps to get there, the "zerolife". 
    
    Params
    ------
    data (pd.DataFrame) -- df of size (n,m)
    zerolifes (list) -- list of length z,
        a zerolife of 3 will have a window_len of 3+1
    
    Return dataframe of size (n, m*z)
    """
    z = len(zerolifes)
    n, m = data.shape
    # Take max of zerolifes + 1 to get window length
    max_zerolife = max(zerolifes)
    window_len   = max(zerolifes) + 1
    rolling_matrix = create_rolling_matrix(data.values, window_len)
    
    # Create matrix of shape (window_len, z)
    # Buffer with 0 to accommodate diff win lengths  
    expvals = np.empty((window_len, z))
    for zi, zerolife in enumerate(zerolifes):
        buffer_len = max_zerolife - zerolife
        expvals[:,zi] = np.array(
            [0] * buffer_len  
            + [ 0.05**(t/zerolife) for t in np.arange(zerolife + 1) ][::-1]
        )

    # (m, n, window_len) @ (window_len, z) -> (m, n, z)
    expsum = rolling_matrix @ expvals
    # Return (n, m*z), flatten multiple computations to 2D
    df = pd.DataFrame(np.swapaxes(expsum, 0, 1).reshape((n, m*z)), index=data.index)
    df.columns = [f'ems{zerolife}_' + col  for zerolife in zerolifes for col in data]
    return df
    
def summary_stats(df):

    col_dict = generate_suffix_dict(df)

    max_hour = df.reset_index().hour_since_adm.max()
    
    df_copy = df.copy()
    curr = df_copy[col_dict['vitals'] + col_dict['labs'] + col_dict['io'] + col_dict['occ']].add_prefix('curr_')
    enc = df_copy[col_dict['enc']]
    
    df = df.reset_index(level=[0,1], drop=True).reindex(np.arange(max_hour + 1))
    
        
    # Optimized tsl
    io_df  = df[col_dict['io'] + col_dict['occ']]
    num_df = df[col_dict['vitals'] + col_dict['labs']]
    io_nan_mask = (io_df.fillna(0) == 0).astype(int).values
    num_nan_mask = num_df.isna().astype(int).values
    io_tsl_arr  = np.zeros((io_df.shape[0], io_df.shape[1]))
    num_tsl_arr = np.zeros((num_df.shape[0], num_df.shape[1]))
    for i in range(io_df.shape[0]):
        io_tsl_arr[i,:]  = (1 + io_tsl_arr[i-1,:])*io_nan_mask[i,:]
        num_tsl_arr[i,:] = (1 + num_tsl_arr[i-1,:])*num_nan_mask[i,:]
        
    io_df.loc[:,:] = io_tsl_arr
    num_df.loc[:,:] = num_tsl_arr
    
    
    ems = pd.concat(
        (expsum(df[col_dict['io'] + col_dict['occ']].fillna(0), zerolifes = [z] ) 
        for z in [6,24,72]), 
        axis=1
    )
    
    ema_vitals = pd.concat(
        (df[col_dict['vitals']].ewm(halflife=halflife).mean().add_prefix(f'ema{halflife}_')
         for halflife in [6, 24, 72]), 
        axis=1
    )
    
    ema_labs = pd.concat(
        (df[col_dict['labs']].ewm(halflife=halflife).mean().add_prefix(f'ema{halflife}_') 
         for halflife in [12, 48, 144]), 
        axis=1
    )
    
    # Add variance?
    
    return pd.merge(
        pd.concat([enc, curr],axis=1), 
        pd.concat([ema_vitals, ema_labs, ems, 
                   num_df.add_prefix('tsl_'),
                   io_df.add_prefix('tsl_')], axis=1),
        left_index=True, right_index=True, how='left'
    )
    
def fill_na(df, train_means):
    """
    PROTOCOL
    Curr - ffill then fill with training set means
    EMA - mean (should only be at beginning of encounter)
    EMS - 0
    """
    df = df.copy()
    pref_dict = generate_prefix_dict(df)
    
    df.loc[:,pref_dict['curr']] = (
        df[pref_dict['curr']]
        .fillna(method='ffill')
        .fillna(train_means.loc[pref_dict['curr']])
    )
    
    df.loc[:,pref_dict['ema']] = (
        df[pref_dict['ema']]
        .fillna(train_means.loc[pref_dict['ema']])
    )
    
    return df
    


def pipe_print(df,stage):
    print('\n')
    print(f'>>> PIPE PRINT: {stage}')
    print('Shape: ' + str(df.shape))
    print('Number of Patients: ' + str(df.reset_index().mrn.nunique()))
    print('Number of Encounters: '+ str(df.reset_index().id.nunique()))
    return df

def preprocess_exp_weights(
        rebuild=False, 
        testing=False, 
        time_to_event=False, 
        time_varying=False):
    '''
    rebuild -- Rebuild the dataframe and save it
    testing -- Use less data, less columns, and don't overwrite saved files
    time_to_event -- Returns dataset with surgery indicator, as well as 
                     surg_datetime_hour_enc indicator
    time_varying -- Returns dataset with surgery indicator only for time of surgery
    ''' 
    
    # Print out flag values
    print('FLAGS')
    print('Rebuild: ' + str(rebuild))
    print('Testing: ' + str(testing))
    print('Time to Event: ' + str(time_to_event))
    print('Time Varying: ' + str(time_varying))

    start = time()
    
    if rebuild:
        print('\n>>> Rebuilding file...')

        if testing:
            sbo = (pd.read_pickle(proc_path + 'sbo_mini.pickle')
                   [['pulse_vitals', 'sodium_labs', 
                     'tube_output_io', 'stool_occ']])
            epsilon = 0.05
        else:
            sbo = (pd.read_pickle(proc_path + 'sbo.pickle')
                   .drop(['drain_output_io', 'ip_blood_administration_volume_io',
                      'maintenance_iv_bolus_volume_io', 'rectal_tube_output_io'], 1))
            epsilon = 0.01
            
        # Initial processing
        sbo_presummary = (sbo
            .pipe(pipe_print, 'Merge Encounters')
            .pipe(merge_encounters)
            .pipe(pipe_print, 'Standardize Times')
            .pipe(standardize_times)
            .pipe(pipe_print, 'Summary Stats')
            .pipe(fix_occurrences)
        )
        
        #return sbo_presummary
        # Write out file before doing summary stats
        if not testing:
            sbo_presummary.to_pickle('data/processed/sbo_exp_presumm_full.pickle')
        else:
            sbo_presummary.to_pickle('data/processed/sbo_exp_presumm_mini_full.pickle')
            
        # Main processing
        sbo_exp_weights = (
            sbo_presummary
            .groupby(level=[0,1])
            .apply(lambda df: summary_stats(df))
        )
        
        print('\n>>> Train/val/test Split')
        train, val, test, idx_dict = train_val_test(sbo_exp_weights, epsilon)
        if not testing:
            to_json(idx_dict, 'references/idx_dict_exp_weights.json')
            
        train_means = train.mean(axis=0)
        train = fill_na(train, train_means)
        val   = fill_na(val, train_means)
        test  = fill_na(test, train_means)
        
        # Scale data
        train, val, test = scale(train, val, test)
        
        # Write out preprocessed dataframe for visualization
        if testing:
            sbo_exp_weights.to_pickle('data/processed/sbo_exp_weights_mini.pickle')
            train.to_pickle('data/processed/train_exp_weights_mini.pickle')
            val.to_pickle('data/processed/val_exp_weights_mini.pickle')
            test.to_pickle('data/processed/test_exp_weights_mini.pickle')
        else:
            sbo_exp_weights.to_pickle('data/processed/sbo_exp_weights.pickle')
            train.to_pickle('data/processed/train_exp_weights.pickle')
            val.to_pickle('data/processed/val_exp_weights.pickle')
            test.to_pickle('data/processed/test_exp_weights.pickle')
    else:
        print('\n>>> Reading in file...')
        if testing:
            train = pd.read_pickle('data/processed/train_exp_weights_mini.pickle')
            val   = pd.read_pickle('data/processed/val_exp_weights_mini.pickle')
            test  = pd.read_pickle('data/processed/test_exp_weights_mini.pickle')
        else:
            train = pd.read_pickle('data/processed/train_exp_weights.pickle')
            val   = pd.read_pickle('data/processed/val_exp_weights.pickle')
            test  = pd.read_pickle('data/processed/test_exp_weights.pickle')  
    # TODO: Decide if filtering out 90+ 
    #print('FILTERING OUT 90+')
    #train48, val48, test48 = (df[df.age_enc < 90] for df in [train48, val48, test48])
    # Depending on goal, return different columns
    
    train48['surg_enc'] = (train48['any_sbo_surg_enc'].astype(bool) & 
                          (train48['time_to_event_enc'] == 0)).astype(int)
    val48['surg_enc']   = (val48['any_sbo_surg_enc'].astype(bool) & 
                          (val48['time_to_event_enc'] == 0)).astype(int)
    test48['surg_enc']  = (test48['any_sbo_surg_enc'].astype(bool) & 
                          (test48['time_to_event_enc'] == 0)).astype(int)
    
    if time_to_event:
        cols_to_drop = ['any_sbo_surg_enc', 'surg_enc']
    else:
        cols_to_drop = ['any_sbo_surg_enc', 'surg_enc',
                        'time_to_event_enc'
                        ]
                        
    if time_varying:
        surg_label = 'surg_enc'
    else:
        surg_label = 'any_sbo_surg_enc'
    
    # TODO: Decide if need to fill na for now because some values not in val set
    x_train = train.drop(cols_to_drop, 1)
    x_val   =   val.drop(cols_to_drop, 1)
    x_test  =  test.drop(cols_to_drop, 1)
    
    y_train = train[surg_label]
    y_val   =   val[surg_label]
    y_test  =  test[surg_label]
    x_cols = list(train48.drop(cols_to_drop,1).columns)
    print('\n\n')
    print('Finished processing in ' + str(round(time()-start)))
    return x_train, y_train, x_val, y_val, x_test, y_test, x_cols

if __name__ == '__main__':
    x_train,y_train,x_val,y_val,x_test,y_test,x_cols = preprocess_exp_weights(
        rebuild=True, 
        testing=True
    )
    #sbo_test = preprocess(rebuild=False, testing=True)
    
    