In [13]:
import pandas as pd
import polars as pl
import numpy as np
import os
import pickle

#### Basic overview of generated embedding .pkl file

In [14]:
def load_pickle(filepath: str):
    """Load a pickled object.

    Args:
        filepath (str): Path to pickle (.pkl) file.

    Returns:
        Any: Loaded object.
    """
    with open(filepath, "rb") as f:
        data = pickle.load(f)
    return data

def save_pickle(target: dict, filepath: str, fname: str = "mm_feat.pkl"):
    """Save a pickled object from a dictionary.

    Args:
        filepath (str): Path to pickle (.pkl) file.

    Returns:
        Any: Loaded object.
    """
    with open(os.path.join(filepath, fname), "wb") as f:
        pickle.dump(target, f)

#### Get patient IDs from pkl file

In [15]:
pt_embs = load_pickle("../outputs/prep_data_us/mmfair_feat.pkl")

In [16]:
len(list(pt_embs.keys()))

20130

#### Replace EHR data

In [17]:
emb_old = load_pickle("../outputs/prep_old/prep_data/mmfair_feat.pkl")
embeddings = load_pickle("../outputs/prep_data_us/mmfair_feat.pkl")
cols = load_pickle("../outputs/prep_data_us/mmfair_cols.pkl")

In [18]:
indices = [cols['static_cols'].index(col) for col in cols['static_cols'][-14:]]
print(indices)
print([cols['static_cols'][i] for i in indices])

[86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]
['gender_F', 'race_group_Hispanic_Latino', 'race_group_Black', 'race_group_White', 'race_group_Asian', 'race_group_Other', 'marital_status_Married', 'marital_status_Single', 'marital_status_Widowed', 'marital_status_Divorced', 'insurance_Medicare', 'insurance_Medicaid', 'insurance_Private', 'insurance_Other']


In [19]:
print(len(embeddings.keys()), len(emb_old.keys()))

20130 20596


In [20]:
print(len(embeddings.keys()), len(emb_old.keys()))
### Get keys in embeddings but not in embeddings_old
keys = set(embeddings.keys()) - set(emb_old.keys())
print(len(keys))
### Get keys in embeddings_old but not in embeddings
keys = set(emb_old.keys()) - set(embeddings.keys())
print(len(keys))

20130 20596
0
466


In [21]:
emb_old[10296921]['dynamic_0']

array([[ 54.71428571,  88.        ,  98.14285714,  18.14285714,
         95.57142857,  27.34285714],
       [ 48.        ,  97.66666667,  98.        ,  17.33333333,
         86.        ,  65.43333333],
       [ 49.5       ,  96.        ,  99.75      ,  18.5       ,
         87.5       ,  -1.        ],
       [ 55.2       , 104.        ,  99.4       ,  18.4       ,
         98.4       ,  19.02      ],
       [ 43.75      ,  79.        ,  74.5       ,  14.5       ,
         73.        ,  73.1       ]])

In [22]:
embeddings[10296921]['dynamic_0']

array([[ 54.71428571,  88.        ,  98.14285714,  18.14285714,
         95.57142857,  98.2       ],
       [ 48.        ,  97.66666667,  98.        ,  17.33333333,
         86.        ,  98.65      ],
       [ 49.5       ,  96.        ,  99.75      ,  18.5       ,
         87.5       ,  98.65      ],
       [ 55.2       , 104.        ,  99.4       ,  18.4       ,
         98.4       ,  99.1       ],
       [ 58.66666667, 105.66666667,  99.66666667,  19.66666667,
         97.66666667,  97.8       ]])

In [23]:
for pt_key in embeddings.keys():
    embeddings[pt_key]['notes'] = emb_old[pt_key]['notes']
    #embeddings[pt_key]['static'] = np.array(embeddings[pt_key]['static']).reshape(1, -1)

In [24]:
### Correct and export the embeddings
save_pickle(embeddings, "../outputs/prep_data_us", "mmfair_feat.pkl")

#### Test SHAP values vs embeddings

In [None]:
embeddings = load_pickle("../outputs/prep_data/mmfair_feat.pkl")
#emb_old = load_pickle("../outputs/prev_data/mmfair_feat.pkl")
cols = load_pickle("../outputs/prep_data/mmfair_cols.pkl")
test_ids = (pl.read_csv(os.path.join("../outputs/prep_data/testing_ids_ext_stay_7.csv"))
        .select("subject_id")
        .to_numpy()
        .flatten()
)

In [None]:
shap_embeddings = load_pickle("../outputs/explanations/ext_stay_7_concat_static_timeseries_notes/shap_ext_stay_7_concat_static_timeseries_notes.pkl")

#### SHAP values inference

In [None]:
### Tabluar data
for i in range(shap_embeddings['batch_1']['static'].shape[2]):
    print(f"{cols['static_cols'][i]} -> SHAP {shap_embeddings['batch_0']['static'][9][0][i]}; Actual {embeddings[test_ids[9]]['static'][0][i]}")

print(embeddings[test_ids[1]]['static'].shape)

### Compare embedding keys

In [None]:
print(len(embeddings.keys()), len(emb_old.keys()))
### Get keys in embeddings but not in embeddings_old
keys = set(embeddings.keys()) - set(emb_old.keys())
print(len(keys))
### Get keys in embeddings_old but not in embeddings
keys = set(emb_old.keys()) - set(embeddings.keys())
print(len(keys))

In [None]:
risk_dict = load_pickle("..\outputs\evaluation\ext_stay_7_concat_static_timeseries\pf_ext_stay_7_concat_static_timeseries.pkl")

In [None]:
fair_dict = load_pickle("../outputs/fairness/ext_stay_7_None_timeseries/pf_ext_stay_7_None_timeseries.pkl")
fair_dict_cst = load_pickle("../outputs/fairness/ext_stay_7_concat_static_timeseries/pf_ext_stay_7_concat_static_timeseries.pkl")
fair_dict_cstn = load_pickle("../outputs/fairness/ext_stay_7_concat_static_timeseries_notes/pf_ext_stay_7_concat_static_timeseries_notes.pkl")

In [None]:
fair_dict

In [None]:
### Select the keys that start with fair_
fair_keys = [key for key in fair_dict.keys() if key.startswith("fair_")]
### Select items from the dictionary that start with fair_
fair_dict_f = {key: fair_dict[key] for key in fair_keys}
fair_dict_cstf = {key: fair_dict_cst[key] for key in fair_keys}
fair_dict_cstnf = {key: fair_dict_cstn[key] for key in fair_keys}
### Display values within the keys
fair_df = pd.concat([pd.DataFrame(fair_dict_f), pd.DataFrame(fair_dict_cstf), pd.DataFrame(fair_dict_cstnf)], axis=1)

In [None]:
fair_df.T

In [None]:
len(risk_dict['risk_quantile']), len(risk_dict['y_prob']), len(risk_dict['test_ids'])

In [None]:
risk_dict['yd_idx']

In [None]:
embeddings

In [None]:
cols

#### Recode some of the embeddings

In [None]:
with open(os.path.join("../outputs/processed_data", "mmfair_feat.pkl"), "wb") as f:
        pickle.dump(embeddings, f)

In [None]:
extr_list = []
for item in embeddings[id_val]['notes']:
    extr_list.append(item[1])
print(extr_list)

#### Test training IDs

In [None]:
train_hosp_death = pd.read_csv('../outputs/prep_data/training_ids_in_hosp_death.csv')
val_hosp_death = pd.read_csv('../outputs/prep_data/validation_ids_in_hosp_death.csv')
test_hosp_death = pd.read_csv('../outputs/prep_data/testing_ids_in_hosp_death.csv')

In [None]:
train_icu = pd.read_csv('../outputs/prep_data/training_ids_icu_admission.csv')

In [None]:
overlap = set(train_icu['subject_id']).intersection(set(train_hosp_death['subject_id']))
print(len(overlap))

In [None]:
emb_ids = list(embeddings.keys())
overlap = set(emb_ids).intersection(set(train_hosp_death['subject_id']))
overlapv = set(emb_ids).intersection(set(val_hosp_death['subject_id']))
overlapt = set(emb_ids).intersection(set(test_hosp_death['subject_id']))
print(len(overlap), len(overlapv), len(overlapt), len(train_hosp_death), len(embeddings.keys()))