In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
import pandas as pd
import numpy as np
import tensorflow as tf
import random
import pickle
import copy

import rtdl
sys.modules['rtdl.rtdl'] = rtdl

from explain_icu import TheWrapper
import matplotlib.pyplot as plt

In [None]:
def load_explanations(folder_list, ftype=None, subset=None):
    name_map = {
        'wshap': 'WindowSHAP',
        'anch': 'Anchors',
        'comte': 'CoMTE',
        'featpert': 'FeaturePerturbation',
        'manualperm': 'ManualPermutation',
        'tsr': 'TSR',
    }
    mapped_name = name_map[ftype]
    
    trial_run_exps = {}
    for folder_name in folder_list:
        collected_exps = []
        for exp_num, exp_i in enumerate(subset):
            
            exp_package = pickle.load(open(f"{folder_name}/{mapped_name}_encounter_1000000{exp_i}.0explanation.pkl", "rb"))
            
            if ftype == 'wshap':
                raw_exp = exp_package['wshap']
            elif ftype == 'anch':
                raw_exp = exp_package[0]
            elif ftype == 'comte':
                raw_exp = exp_package[0]
                print(f"comte 01: {exp_package[1]}")
            elif ftype == 'featpert':
                raw_exp = exp_package  
            elif ftype == 'manualperm':
                raw_exp = exp_package 
            elif ftype == 'tsr':
                raw_exp = exp_package
                
            if ftype in ['anch']:
                print(f"anch exp: {raw_exp}")
            else:
                raw_exp = raw_exp[:,:,USED_FEATURE_LOCS]
                raw_exp = raw_exp.squeeze()
                print(f"{ftype} exp shape: {raw_exp.shape}")
            
            collected_exps.append(raw_exp)
        trial_run_exps[folder_name] = collected_exps
    return trial_run_exps
             

In [None]:
def make_cases_into_pdframes(EXP_SETS):
    all_dfs = []
    for exp_set in EXP_SETS:
        for case_id, exp_idx in enumerate(range(len(exp_set))):
            correct_size_data = exp_set[exp_idx].squeeze()
            if correct_size_data.shape[-1] == 57:
                correct_size_data = np.column_stack(([case_id for i in range(correct_size_data.shape[0])], correct_size_data))
                correct_size_data = np.column_stack((correct_size_data, [0 for i in range(correct_size_data.shape[0])]))
            elif correct_size_data.shape[-1] == 59:
                pass
            else:
                raise NotImplementedError(f"Size {correct_size_data.shape[-1]} not supported")

            new_df = pd.DataFrame(correct_size_data, columns=col_names)
            all_dfs.append(new_df)
    return all_dfs
            

In [None]:
def get_different_feats(list_of_exps, list_of_origs, tolerance=0.0000001, subset=None):
    if list_of_exps[0].shape[-1] == 57 and list_of_origs[0].shape[-1] == 59:
        list_of_origs = [x.iloc[:, 1:-1] for x in list_of_origs]
        list_of_exps = [x.squeeze() for x in list_of_exps]
    elif list_of_exps[0].shape[-1] == 57 and list_of_origs[0].shape[-1] == 57:
        pass
    else:
        raise NotImplementedError("fail 01")
       
    
    all_changed_locs = []
    for i, exp_data in enumerate(list_of_exps):
        matching_orig = list_of_origs[subset[i]]
        print(f"exp_data: {exp_data.shape} and matching_orig: {matching_orig.shape}")
        by_feat_diffs = np.abs(exp_data - matching_orig).sum(axis=0)
        changed_locs_per_exp = []
        for f_idx, d_val in enumerate(by_feat_diffs):
            if d_val > tolerance:
                changed_locs_per_exp.append(f_idx)
        all_changed_locs.append(changed_locs_per_exp)
    return all_changed_locs
    


In [None]:
def make_minmax_dict(exps_for_inst, orig_data, all_feature_names, vertical_space=0.1):
    feature_mins = np.expand_dims(orig_data.min(axis=0), axis=1)
    feature_maxs = np.expand_dims(orig_data.max(axis=0), axis=1)
    
    for explanation_of_orig in exps_for_inst:
        exp_mins = np.expand_dims(explanation_of_orig.min(axis=0), axis=1)
        exp_maxs = np.expand_dims(explanation_of_orig.max(axis=0), axis=1)
        
        combo_maxes = np.concatenate([feature_maxs, exp_maxs], axis=1)
        feature_maxs = combo_maxes.max(axis=1)
        
        combo_mins = np.concatenate([feature_mins, exp_mins], axis=1)
        feature_mins = combo_mins.min(axis=1)
        
    #add space
    changes = feature_maxs - feature_mins
    changes_scaled = changes*vertical_space
    feature_mins = feature_mins - changes_scaled
    feature_maxs = feature_maxs + changes_scaled
        
    min_max_dict = {}
    for i in range(len(feature_mins)):
        f_min = feature_mins[i]
        f_max = feature_maxs[i]
        f_name = all_feature_names[i]
        
        min_max_dict[f_name] = (f_min, f_max)
    return min_max_dict
        

In [None]:
def visualise_counterfactual(cf_data, orig_data, changed_idxs, all_feature_names, 
                             explanation_output_folder, image_name_prefix, minmax_dict, 
                             n_img_horizontal=1, timepoint_tolerance=0.06):
    assert len(all_feature_names) == cf_data.shape[-1] == orig_data.shape[-1], \
    f"Shapes dont align; \n"
    f"len(all_feature_names)={len(all_feature_names)} \n"
    f"cf_data.shape[-1]={cf_data.shape[-1]} \n"
    f"orig_data.shape[-1]={orig_data.shape[-1]} \n"
    
    ##########
    n_img_vertical = len(changed_idxs) // n_img_horizontal
    remainder = len(changed_idxs) % n_img_horizontal
    if remainder != 0:
        n_img_vertical + 1
        
    fig_height = n_img_vertical * 2
    fig_width = n_img_horizontal * 8
    print(f"fig_height: {fig_height} fig_width: {fig_width}")
    figure, ax = plt.subplots(n_img_vertical, n_img_horizontal, layout='constrained', 
                              figsize=(fig_width, fig_height))
    if remainder > 0:
        extra_cols_n = n_img_horizontal - remainder
        for extra_col in range(remainder, remainder+extra_cols_n):
            figure.delaxes(ax[n_img_vertical-1, extra_col])
        
    #########
    timepoints = orig_data[:, 0]
    max_time = timepoints.max()
    time_diff_req = max_time * timepoint_tolerance
    
    time_tick_lbls = [f"" for x in timepoints]
    most_recent_time = -100
    for idx, time_val in enumerate(timepoints):
        diff_cur_prev = time_val - most_recent_time
        if diff_cur_prev > time_diff_req:
            time_tick_lbls[idx] = round(time_val, 1)
            most_recent_time = round(time_val, 1)
    
    #########
    for plot_num, change_loc in enumerate(changed_idxs):
        f_name = all_feature_names[change_loc]
        display_name = true_display_names[f_name]
        feat_cf_data = cf_data[:, change_loc]
        feat_og_data = orig_data[:, change_loc]

        minval = minmax_dict[f_name][0]
        maxval = minmax_dict[f_name][1]
        
        if n_img_horizontal > 1:
            main_axis = ax[plot_num // n_img_horizontal, plot_num % n_img_horizontal]
        else:
            main_axis = ax[plot_num // n_img_horizontal]
        
        if f_name == 'avpu':
            minv, maxv = -0.5,3.5
        if f_name == 'disoriented':
            minv, maxv = -0.5,1.5
        
        main_axis.plot(timepoints, feat_og_data, color='b', label='Feature values')
        main_axis.set_title(f"{display_name}")
        main_axis.set_xticks(timepoints, labels=time_tick_lbls)
        
        main_axis.set_ylabel(FEAT_AXIS_DSIPLAY_NAME[f_name], color='b', loc='center')
        main_axis.set_ylim([minval, maxval])
        main_axis.tick_params(axis='y', colors='b')
        
    image_location = f"{explanation_output_folder}/{image_name_prefix}_explanation.png" 
    print(f"Saving figure to {image_location}")
    figure.savefig(image_location)
    plt.close() 

In [None]:
input_data = pd.read_pickle("new_input2.pkl")
data_to_save = pickle.load(open("new_background.pkl", "rb"))

In [None]:
background_x = data_to_save['background_x']
background_y_lbl = data_to_save['background_y_lbl']
pad_in = data_to_save['pad_in']
wrapper_df = data_to_save['wrapper_df']
background_sizes = data_to_save['background_sizes']
col_names = data_to_save['col_names']
to_explain_shape = data_to_save['to_explain_shape']
test_comb = data_to_save['test_comb']

In [None]:
true_display_names = {
 'hours_since_admit': 'Hours Since Admit',
 'hr': 'Heart rate',
 'rr': 'Respiratory rate',
 'sbp': 'SBP',
 'dbp': 'DBP',
 'o2sat': 'Oxygen saturation',
 'temp_c': 'Temperature (Celsius)',
 'avpu': 'AVPU mental status',
 'disoriented': 'Disoriented',
 'bmi': 'BMI',
 'fio2_final': 'Delivered FiO2',
 'braden_activity': 'Braden Scale - Activity',
 'braden_friction': 'Braden Scale - Friction',
 'braden_mobility': 'Braden Scale - Mobility',
 'braden_moisture': 'Braden Scale - Moisture',
 'braden_nutrition': 'Braden Scale - Nutrition',
 'braden_sensory': 'Braden Scale - Sensory',
 'braden_scale': 'Sum Total of Braden Scale',
 'urine_output_sum': 'Urine output sum over 24 hours',
 'albumin': 'Albumin',
 'alk_phos': 'Alkaline phosphatase',
 'anion_gap': 'Anion Gap',
 'bands_pct': 'Bands',
 'bili_total': 'Total bilirubin',
 'bun': 'BUN',
 'calcium': 'Calcium',
 'chloride': 'Chloride',
 'co2': 'CO2',
 'creatinine': 'Creatinine',
 'eosinophils_pct': 'Eosinophils',
 'gluc_ser': 'Glucose',
 'hb': 'Hemoglobin',
 'inr': 'INR',
 'lactate': 'Lactate',
 'lipase': 'Lipase',
 'lymphocytes_pct': 'Lymphocytes',
 'magnesium': 'Magnesium',
 'mcv': 'Mean Corpuscular Volume',
 'monocytes_pct': 'Monocytes',
 'neutrophils_pct': 'Neutrophils',
 'pco2_art': 'PCO2 (Arterial)',
 'pco2_ven': 'PCO2 (Venous)',
 'ph_art': 'pH (Arterial)',
 'ph_ven': 'pH (Venous)',
 'phosphate': 'Inorganic Phosphate',
 'platelet_count': 'Platelets',
 'po2_art': 'PO2 (Arterial)',
 'potassium': 'Potassium',
 'ptt': 'PTT',
 'rdw': 'RBC Distribution Width',
 'sgot': 'SGOT',
 'sodium': 'Sodium',
 'total_protein': 'Total Protein',
 'wbc': 'WBC count',
 'age': 'Age'
}

In [None]:
model = TheWrapper(wrapper_df, background_sizes, col_names, to_explain_shape, n_workers=1, multiproc=False)

In [None]:
USED_FEATURE_LOCS = list(range(1,58)) #Original is 0 to 59

In [None]:
N_REPEATED_TRIALS = 1

OUTS_RANGE_COMTE = range(0,N_REPEATED_TRIALS)
TRIAL_FOLDER = 'final_comte/'
COMTE_SUBSET = [0,1,2,3,4,5,6,7,8,9]

folders_COMTE = [TRIAL_FOLDER + x for x in [f'outs{y}' for y in OUTS_RANGE_COMTE]]

comte_exps = load_explanations(folders_COMTE, ftype='comte', subset=COMTE_SUBSET)

In [None]:
CHANGE_59_TO_57 = True
cols_names_57 = np.array(col_names[1:-1])

In [None]:
original_data = []
for encid in input_data['wrapper_df']['combined_dat']['encounter_id'].unique():
    matchdata = input_data['wrapper_df']['combined_dat'][input_data['wrapper_df']['combined_dat']['encounter_id'] == encid]
    if CHANGE_59_TO_57:
        matchdata = matchdata.iloc[:, 1:-1]
    original_data.append(matchdata)

In [None]:
EXCLUSION_FIELDS = ['hours_since_admit', 'outcome_ward24hr', 'timeofday']
EXCLUSION_LOCS = [list(cols_names_57).index(x) for x in EXCLUSION_FIELDS]

# Generate counterfactual images

In [None]:
def gen_counterfactual_images(explanation_dict, original_data, cols_names_used, subset):
    minmax_by_trial = {}
    comte_locs = {}
    for trial_folder_path, list_of_exps_for_trial in explanation_dict.items():
        per_i_diff_locs = get_different_feats(list_of_exps_for_trial, original_data, subset=subset)
        print(f"per_i_diff_locs: {per_i_diff_locs}")
        outs_ids = trial_folder_path.split('/')[-1]
        
        for case_id, xset in enumerate(per_i_diff_locs):
            for x in xset:
                if x in EXCLUSION_LOCS:
                    print(f"Dropping {cols_names_used[x]} at loc {x} from Case: {case_id} Folder: {outs_ids}")
                    xset.remove(x)
        
        minmax_by_case = {}
        for case_id, case_data in enumerate(list_of_exps_for_trial):
            minmax_vals = make_minmax_dict([case_data], original_data[case_id], cols_names_used)
            minmax_by_case[case_id] = minmax_vals
            
            image_name = f"case{subset[case_id]}_comte_{outs_ids}"
            
            visualise_counterfactual(case_data, 
                                     original_data[subset[case_id]].to_numpy(), 
                                     per_i_diff_locs[case_id], 
                                     cols_names_used, 
                                     explanation_output_folder='testing01', 
                                     image_name_prefix=image_name, 
                                     minmax_dict=minmax_vals, 
                                     n_img_horizontal=1)
        minmax_by_trial[outs_ids] = minmax_by_case
        comte_locs[outs_ids] = per_i_diff_locs
    return minmax_by_trial, comte_locs

In [None]:
FEAT_AXIS_DSIPLAY_NAME = {
    'hr': 'Heart rate, bpm',
    'rr': 'RR, bpm',
    'sbp': 'SBP, mmHg',
    'dbp': 'DBP, mmHg',
    'o2sat': 'O2 saturation, %',
    'fio2_final': 'FiO2, %',
    'lactate': 'Lactate, mmol/L',
    'creatinine': 'Creatinine mg/dL',
    'bili_total': 'Bilirubin total, mg/dL',
    'sgot': 'SGOT, U/L',
    'temp_c': 'Temperature, \u00B0C',
    'avpu': 'AVPU, status',
    'disoriented': 'Disoriented, status',
    'bmi': 'BMI, kg/m\u00b2',
    'braden_moisture': 'Braden Scale - Moisture, score',
    'braden_nutrition': 'Braden Scale - Nutrition, score',
    'braden_activity': 'Braden Scale - Activity, score',
    'bands_pct': 'PCT',
    'hb': 'HB',
    'lymphocytes_pct': 'lymphocytes_pct',
    'mcv': 'mcv',
    'mcv2': 'mcv2',
    'braden_friction': 'braden_friction',
    'gluc_ser': 'gluc_ser',
    'inr': 'inr',
    'sodium': 'sodium',
    'age': 'age',
    'bun': 'bun',
    'braden_scale': 'Braden Scale - Total, score',
    'platelet_count': 'platelet_count',
}

In [None]:
minmax_dict, comte_locs = gen_counterfactual_images(comte_exps, original_data, cols_names_57, subset=COMTE_SUBSET)

# Running counterfactuals through model

In [None]:
model_out = model(input_data['background_x'])

In [None]:
for i, x in enumerate(input_data['background_x']):
    inst_len = input_data['background_sizes'][i]
    model.to_explain_shape = (inst_len, 59)
    val = model(x[-inst_len:])

# WindowSHAP - Making visuals

In [None]:
N_INSTANCES_EXPLAINED = 10
N_REPEATED_TRIALS = 1

OUTS_RANGE_WSHAP = range(0,1)
OUTS_RANGE_ANCH = range(0,1)
OUTS_RANGE_FEATPERM = range(0,N_REPEATED_TRIALS)
OUTS_RANGE_TSR = range(0,N_REPEATED_TRIALS)
OUTS_RANGE_MPERM = range(0,N_REPEATED_TRIALS)

In [None]:
x = pickle.load(open("explanation_outputs/outs10/WindowSHAP_encounter_10000005.0explanation.pkl", "rb"))

In [None]:
folders_ANCHORS = ['trial11_vis_outs/' + x for x in [f'outs{y}' for y in OUTS_RANGE_ANCH]]

folders_WSHAP = ['trial13_vis_outs/' + x for x in [f'outs{y}' for y in OUTS_RANGE_WSHAP]]

folders_FEATPERM = ['trial9_vis_outs/' + x for x in [f'outs{y}' for y in OUTS_RANGE_FEATPERM]]

folders_TSR = ['trial9_vis_outs/' + x for x in [f'outs{y}' for y in OUTS_RANGE_TSR]]

folders_MPERM = ['trial9_vis_outs/' + x for x in [f'outs{y}' for y in OUTS_RANGE_MPERM]]

wshap_exps = load_explanations(folders_WSHAP, ftype='wshap')

manualperm_exps = load_explanations(folders_MPERM, ftype='manualperm')


In [None]:
def get_top_k_featattrib(np_exp_tensor, k=10):
    per_feat_vals = np.abs(np_exp_tensor.squeeze()).sum(axis=0)
    top_k_locs = np.argpartition(per_feat_vals, -k)[-k:]
    print(f"k: {k} and then top_k_locs: {top_k_locs}")
    top_vals = np_exp_tensor[:,top_k_locs]
    
    return top_k_locs, top_vals
    

In [None]:
def visualise_featureattribution(importance_data, orig_data, top_k_locs, all_feature_names, 
                             explanation_output_folder, image_name_prefix, minmax_dict, 
                             n_img_horizontal=1, timepoint_tolerance=0.06, orig_only=False):
    assert len(all_feature_names) == importance_data.shape[-1] == orig_data.shape[-1], \
    f"Shapes dont align; len(all_feature_names)={len(all_feature_names)} importance_data.shape[-1]={importance_data.shape[-1]} orig_data.shape[-1]={orig_data.shape[-1]} \n"
    
    ##########
    n_img_vertical = len(top_k_locs) // n_img_horizontal
    remainder = len(top_k_locs) % n_img_horizontal
    if remainder != 0:
        n_img_vertical + 1
        
    fig_height = n_img_vertical * 2
    fig_width = n_img_horizontal * 8
    figure, ax = plt.subplots(n_img_vertical, n_img_horizontal, layout='constrained', 
                              figsize=(fig_width, fig_height))
    
    if remainder > 0:
        extra_cols_n = n_img_horizontal - remainder
        for extra_col in range(remainder, remainder+extra_cols_n):
            figure.delaxes(ax[n_img_vertical-1, extra_col])
        
    #########
    timepoints = orig_data[:, 0]
    max_time = timepoints.max()
    time_diff_req = max_time * timepoint_tolerance
    
    time_tick_lbls = [f"" for x in timepoints]
    most_recent_time = -100
    for idx, time_val in enumerate(timepoints):
        diff_cur_prev = time_val - most_recent_time
        if diff_cur_prev > time_diff_req:
            time_tick_lbls[idx] = round(time_val, 1)
            most_recent_time = round(time_val, 1)    
    
    ######### Normalize importance data
    importance_data = (importance_data - importance_data.mean()) / importance_data.std()
    
    #########
    min_importance = importance_data[:,top_k_locs].min()
    min_importance = min_importance - (np.abs(min_importance) * 0.1)
    max_importance = importance_data[:,top_k_locs].max()
    max_importance = max_importance + (np.abs(max_importance) * 0.1)
        
    for plot_num, top_loc in enumerate(top_k_locs):
        f_name = all_feature_names[top_loc]
        display_name = true_display_names[f_name]
        feat_importance_data = importance_data[:, top_loc]
        feat_og_data = orig_data[:, top_loc]
        
        bar_width = 10/feat_og_data.shape[0]
        
        minval = minmax_dict[f_name][0]
        maxval = minmax_dict[f_name][1]
        print(f"display_name: {display_name} min: {minval} max: {maxval}")
        
        if n_img_horizontal > 1:
            main_axis = ax[plot_num // n_img_horizontal, plot_num % n_img_horizontal]
        else:
            main_axis = ax[plot_num // n_img_horizontal]
        
        if not orig_only:
            overlay_plot = main_axis.twinx()
            overlay_plot.axhline(y=0, color='red', linestyle=":", alpha=0.25)
            overlay_plot.set_zorder(1)
            main_axis.set_zorder(2)
            main_axis.patch.set_visible(False)
        
        main_axis.plot(timepoints, feat_og_data, color='b', label='Feature values')
        main_axis.set_title(f"{display_name}")
        main_axis.set_xticks(timepoints, labels=time_tick_lbls)
        
        if f_name == 'avpu':
            minv, maxv = -0.5,3.5
            main_axis.set_yticks([0,1,2,3])
            main_axis.set_yticklabels(['Alert', 'Responds to Voice', 'Responds to Pain', 'Unresponsive'])
        if f_name == 'disoriented':
            minv, maxv = -0.5,1.5
            overlay_plot.set_yticks([0,1])
            overlay_plot.set_yticklabels(['No', 'Yes'])
        
        main_axis.set_ylabel(FEAT_AXIS_DSIPLAY_NAME[f_name], color='b', loc='center')
        main_axis.set_ylim([minval, maxval])
        main_axis.tick_params(axis='y', colors='b')

        if not orig_only:
            overlay_plot.bar(timepoints, feat_importance_data, color='r', label='Importance Values', 
                             linestyle='dashed', width=bar_width)
            overlay_plot.set_ylabel("Importance Values", color='r')
            overlay_plot.set_ylim([min_importance, max_importance])
            overlay_plot.tick_params(axis='y', colors='r')
        
    image_location = f"{explanation_output_folder}/{image_name_prefix}_explanation.png" 
    print(f"Saving figure to {image_location}")
    figure.savefig(image_location)
    plt.close() 

In [None]:
def gen_featureattribution_images(explanation_dict, original_data, cols_names_used, minmax_dict, comte_locs, img_prefix="wshap", orig_only=False):
    for trial_folder_path, list_of_exps_for_trial in explanation_dict.items():
        
        outs_ids = trial_folder_path.split('/')[-1]
        minmax_vals_for_trial = minmax_dict[outs_ids]
        trial_comte_locs = comte_locs[outs_ids]
        
        for case_id, case_data in enumerate(list_of_exps_for_trial):
            minmax_for_case = minmax_vals_for_trial[case_id]
            
            image_name = f"case{case_id}_{img_prefix}_{outs_ids}"
            
            case_top_k_locs, _ = get_top_k_featattrib(np_exp_tensor=case_data, k=5)
            filtered_case_locs = []
            for item in case_top_k_locs:
                if item not in EXCLUSION_LOCS:
                    filtered_case_locs.append(item)
                    
            case_comte_locs = trial_comte_locs[case_id]
            case_top_k_locs = np.array(list(set(filtered_case_locs + case_comte_locs)))
        
            importance_vals = np.sum(np.abs(case_data), axis=0)[case_top_k_locs]
            order_locs = np.argsort(importance_vals)
            ordered_fs = [case_top_k_locs[x] for x in order_locs]
            ordered_fs.reverse()
            
            visualise_featureattribution(case_data, 
                                     original_data[case_id].to_numpy(), 
                                     ordered_fs, 
                                     cols_names_used, 
                                     explanation_output_folder='testing01', 
                                     image_name_prefix=image_name, 
                                     minmax_dict=minmax_for_case, 
                                     n_img_horizontal=1,
                                     orig_only=orig_only)
            


In [None]:
gen_featureattribution_images(wshap_exps, original_data, cols_names_57, minmax_dict, comte_locs,
                             orig_only=True)

# Evaluating Anchor Explanations on real data

In [None]:
all_ids = real_data['combined_dat']['encounter_id'].unique()
total_len = real_data['combined_dat']

In [None]:
list_of_uniques = []
latest_id = all_ids[0]

current_sect = []
for i, row_data in real_data['combined_dat'].iterrows():
    if i % 10000 == 0:
        print(f"i: {i}")
    if row_data['encounter_id'] == latest_id:
        current_sect.append(row_data)
    else:
        try:
            this_id_group = pd.concat(current_sect, axis=1).transpose()
            list_of_uniques.append(this_id_group)
            print(f"Added patient {latest_id} with {this_id_group.shape} rows")

            current_sect = [row_data]
            latest_id = row_data['encounter_id']
        except Exception as ex:
            print(f"Error: {ex}")
            current_sect = [row_data]
            latest_id = row_data['encounter_id']

In [None]:
folder_run = list(anchor_exps.keys())[0]
instance_num = 1
MAX_NUM = 1000

anchor_rules = anchor_exps[folder_run][instance_num]['names']

def eval_rules(list_of_uniques):
    rules = []
    for rule in anchor_rules:
        feat_time, sign_char, val = rule.split(" ")
        feat_name = "_".join(feat_time.split("_")[:-1])
        time = int(feat_time.split("_")[-1])
        feat_loc = col_names.index(feat_name)

        rules.append([feat_name, feat_loc, time, sign_char, val])

    total_n_instances = 0
    n_inst_rule_applies = 0
    n_inst_rule_applies_pos = 0
    n_inst_rule_applies_neg = 0
    n_pred_pos = 0
    n_rule_fails_but_pos = 0
    

    for i, test_instance in enumerate(list_of_uniques):
        rule_applies = True
        
        if i >= MAX_NUM:
            break
        
        np_inst = test_instance.to_numpy()
        model.to_explain_shape = np_inst.shape
        prediction = model(np_inst)
        pred_lbl = np.argmax(prediction)
        if pred_lbl == 1:
            n_pred_pos += 1

        total_n_instances += 1

        for r_data in rules:
            r_time = r_data[2]
            r_f_name = r_data[0]
            r_sign_char = r_data[3]
            r_val = r_data[4]

            if test_instance.shape[0] > r_time:
                if r_sign_char == '>':
                    if not test_instance.iloc[r_time][r_f_name] > float(r_val):
                        rule_applies = False
                elif r_sign_char == '<':
                    if not test_instance.iloc[r_time][r_f_name] < float(r_val):
                        rule_applies = False
            else:
                rule_applies = False

        if rule_applies:
            n_inst_rule_applies += 1

            if pred_lbl == 1:
                n_inst_rule_applies_pos += 1
            else:
                n_inst_rule_applies_neg += 1
        else:
            n_rule_fails_but_pos += 1
                
    return total_n_instances, \
            n_inst_rule_applies, \
            n_inst_rule_applies_pos, \
            n_inst_rule_applies_neg, \
            n_pred_pos, \
            n_rule_fails_but_pos


In [None]:
total_n_instances, n_inst_rule_applies, n_inst_rule_applies_pos, n_inst_rule_applies_neg, n_pred_pos, n_rule_fails_but_pos = eval_rules(list_of_uniques)

In [None]:
precision = n_inst_rule_applies_pos / (n_inst_rule_applies_pos + n_inst_rule_applies_neg)
coverage = n_inst_rule_applies / total_n_instances

In [None]:
#roughly 5% of instances should have positive label
print(f"applied_pos to applied ratio/PRECISION: {precision}")
print(f"COVERAGE: {coverage}")
print(f"n_inst_rule_applies: {n_inst_rule_applies}, n_inst_rule_applies_pos:{n_inst_rule_applies_pos}, n_inst_rule_applies_neg:{n_inst_rule_applies_neg}")
print(f"total_n_instances: {total_n_instances}")


In [None]:
back_exs = pickle.load(open("new_background.pkl", "rb"))
b_matches = [pd.DataFrame(x, columns=col_names) for x in back_exs['background_x']]

In [None]:
tot1, app1, apppos1, appneg1 = eval_rules(b_matches)
cov = app1/tot1
prc = apppos1/app1

print(f"Precision: {prc} \t\t Coverage: {cov}")

# Explain Integrated Gradients

In [None]:
IG_importances = pickle.load(open('trial22_vis_outs/IntegratedGrad', 'rb'))

In [None]:
ig_exps = {'trial22_vis_outs/outs0': [v['importance'].to_numpy()[:, 1:-1] for k,v in enumerate(IG_importances)]}

In [None]:
comte_plus_anchors = copy.deepcopy(comte_locs)
#defined via manual examination of explanations
anchor_feat_locs = [
    [],
    [],
    [],
    ['lactate', 'rr', 'hr', 'braden_scale', 'fio2_final'],
    ['rr', 'fio2_final'],
    ['hr', 'braden_nutrition', 'fio2_final'],
    ['rr', 'fio2_final'],
    ['rr', 'fio2_final'],
    [],
    [],
]

for il, caselist in enumerate(comte_plus_anchors['outs0']):
    for af in anchor_feat_locs[il]:
        if list(cols_names_57).index(af) not in caselist:
            caselist.append(list(cols_names_57).index(af))
comte_plus_anchors

In [None]:
gen_featureattribution_images(ig_exps, original_data, cols_names_57, minmax_dict, comte_locs, img_prefix="IG")

In [None]:
real_background = pickle.load(open('real_background.pkl', 'rb'))
real_input2 = pickle.load(open('new_input2.pkl', 'rb'))

real_input2['orig_dat_noise'][['encounter_id', 'hours_since_admit', 'creatinine']][real_input2['orig_dat_noise']['creatinine'].notnull()]

In [None]:
real_input2['orig_dat_noise'][['encounter_id', 'hours_since_admit', 'creatinine']][real_input2['orig_dat_noise']['creatinine'].notnull()]
                                                                                                                 
testinginput3 = pickle.load(open('2000_01_2_new_pred_dict_noise_20241031.pickle', 'rb'))
check3 = testinginput3['orig_dat_noise'][['encounter_id', 'hours_since_admit', 'creatinine']][testinginput3['orig_dat_noise']['creatinine'].notnull()]
print(f"without any: {len(check3['encounter_id'].unique())}")

In [None]:
new_input = pd.read_pickle('2000_high_new_pred_dict_noise_20241031.pickle')
tempid = 'EAAAL91N'
new_input['orig_dat_sample'][['encounter_id', 'hours_since_admit', 'creatinine']][new_input['orig_dat_sample']['encounter_id'] == tempid]

In [None]:
all_data_set = pd.read_pickle('copy_deterioration_data_20240417-1657_combined_input_dict.pickle')

In [None]:
min_rec_length = 28
complied_records = []

working_id = all_data_set['combined_dat'].iloc[0]['encounter_id']
working_data = []

for row_i in range(len(all_data_set['combined_dat'])):
    if row_i % 1000 == 0:
        print(f"On {row_i}/{len(all_data_set['combined_dat'])} count: {len(complied_records)}")
    
    row = all_data_set['combined_dat'].iloc[row_i]
    row_id = row['encounter_id']
    
    if row_id == working_id:
        working_data.append(row)
    else:
        pd_record = pd.concat(working_data, axis=1).transpose()
        
        if len(pd_record) >= min_rec_length:
            complied_records.append(pd_record.iloc[-28:, :])
            
        working_id = row_id
        working_data = [row]

In [None]:
stacked_recs = np.stack(complied_records)

In [None]:
conditions = ["braden_nutrition_14 = 2"]
which_case = 5

eid = real_input2['orig_dat_noise']['encounter_id'].unique()[which_case]

match = real_input2['orig_dat_noise'][real_input2['orig_dat_noise']['encounter_id'] == eid]
local_background = stacked_recs[:, -match.shape[0]:, :]
start_size = local_background.shape[0]

for c in conditions:
    f_t, symbol, val = c.split(" ")
    feat, time = f_t.rsplit("_", 1)
    col_loc = col_names.index(feat)

    if symbol == ">":
        applied_conds = local_background[:, int(time), col_loc] > float(val)
    elif symbol == "<=":
        applied_conds = local_background[:, int(time), col_loc] <= float(val)
    elif symbol == "=":
        applied_conds = local_background[:, int(time), col_loc] = float(val)
    else:
        raise NotImplementedError("Err")

    true_idxs = np.where(applied_conds)
    local_background = local_background[true_idxs]
    print(f"Applying condition {feat} at time {time} {symbol} {val}. New size: {local_background.shape}")
    
model.to_explain_shape = match.shape
if local_background.shape[0] > 1000:
    res_holder = []
    for chunk in np.array_split(local_background, (local_background.shape[0]//1000)+1):
        tempr = model(chunk)
        res_holder.append(tempr)
    res = np.concatenate(res_holder, axis=0)
else:
    res = model(local_background)
res = res[:,1] > res[:,0]

### Checking Anchor

In [None]:
temp = pickle.load(open('old_outputs/trial20_vis_outs/outs0/Anchors_encounter_10000005.0explanation.pkl', 'rb'))

In [None]:
tempins = pickle.load(open('new_input2.pkl', 'rb'))
tempins.keys()
tempins['orig_dat_noise']['encounter_id'].unique()

ids = ['EAAAMYJB', 'EAAAIG9E', 'EAAAJ8PR', 'EAAAM4E5', 'EAAANTEY',
       'EAAAKMHX', 'EAAAL0KH', 'EAAAL91N', 'EAAANMQ0', 'EAAALRCT']

case_id = 5
holder = tempins['orig_dat_noise'][tempins['orig_dat_noise']['encounter_id'] == ids[case_id]]
for x in list(range(holder.shape[0])):
    t = holder['hours_since_admit'].iloc[x]
    print(f"Mapping written time {x} to real time {t}")

In [None]:
tempins = pickle.load(open('new_input2.pkl', 'rb'))
unique_ids = tempins['orig_dat_noise']['encounter_id'].unique()

care_about_cols_labs = ['hours_since_admit', 'hr', 'rr', 'sbp', 'dbp', 'o2sat', 'temp_c', 'avpu', 'disoriented',
       'fio2_final', 'braden_nutrition', 'braden_scale', 'lactate', 'creatinine']

tempins['orig_dat_noise']['hours_since_admit'] = tempins['orig_dat_noise']['hours_since_admit'].round(1)

tempins['orig_dat_noise']['avpu'].loc[tempins['orig_dat_noise']['avpu'] == 0] = 'Alert'
tempins['orig_dat_noise']['avpu'].loc[tempins['orig_dat_noise']['avpu'] == 1] = 'Responds to Voice'
tempins['orig_dat_noise']['avpu'].loc[tempins['orig_dat_noise']['avpu'] == 2] = 'Responds to Pain'
tempins['orig_dat_noise']['avpu'].loc[tempins['orig_dat_noise']['avpu'] == 3] = 'Unresponsive'

tempins['orig_dat_noise']['disoriented'].loc[tempins['orig_dat_noise']['disoriented'] == 0] = 'No'
tempins['orig_dat_noise']['disoriented'].loc[tempins['orig_dat_noise']['disoriented'] == 1] = 'Yes'

for inum, eid in enumerate(unique_ids):
    match_recs = tempins['orig_dat_noise'][tempins['orig_dat_noise']['encounter_id'] == eid]
    desired_data1 = match_recs[care_about_cols_labs].rename(columns=true_display_names).transpose()
    
    #drop empty cols
    to_drop = []
    for colname in desired_data1.columns:
        if pd.isnull(desired_data1[colname][1:]).values.all():
            to_drop.append(colname)
    print(f"to_drop: {to_drop}")
    desired_data1 = desired_data1.drop(to_drop, axis=1)
    
    #merge cols in same hour if they have no conflicts
    to_drop2 = []
    for cloc in range(len(desired_data1.columns)-1):
        thiscol = desired_data1.iloc[:, cloc]
        nextcol = desired_data1.iloc[:, cloc+1]
        if thiscol.iloc[0] - nextcol.iloc[0] < 1:
            thiscol_nan = pd.isnull(thiscol)
            nextcol_nan = pd.isnull(nextcol)
            col_holder = pd.Series(np.full(14, np.nan))
            if (thiscol_nan.astype(int) + nextcol_nan.astype(int) < 2).all():
                for ix, pair in enumerate(zip(thiscol, nextcol)):
                    val1, val2 = pair[0], pair[1]
                    try:
                        col_holder.iloc[ix] = val1 if not np.isnan(val1) else val2
                    except TypeError as ex:
                        col_holder.iloc[ix] = val1 if not val1 == 'nan' else val2
                to_drop2.append(cloc+1)
                desired_data1.iloc[:, cloc] = col_holder
    
    desired_data1 = desired_data1.fillna(value="-")
    desired_data1.to_csv(f'testing01/case{inum}_labs_subsetvars1.csv', header=False)
    
    print("---------------------------------")