In [1]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
import random

In [2]:
vitals = pd.read_hdf("all_hourly_data.h5", 'vitals_labs_mean')
print(vitals.shape)

(2200954, 104)


In [3]:
interventions = pd.read_hdf("all_hourly_data.h5",'interventions')
print(interventions.shape)

(2200954, 14)


In [4]:
patients = pd.read_hdf("all_hourly_data.h5", 'patients')
print(patients.shape)

(34472, 28)


In [5]:
# Prepare labels for los > 7 task
patients["los_7"] = patients['los_icu'] > 7
patients.drop(columns=['los_icu'], inplace=True)

In [6]:
patients = patients[["gender","ethnicity","age","los_7","mort_icu"]].reset_index()

In [7]:
patients.los_7 = patients.los_7.apply(int)

In [8]:
print(f"class imbalance for length of stay prediction: {patients.los_7.mean()}")
print(f"class imbalance for icu mortality prediction: {patients.mort_icu.mean()}")

class imbalance for length of stay prediction: 0.053521698770016245
class imbalance for icu mortality prediction: 0.06558946391274077


WINDOW_SIZE determines how much information we want to use for the prediction tasks. For example, a window length of 24 means that we will use the sequence of the first 24 readings for vitals and interventions.

In [9]:
WINDOW_SIZE = 24

In [10]:
interventions = interventions.reset_index()

In [11]:
interventions = interventions[interventions.hours_in < WINDOW_SIZE]

In [12]:
interventions

Unnamed: 0,subject_id,hadm_id,icustay_id,hours_in,vent,vaso,adenosine,dobutamine,dopamine,epinephrine,isuprel,milrinone,norepinephrine,phenylephrine,vasopressin,colloid_bolus,crystalloid_bolus,nivdurations
0,3,145834,211552,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0
1,3,145834,211552,1,1,1,0,0,1,0,0,0,0,1,0,0,0,0
2,3,145834,211552,2,1,1,0,0,1,0,0,0,0,1,0,0,0,0
3,3,145834,211552,3,1,1,0,0,0,0,0,0,0,1,0,0,0,0
4,3,145834,211552,4,1,1,0,0,0,0,0,0,1,1,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2200946,99999,113369,246512,19,0,0,0,0,0,0,0,0,0,0,0,0,0,0
2200947,99999,113369,246512,20,0,0,0,0,0,0,0,0,0,0,0,0,0,0
2200948,99999,113369,246512,21,0,0,0,0,0,0,0,0,0,0,0,0,0,0
2200949,99999,113369,246512,22,0,0,0,0,0,0,0,0,0,0,0,0,0,0


In [13]:
vitals = vitals.reset_index()

In [14]:
vitals = vitals[vitals.hours_in < WINDOW_SIZE]

In [15]:
vitals.shape

(808539, 108)

In [16]:
vitals.columns = vitals.columns.to_flat_index()
vitals.columns = list(map(lambda x: x[0] + x[1] if x[1] == "" else x[0] + " " + x[1] , vitals.columns.tolist()))

In [17]:
vitals

Unnamed: 0,subject_id,hadm_id,icustay_id,hours_in,alanine aminotransferase mean,albumin mean,albumin ascites mean,albumin pleural mean,albumin urine mean,alkaline phosphate mean,...,total protein mean,total protein urine mean,troponin-i mean,troponin-t mean,venous pvo2 mean,weight mean,white blood cell count mean,white blood cell count urine mean,ph mean,ph urine mean
0,3,145834,211552,0,25.0,1.8,,,,73.0,...,,,,,,,14.842857,,7.40,5.0
1,3,145834,211552,1,,,,,,,...,,,,,,,,,,
2,3,145834,211552,2,,,,,,,...,,,,,,,,,7.26,
3,3,145834,211552,3,,,,,,,...,,,,,,,,,,
4,3,145834,211552,4,,,,,,,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2200946,99999,113369,246512,19,,,,,,,...,,,,,,,,,,
2200947,99999,113369,246512,20,,,,,,,...,,,,,,,,,,
2200948,99999,113369,246512,21,,,,,,,...,,,,,,,,,,
2200949,99999,113369,246512,22,,,,,,,...,,,,,,,,,,


In [18]:
vitals.isna().mean(axis=0)

subject_id                           0.000000
hadm_id                              0.000000
icustay_id                           0.000000
hours_in                             0.000000
alanine aminotransferase mean        0.968198
                                       ...   
weight mean                          0.954340
white blood cell count mean          0.883480
white blood cell count urine mean    0.995364
ph mean                              0.859961
ph urine mean                        0.978932
Length: 108, dtype: float64

In [19]:
patients

Unnamed: 0,subject_id,hadm_id,icustay_id,gender,ethnicity,age,los_7,mort_icu
0,3,145834,211552,M,WHITE,76.526792,0,0
1,4,185777,294638,F,WHITE,47.845047,0,0
2,6,107064,228232,F,WHITE,65.942297,0,0
3,9,150750,220597,M,UNKNOWN/NOT SPECIFIED,41.790228,0,1
4,11,194540,229441,F,WHITE,50.148295,0,0
...,...,...,...,...,...,...,...,...
34467,99983,117390,286606,M,UNKNOWN/NOT SPECIFIED,78.576624,0,0
34468,99991,151118,226241,M,WHITE,47.729259,0,0
34469,99992,197084,242052,F,WHITE,65.772155,0,0
34470,99995,137810,229633,F,WHITE,88.698942,0,0


In [20]:
interventions_list = sorted(list(set(interventions.columns.tolist())-set(['subject_id', 'hadm_id', 'icustay_id', 'hours_in'])))

In [21]:
interventions_mapping = dict(zip(interventions_list, range(1, 1+len(interventions_list))))

In [22]:
interventions_mapping

{'adenosine': 1,
 'colloid_bolus': 2,
 'crystalloid_bolus': 3,
 'dobutamine': 4,
 'dopamine': 5,
 'epinephrine': 6,
 'isuprel': 7,
 'milrinone': 8,
 'nivdurations': 9,
 'norepinephrine': 10,
 'phenylephrine': 11,
 'vaso': 12,
 'vasopressin': 13,
 'vent': 14}

In [23]:
subjects = patients.subject_id.tolist()

In [24]:
bins= [0,16,35,60,80,100]
labels = ['child','young_adult','middle_age','old','very_old']
def bucket_age(age_column):
    bucketed = pd.cut(age_column, bins=bins, labels=labels, right=False).astype(str)
    return bucketed

In [25]:
patients.age = bucket_age(patients.age)

In [26]:
patients.age

0               old
1        middle_age
2               old
3        middle_age
4        middle_age
            ...    
34467           old
34468    middle_age
34469           old
34470      very_old
34471           old
Name: age, Length: 34472, dtype: object

In [27]:
patients[["gender","ethnicity", "age"]]

Unnamed: 0,gender,ethnicity,age
0,M,WHITE,old
1,F,WHITE,middle_age
2,F,WHITE,old
3,M,UNKNOWN/NOT SPECIFIED,middle_age
4,F,WHITE,middle_age
...,...,...,...
34467,M,UNKNOWN/NOT SPECIFIED,old
34468,M,WHITE,middle_age
34469,F,WHITE,old
34470,F,WHITE,very_old


In [28]:
l = patients[["gender","ethnicity", "age"]].values.tolist()
cat_values = set([item for sublist in l for item in sublist])
demo_mapping = dict(zip(cat_values, range(15,15 + len(cat_values))))
demographics = patients[["gender","ethnicity", "age"]].applymap(lambda x: demo_mapping[x])

In [29]:
patients[["gender","ethnicity", "age"]] = demographics

In [30]:
patients

Unnamed: 0,subject_id,hadm_id,icustay_id,gender,ethnicity,age,los_7,mort_icu
0,3,145834,211552,43,21,61,0,0
1,4,185777,294638,25,21,41,0,0
2,6,107064,228232,25,21,61,0,0
3,9,150750,220597,43,53,41,0,1
4,11,194540,229441,25,21,41,0,0
...,...,...,...,...,...,...,...,...
34467,99983,117390,286606,43,53,61,0,0
34468,99991,151118,226241,43,21,41,0,0
34469,99992,197084,242052,25,21,61,0,0
34470,99995,137810,229633,25,21,16,0,0


### SDPRL setup LOS_7  task
For each patient, match that patient with a random patient with a different outcome, that patient's modals will be used as SDPRL's negative samples

In [31]:
all_patients = patients.subject_id.values.tolist()

In [32]:
pos_patients = patients.loc[patients.los_7 == 1, "subject_id"].values.tolist()
neg_patients = patients.loc[patients.los_7 == 0, "subject_id"].values.tolist()

In [33]:
print(len(pos_patients), len(neg_patients))

1845 32627


In [34]:
random.sample(neg_patients,1)

[26723]

In [35]:
def matchmaking(patients, pos_patients, neg_patients):
    pairings = []
    for patient in patients:
        if patient in pos_patients:
            pairings.append((patient, random.sample(neg_patients,1)[0]))
        else:
            pairings.append((patient, random.sample(pos_patients,1)[0]))
    return pairings

In [36]:
pairings = matchmaking(all_patients, pos_patients, neg_patients)

Given the pairings, prepare the modals in a pytorch dataloader friendly format

In [37]:
interventions_demo = interventions.merge(patients,how="left", on="subject_id")

In [38]:
interventions_demo = interventions_demo.drop(["hours_in","los_7","mort_icu"],axis=1)

In [None]:
demo_seqs = []
for pair in pairings:
    interv_list1 = []
    interv_list2 = []
    tmp1 = interventions_demo[interventions_demo.subject_id == pair[0]]
    tmp2 = interventions_demo[interventions_demo.subject_id == pair[1]]
    for _, i in tmp1.iterrows():
        interv_demo = interventions_demo.columns[i.eq(1).values].tolist()
        intervention_at_timestamp = list(map(lambda x: interventions_mapping[x], interv_demo))
        interv_list1.append( i[["gender","ethnicity","age"]].values.tolist() + intervention_at_timestamp)
    for _, i in tmp2.iterrows():
        interv_demo = interventions_demo.columns[i.eq(1).values].tolist()
        intervention_at_timestamp = list(map(lambda x: interventions_mapping[x], interv_demo))
        interv_list2.append( i[["gender","ethnicity","age"]].values.tolist() + intervention_at_timestamp)
    demo_seqs.append((interv_list1,interv_list2))

In [41]:
with open("demograph_interventions_SDPRL.pkl","wb") as f:
    pickle.dump(demo_seqs,f)

In [42]:
len(demo_seqs)

34472

In [43]:
vitals_zero = vitals.fillna(0)

In [44]:
l = patients[["gender","ethnicity", "age"]].values.tolist()
cat_values = set([item for sublist in l for item in sublist])

In [45]:
demo_mapping = dict(zip(cat_values, range(len(cat_values))))


In [46]:
demographics = patients[["gender","ethnicity", "age"]].applymap(lambda x: demo_mapping[x])

In [47]:
patients[["gender","ethnicity", "age"]] = demographics

In [48]:
vitals_zero_demo = vitals_zero.merge(patients,how="left", on="subject_id")

In [49]:
vitals_columns = [x for x in vitals.columns.tolist() if x not in ['subject_id', 'hadm_id', 'icustay_id', 'hours_in']]

In [50]:
demo_columns = ["gender","ethnicity","age"]

In [51]:
vitals_zero_demo[["subject_id"] + vitals_columns + demo_columns]

Unnamed: 0,subject_id,alanine aminotransferase mean,albumin mean,albumin ascites mean,albumin pleural mean,albumin urine mean,alkaline phosphate mean,anion gap mean,asparate aminotransferase mean,basophils mean,...,troponin-t mean,venous pvo2 mean,weight mean,white blood cell count mean,white blood cell count urine mean,ph mean,ph urine mean,gender,ethnicity,age
0,3,25.0,1.8,0.0,0.0,0.0,73.0,20.666667,69.0,0.0,...,0.0,0.0,0.0,14.842857,0.0,7.40,5.0,28,6,46
1,3,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.0,...,0.0,0.0,0.0,0.000000,0.0,0.00,0.0,28,6,46
2,3,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.0,...,0.0,0.0,0.0,0.000000,0.0,7.26,0.0,28,6,46
3,3,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.0,...,0.0,0.0,0.0,0.000000,0.0,0.00,0.0,28,6,46
4,3,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.0,...,0.0,0.0,0.0,0.000000,0.0,0.00,0.0,28,6,46
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
808534,99999,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.0,...,0.0,0.0,0.0,0.000000,0.0,0.00,0.0,10,5,46
808535,99999,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.0,...,0.0,0.0,0.0,0.000000,0.0,0.00,0.0,10,5,46
808536,99999,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.0,...,0.0,0.0,0.0,0.000000,0.0,0.00,0.0,10,5,46
808537,99999,0.0,0.0,0.0,0.0,0.0,0.0,0.000000,0.0,0.0,...,0.0,0.0,0.0,0.000000,0.0,0.00,0.0,10,5,46


In [52]:
len(vitals_columns)

104

In [None]:
demo_vseqs = []
for pair in pairings:
    v_zero_demo1 = vitals_zero_demo[vitals_zero_demo.subject_id == pair[0]]
    v_zero_demo1 = v_zero_demo1[vitals_columns + demo_columns].values
    
    v_zero_demo2 = vitals_zero_demo[vitals_zero_demo.subject_id == pair[1]]
    v_zero_demo2 = v_zero_demo2[vitals_columns + demo_columns].values
    demo_vseqs.append((v_zero_demo1,v_zero_demo2))
    

In [58]:
with open("demograph_vitals_zero_SDPRL.pkl","wb") as f:
    pickle.dump(demo_vseqs,f)