In [1]:
import os
import pickle

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.layers import Input, Dense, GRU, Embedding, Dropout, Masking, GlobalMaxPooling1D, TimeDistributed
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import Callback
import tensorflow.keras.backend as K
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split
from tqdm import tqdm

In [2]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

tf.test.is_gpu_available()

True

# Preprocessing

## Load Data

In [3]:
admissions = pd.read_csv('../data/original/ADMISSIONS.csv')
diagnoses = pd.read_csv('../data/original/DIAGNOSES_ICD.csv', dtype={'ICD9_CODE': str})

## Filtering longitudinal

Only keep patients with 2 or more visits where they were given a diagnosis.

In [4]:
diag_freqs = (
    diagnoses.dropna()
    .loc[:, ['SUBJECT_ID', 'HADM_ID']]
    .drop_duplicates()
    .SUBJECT_ID
    .value_counts()
)
selected_subjs = diag_freqs[diag_freqs >= 2].index.values
selected_subjs.shape

(7499,)

In [5]:
admissions_long = admissions.loc[admissions.SUBJECT_ID.isin(selected_subjs)]
admissions_long.shape

(19917, 20)

In [6]:
diagnoses_long = (
    diagnoses.loc[diagnoses.SUBJECT_ID.isin(selected_subjs)]
    .loc[:, ["SUBJECT_ID", "HADM_ID", "ICD9_CODE"]]
    .dropna()
    .merge(admissions_long[['HADM_ID', 'ADMITTIME']], how='left', on=['HADM_ID'])
    .astype({'ADMITTIME': 'datetime64[ns]', 'ICD9_CODE': str})
    .sort_values(['SUBJECT_ID', 'ADMITTIME'])
)

print(diagnoses_long.shape)
diagnoses_long.head()

(259995, 4)


Unnamed: 0,SUBJECT_ID,HADM_ID,ICD9_CODE,ADMITTIME
627,17,194023,7455,2134-12-27 07:15:00
628,17,194023,45829,2134-12-27 07:15:00
629,17,194023,V1259,2134-12-27 07:15:00
630,17,194023,2724,2134-12-27 07:15:00
619,17,161087,4239,2135-05-09 14:11:00


## Understand ICD9 for MIMIC-III

It seems like the diagnoses code length varies, and does not follow the same convention as ICD-9 exactly (there is no dot delimiting group from actual values, and length of each code seems to vary.

In [7]:
for i in range(3, 6):
    mask = diagnoses_long.ICD9_CODE.apply(len) == i
    print(diagnoses_long[mask].shape)
    print(diagnoses_long[mask].ICD9_CODE.sort_values().unique()[:50])

(10457, 4)
['035' '042' '075' '096' '118' '135' '138' '185' '193' '217' '220' '226'
 '243' '260' '261' '262' '267' '311' '316' '317' '319' '325' '326' '340'
 '390' '412' '430' '431' '436' '449' '452' '462' '463' '470' '475' '481'
 '485' '486' '490' '496' '500' '501' '502' '514' '515' '538' '541' '542'
 '566' '570']
(118808, 4)
['0030' '0049' '0051' '0074' '0085' '0088' '0090' '0091' '0092' '0093'
 '0239' '0270' '0272' '0310' '0311' '0312' '0319' '0329' '0338' '0340'
 '0360' '0380' '0382' '0383' '0388' '0389' '0391' '0398' '0400' '0412'
 '0413' '0414' '0415' '0416' '0417' '0419' '0463' '0470' '0478' '0479'
 '0490' '0491' '0498' '0499' '0521' '0529' '0530' '0539' '0542' '0543']
(130730, 4)
['00581' '00841' '00843' '00845' '00847' '00863' '00869' '01085' '01190'
 '01194' '01215' '01300' '01325' '01330' '01803' '01805' '01880' '01890'
 '01894' '01895' '01896' '03285' '03289' '03810' '03811' '03812' '03819'
 '03840' '03842' '03843' '03844' '03849' '04082' '04100' '04101' '04102'
 '04103' '0

We observe that the length of each code varies depending on the values after the **dot** in the ICD-9 code. If the length is 3, then it is implictly a 0 (e.g. `'035'` indicates the code 035.0), and length of 4 means there is only one value after the dot (`0049` indicates 004.9), and length of 5 means there are 2 values after the dot.

## Clean ICD9 Code

We create a function to correct the ICD-9 code by adding a dot at the right place. This will usually be after the 3rd digit, except for "External causes of injury (V)", which is 4-digits. We also add the implicit 0 whenever needed.

In [8]:
def correct_icd9(code):
    if 'E' in code:
        return code[:4] + '.' + code[4:]
    
    if len(code) == 3:
        code += '0'
    
    return code[:3] + '.' + code[3:]

diagnoses_long.ICD9_CODE = diagnoses_long.ICD9_CODE.apply(correct_icd9)
diagnoses_long.head()

Unnamed: 0,SUBJECT_ID,HADM_ID,ICD9_CODE,ADMITTIME
627,17,194023,745.5,2134-12-27 07:15:00
628,17,194023,458.29,2134-12-27 07:15:00
629,17,194023,V12.59,2134-12-27 07:15:00
630,17,194023,272.4,2134-12-27 07:15:00
619,17,161087,423.9,2135-05-09 14:11:00


## Grouping diagnosis by admission

In [9]:
grouped_diag = (
    diagnoses_long.groupby(['SUBJECT_ID', 'HADM_ID'])
    .apply(lambda df: df.ICD9_CODE.tolist())
    .reset_index()
    .rename({0: 'DIAGNOSES'}, axis=1)
)

grouped_diag['DIAGNOSES_STR'] = grouped_diag.DIAGNOSES.apply(lambda ls: " ".join(ls))
print(grouped_diag.shape)

grouped_diag.head()

(19911, 4)


Unnamed: 0,SUBJECT_ID,HADM_ID,DIAGNOSES,DIAGNOSES_STR
0,17,161087,"[423.9, 511.9, 785.51, 458.9, 311.0, 722.0, 71...",423.9 511.9 785.51 458.9 311.0 722.0 719.46 272.4
1,17,194023,"[745.5, 458.29, V12.59, 272.4]",745.5 458.29 V12.59 272.4
2,21,109451,"[410.71, 785.51, 578.1, 584.9, 403.91, 428.0, ...",410.71 785.51 578.1 584.9 403.91 428.0 459.2 5...
3,21,111970,"[038.8, 785.52, 403.91, 427.31, 707.09, 511.9,...",038.8 785.52 403.91 427.31 707.09 511.9 682.3 ...
4,23,124321,"[225.2, 348.5, 780.39, 424.1, 401.9, 272.0, 27...",225.2 348.5 780.39 424.1 401.9 272.0 272.4 V45...


In [10]:
grouped_diag['GROUP'] = (
    grouped_diag.DIAGNOSES
    .apply(lambda ls: [code.split('.')[0] for code in ls])
)
grouped_diag.head()

Unnamed: 0,SUBJECT_ID,HADM_ID,DIAGNOSES,DIAGNOSES_STR,GROUP
0,17,161087,"[423.9, 511.9, 785.51, 458.9, 311.0, 722.0, 71...",423.9 511.9 785.51 458.9 311.0 722.0 719.46 272.4,"[423, 511, 785, 458, 311, 722, 719, 272]"
1,17,194023,"[745.5, 458.29, V12.59, 272.4]",745.5 458.29 V12.59 272.4,"[745, 458, V12, 272]"
2,21,109451,"[410.71, 785.51, 578.1, 584.9, 403.91, 428.0, ...",410.71 785.51 578.1 584.9 403.91 428.0 459.2 5...,"[410, 785, 578, 584, 403, 428, 459, 507, 427, ..."
3,21,111970,"[038.8, 785.52, 403.91, 427.31, 707.09, 511.9,...",038.8 785.52 403.91 427.31 707.09 511.9 682.3 ...,"[038, 785, 403, 427, 707, 511, 682, 998, 008, ..."
4,23,124321,"[225.2, 348.5, 780.39, 424.1, 401.9, 272.0, 27...",225.2 348.5 780.39 424.1 401.9 272.0 272.4 V45...,"[225, 348, 780, 424, 401, 272, 272, V45, V45, ..."


## Tokenize diagnoses

In [11]:
tokenizer = Tokenizer(filters="", lower=False)
tokenizer.fit_on_texts(grouped_diag.DIAGNOSES_STR)

grouped_diag['SEQUENCES'] = tokenizer.texts_to_sequences(grouped_diag.DIAGNOSES_STR)
grouped_diag.head()

Unnamed: 0,SUBJECT_ID,HADM_ID,DIAGNOSES,DIAGNOSES_STR,GROUP,SEQUENCES
0,17,161087,"[423.9, 511.9, 785.51, 458.9, 311.0, 722.0, 71...",423.9 511.9 785.51 458.9 311.0 722.0 719.46 272.4,"[423, 511, 785, 458, 311, 722, 719, 272]","[260, 34, 125, 46, 25, 1338, 878, 7]"
1,17,194023,"[745.5, 458.29, V12.59, 272.4]",745.5 458.29 V12.59 272.4,"[745, 458, V12, 272]","[271, 58, 384, 7]"
2,21,109451,"[410.71, 785.51, 578.1, 584.9, 403.91, 428.0, ...",410.71 785.51 578.1 584.9 403.91 428.0 459.2 5...,"[410, 785, 578, 584, 403, 428, 459, 507, 427, ...","[41, 125, 157, 4, 21, 2, 472, 29, 3, 64, 5, 6,..."
3,21,111970,"[038.8, 785.52, 403.91, 427.31, 707.09, 511.9,...",038.8 785.52 403.91 427.31 707.09 511.9 682.3 ...,"[038, 785, 403, 427, 707, 511, 682, 998, 008, ...","[465, 35, 21, 3, 434, 34, 553, 77, 57, 742, 19..."
4,23,124321,"[225.2, 348.5, 780.39, 424.1, 401.9, 272.0, 27...",225.2 348.5 780.39 424.1 401.9 272.0 272.4 V45...,"[225, 348, 780, 424, 401, 272, 272, V45, V45, ...","[607, 208, 56, 53, 1, 15, 7, 17, 448, 36]"


## Create Label encoder

In [12]:
label_enc = MultiLabelBinarizer()
label_enc.fit(grouped_diag.GROUP)

MultiLabelBinarizer(classes=None, sparse_output=False)

## Grouping by subjects

In [13]:
grouped_subjs = (
    grouped_diag.loc[:, ['SUBJECT_ID', 'SEQUENCES', 'GROUP']]
    .groupby('SUBJECT_ID')
    .agg({
        'GROUP': lambda groups: groups.tolist(),
        'SEQUENCES': lambda seqs: seqs.tolist()
    })
)

print(grouped_subjs.shape)
grouped_subjs.head()

(7499, 2)


Unnamed: 0_level_0,GROUP,SEQUENCES
SUBJECT_ID,Unnamed: 1_level_1,Unnamed: 2_level_1
17,"[[423, 511, 785, 458, 311, 722, 719, 272], [74...","[[260, 34, 125, 46, 25, 1338, 878, 7], [271, 5..."
21,"[[410, 785, 578, 584, 403, 428, 459, 507, 427,...","[[41, 125, 157, 4, 21, 2, 472, 29, 3, 64, 5, 6..."
23,"[[225, 348, 780, 424, 401, 272, 272, V45, V45,...","[[607, 208, 56, 53, 1, 15, 7, 17, 448, 36], [5..."
34,"[[E879, 410, 428, 425, 427, 997, 426, 414], [4...","[[337, 41, 2, 50, 3, 70, 289, 5], [48, 72, 427..."
36,"[[998, 998, 415, 453, 996, 496, 414, V45, 401,...","[[473, 89, 679, 171, 274, 14, 5, 17, 1, 90, 10..."


## Moving target forward in time

Since we are predicting the labels in the next admission, we are removing the first label from each subject records, as well as remove last visits. This will effectively move target forward in time by one time step, since by removing the first label the first visit is now matched with the second visit's group (which is what we are trying to predict).

Keep in mind that visits here has already been embedded into count vectors (vectors of integer associated with an index in the `word_index` of our keras `Tokenizer`).

In [14]:
visits = grouped_subjs.SEQUENCES.tolist()
labels = grouped_subjs.GROUP.tolist()

visits = [visit[:-1] for visit in visits]
labels = [label[1:] for label in labels]

## Pad Visits

Since we have a list of list of list, we need to pad it on multiple level.

In [15]:
n_visits_max = max([len(visit) for visit in visits])
n_diagnosis_max = max([max(map(len, visit)) for visit in visits])

print("Max number of visits per patient:", n_visits_max)
print("Max number of diagnosis in a visit:", n_diagnosis_max)

Max number of visits per patient: 41
Max number of diagnosis in a visit: 39


In [16]:
visits_padded = pad_sequences([
    pad_sequences(visit, maxlen=n_diagnosis_max) 
    for visit in visits
])

visits_padded.shape

(7499, 41, 39)

## Pad labels

In [17]:
labels_padded = pad_sequences([label_enc.transform(label) for label in labels])
labels_padded.shape

(7499, 41, 939)

## Train test split

In [18]:
train_visits, test_visits, train_labels, test_labels = train_test_split(
    visits_padded, labels_padded, test_size=0.1, random_state=2019)

print(train_visits.shape)
print(test_visits.shape)
print(train_labels.shape)
print(test_labels.shape)

(6749, 41, 39)
(750, 41, 39)
(6749, 41, 939)
(750, 41, 939)


# Modelling

## Evaluation metric

In [19]:
def top_k_recall(y_true, y_pred, use_tqdm=True, k=30):
    pred_flat = y_pred.reshape(-1, 939)
    true_flat = y_true.reshape(-1, 939)
    
    all_patients_recall = []

    for adm_idx in tqdm(range(true_flat.shape[0]), disable=not use_tqdm):
        true_indices = np.argwhere(true_flat[adm_idx] == 1).reshape(-1)

        # If this admission does not have any diagnosis, then it
        # is a dummy admission created by padding from keras
        if true_indices.shape[0] > 0:
            pred_indices = pred_flat[adm_idx].argsort()[-k:]

            intersection_count = len(np.intersect1d(pred_indices, true_indices))

            recall = intersection_count / len(true_indices)
            all_patients_recall.append(recall)

    all_patients_recall = np.array(all_patients_recall)
    return all_patients_recall

## Custom callback to monitor top-k recall

In [20]:
class TopKRecallCallback(Callback):
    def __init__(self, X_test, y_test):
        super().__init__()
        self.X_test = X_test
        self.y_test = y_test
        
    def on_train_begin(self, logs={}):
        self.test_recalls_at_10 = []
        self.test_recalls_at_20 = []
        self.test_recalls_at_30 = []

    def on_epoch_end(self, epoch, logs={}):
        y_pred = self.model.predict(self.X_test, batch_size=256)
        y_test = self.y_test
        
        _test_recall_at_10 = top_k_recall(y_test, y_pred, k=10, use_tqdm=False).mean()
        _test_recall_at_20 = top_k_recall(y_test, y_pred, k=20, use_tqdm=False).mean()
        _test_recall_at_30 = top_k_recall(y_test, y_pred, k=30, use_tqdm=False).mean()
        
        self.test_recalls_at_10.append(_test_recall_at_10)
        self.test_recalls_at_20.append(_test_recall_at_20)
        self.test_recalls_at_30.append(_test_recall_at_30)

        print(f"\ntest_top_k_recall: {_test_recall_at_10:.4f}@10; {_test_recall_at_20:.4f}@20; {_test_recall_at_30:.4f}@30")

        return

## Building Model

In [22]:
def build_model(num_words, embedding_dim=200):
    input1 = Input(shape=(None, None))
    x = Embedding(num_words, embedding_dim, mask_zero=False)(input1)
    x = K.sum(x, axis=2)
    x = Masking(0)(x)
    
    x = GRU(200, return_sequences=True)(x)
    x = GRU(200, return_sequences=True)(x)
    
    x = Dense(300, activation='relu')(x)
    x = Dropout(0.5)(x)
    x = Dense(939, activation='softmax')(x)

    model = Model(inputs=input1, outputs=x)
    model.compile('adam', loss='binary_crossentropy')
    
    return model

In [23]:
num_words = len(tokenizer.word_index) + 1
model = build_model(num_words)
model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, None, None)]      0         
_________________________________________________________________
embedding (Embedding)        (None, None, None, 200)   976200    
_________________________________________________________________
tf_op_layer_Sum (TensorFlowO [(None, None, 200)]       0         
_________________________________________________________________
masking (Masking)            (None, None, 200)         0         
_________________________________________________________________
gru (GRU)                    (None, None, 200)         241200    
_________________________________________________________________
gru_1 (GRU)                  (None, None, 200)         241200    
_________________________________________________________________
dense (Dense)                (None, None, 300)         60300 

In [24]:
callbacks = [
    TopKRecallCallback(X_test=test_visits, y_test=test_labels)
]

model.fit(
    train_visits, 
    train_labels, 
    batch_size=64, 
    epochs=10, 
    verbose=1,
    callbacks=callbacks,
    validation_split=0.1
)

Train on 6074 samples, validate on 675 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7fa0681a9950>

# Evaluation

## Eval on train

In [25]:
%%time
with tf.device('/gpu:0'):
    train_pred = model.predict(train_visits, batch_size=512)

CPU times: user 24.1 s, sys: 49.2 s, total: 1min 13s
Wall time: 5.52 s


In [26]:
all_patients_recall = top_k_recall(y_true=train_labels, y_pred=train_pred)
all_patients_recall.mean()

100%|██████████| 276709/276709 [00:03<00:00, 71779.65it/s]


0.5649448928782117

## Eval on test

In [27]:
%%time
with tf.device('/gpu:0'):
    test_pred = model.predict(test_visits, batch_size=512)

CPU times: user 1.9 s, sys: 3.46 s, total: 5.35 s
Wall time: 333 ms


In [28]:
all_patients_recall = top_k_recall(y_true=test_labels, y_pred=test_pred)
all_patients_recall.mean()

100%|██████████| 30750/30750 [00:00<00:00, 58337.70it/s]


0.565019502320133

# Saving Data

In [30]:
admissions_long.to_csv('../generated_data/ADMISSIONS_LONGITUDINAL.CSV')