In [1]:
import pandas as pd
import numpy as np
import csv
import matplotlib.pyplot as plt
import seaborn as sns
import random
import re
import textwrap
from sklearn.metrics import classification_report
from pprint import pprint


import tensorflow as tf
from tensorflow import keras
from transformers import TFBertModel, BertTokenizerFast


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
diagnoses_train = pd.read_csv("drive/MyDrive/MIDS/medal_mimic_subset/diagnoses/train.csv")
diagnoses_validation = pd.read_csv("drive/MyDrive/MIDS/medal_mimic_subset/diagnoses/valid.csv")
diagnoses_test = pd.read_csv("drive/MyDrive/MIDS/medal_mimic_subset/diagnoses/test.csv")

In [4]:
total_abbreviations = pd.read_csv("drive/MyDrive/MIDS/medal_mimic_subset/total_abbreviations.csv")

In [5]:
print(len(diagnoses_train))
print(len(diagnoses_validation))
print(len(diagnoses_test))

61079
9258
14287


In [6]:
# diagnoses_train.head()
# print(diagnoses_train[diagnoses_train.SUBJECT_ID == 29487].TEXT[0])
print(diagnoses_train[diagnoses_train.SUBJECT_ID == 29487].TEXT.values[-1])
# diagnoses_train.HOSPITAL_EXPIRE_FLAG.value_counts()

Chief Complaint : 
   24 Hour Events : 
    -  BP stable off pressors ,  no bolus requirement
    -  UOP adequate ,  still net positive
    -  tolerated 8 - 28 for 5 hrs ,  ABGs improved
    -  admin Lasix 40mg IV this am
   Pt intubated ,  sedated ,  opening eyes but only intermittently responding
   to commands . 
   Allergies : 
   Iodine
   Anaphylaxis ; 
   Cipro  ( Oral )   ( Ciprofloxacin Hcl ) 
   Rash ; 
   Sulfonamides
   Rash ; 
   Morphine
   Nausea / Vomiting
   Codeine
   Nausea / Vomiting
   Levofloxacin
   Anaphylaxis ; 
   Last dose of Antibiotics : 
   Vancomycin  -  2163 - 1 - 19 12 : 45 PM
   Piperacillin / Tazobactam  ( Zosyn )   -  2163 - 1 - 22 04 : 20 AM
   Infusions : 
   Fentanyl  -  25 mcg / hour
   Other ICU medications : 
   Other medications : 
   Flowsheet Data as of  2163 - 1 - 22 06 : 28 AM
   Vital signs
   Hemodynamic monitoring
   Fluid balance
                                                                  24 hours
                                

In [7]:
abbreviations = set(total_abbreviations.abbreviation)

def has_any_abbreviation(text):
    words = set(text.split())
    found = words.intersection(abbreviations)
    return bool(found)

diagnoses_train_subset = diagnoses_train.loc[diagnoses_train.TEXT.apply(has_any_abbreviation)].reset_index(drop=True)
diagnoses_validation_subset = diagnoses_validation.loc[diagnoses_validation.TEXT.apply(has_any_abbreviation)].reset_index(drop=True)
diagnoses_test_subset = diagnoses_test.loc[diagnoses_test.TEXT.apply(has_any_abbreviation)].reset_index(drop=True)

In [8]:
print(f"Train: {len(diagnoses_train_subset)} out of {len(diagnoses_train)}")
print(f"Validation: {len(diagnoses_validation_subset)} out of {len(diagnoses_validation)}")
print(f"Test: {len(diagnoses_test_subset)} out of {len(diagnoses_test)}")

Train: 60444 out of 61079
Validation: 9147 out of 9258
Test: 14132 out of 14287


In [9]:
with open("drive/MyDrive/MIDS/medal_mimic_subset/diagnoses/diag_to_idx.csv") as f:
    diagnosis_to_idx = {diag: int(idx) for diag, idx in csv.reader(f)}
    idx_to_diagnosis = {int(idx): diag for diag, idx in diagnosis_to_idx.items()}

print(f"{len(diagnosis_to_idx)} unique diagnosis codes")

1204 unique diagnosis codes


In [10]:
model_checkpoint = 'NLP4H/ms_bert'
bert_tokenizer = BertTokenizerFast.from_pretrained(model_checkpoint)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/499 [00:00<?, ?B/s]

In [11]:
MAX_LENGTH = 512


def head_and_tail(tokens, total_size=MAX_LENGTH):
    """Build an array of size <total_size> using the head and tail ends of the input array"""
    half = int(total_size / 2)
    if len(tokens) <= total_size:
        return tokens
    head = tokens[:half]
    tail = tokens[-half:]
    return np.concatenate([head, tail])


def tokenize(texts):
    """
    Tokenize an array of text inputs for bert, and take the head and tail of
    each set of bert inputs. Returns a tuple of three arrays: input_ids, token_type_ids,
    and attention_mask.
    """
    input_ids_list = []
    token_type_ids_list = []
    attention_mask_list = []

    for text in texts:
        bert_output = bert_tokenizer(
            text, padding="max_length", return_tensors="tf", max_length=MAX_LENGTH
        )
        input_ids = bert_output["input_ids"][0]
        token_type_ids = bert_output["token_type_ids"][0]
        attention_mask = bert_output["attention_mask"][0]

        input_ids_list.append(head_and_tail(input_ids))
        token_type_ids_list.append(head_and_tail(token_type_ids))
        attention_mask_list.append(head_and_tail(attention_mask))

    return np.array(input_ids_list), np.array(token_type_ids_list), np.array(attention_mask_list)


def get_labels(df):
    """
    Get array of labels for each record in the dataframe as a one-hot encoded array, e.g.
    each ICD9 code appears as a 0 or 1 mapping to the diagnosis_labels array. If the ICD9
    code does not appear in the training set, it is ignored.
    """
    labels_by_hadm_id = {row.HADM_ID: row.ICD9_ID for _, row in df.iterrows()}
    output = np.zeros((len(df), len(diagnosis_to_idx)))
    for sample_idx, row in enumerate(df.iterrows()):
        for icd9 in row[1].ICD9_ID.split(';'):
            # skip codes that are not present in the train set
            grouped_icd9 = icd9[:4] if icd9.startswith("V") or icd9.startswith("E") else icd9[:3]
            if (index := diagnosis_to_idx.get(grouped_icd9)) is None:
                continue

            output[sample_idx, index] = 1

    return output


def preprocess_data(batch_df):
    return tokenize(batch_df.TEXT), get_labels(batch_df)


input_examples, label_examples = preprocess_data(diagnoses_train_subset[:3])
print("Labels:", [idx_to_diagnosis[idx] for idx in np.where(label_examples[1] == 1)[0]])
print("Inputs:", bert_tokenizer.decode(input_examples[0][1]))

Labels: ['584', '707', '428', '276', '403', '293', '585', '427', '530', '197', '486', '038', '785', '995', '518', '599', '424', 'V104', 'V105', '274', 'V436']
Inputs: [CLS] chief complaint : 24 hour events : ekg - at 2163 - 1 - 17 11 : 30 am - afib, rbbb, no acute st - t waves changes arterial line - start 2163 - 1 - 17 04 : 15 pm - ce negative - vanc added - urine grew out pseudomonas, did not double cover per id - mri hip ordered, not performed - uable to wean pressors overnight pt doing okay this am, still some l hip pain, but otherwise, breathing comfortably, no cp / sob / abd pain. allergies : iodine anaphylaxis ; cipro ( oral ) ( ciprofloxacin hcl ) rash ; sulfonamides rash ; morphine nausea / vomiting codeine nausea / vomiting levofloxacin anaphylaxis ; last dose of antibiotics : vancomycin - 2163 - 1 - 17 09 : 45 am piperacillin / tazobactam ( zosyn ) - 2163 - 1 - 18 04 : 12 am infusions : other icu medications : heparin tid, allopurinol, famotidine, tamoxifen, dilaudid, zofran

In [12]:
BATCH_SIZE = 32


class DataGeneratorFromDataframe(tf.keras.utils.Sequence):

    def __init__(self, df, batch_size=BATCH_SIZE, shuffle=True):
        self.df = df
        self.n_examples = len(df)
        self.batch_size = batch_size
        self.shuffle = shuffle

        # Initialize row order, call on_epoch_end to shuffle row indices
        self.row_order = np.arange(self.n_examples)
        self.on_epoch_end()

    def __len__(self):
        # Return the number of batches in the full dataset
        return self.n_examples // self.batch_size

    def __getitem__(self, idx):
        batch_start = idx * self.batch_size
        batch_end = (idx + 1) * self.batch_size

        # Indices to skip are the ones in the shuffled row_order before and
        # after the chunk we'll use for this batch
        batch_idx = self.row_order[batch_start : batch_end]
        batch_df = self.df.iloc[batch_idx, :].copy()
        batch_data = preprocess_data(batch_df)

        return batch_data

    def on_epoch_end(self):
        if self.shuffle:
            self.row_order = list(np.random.permutation(self.row_order))

In [13]:
# Create an instance of our data generator, for our training data file and size

train_data_generator = DataGeneratorFromDataframe(diagnoses_train_subset)
valid_data_generator = DataGeneratorFromDataframe(diagnoses_validation_subset)
test_data_generator = DataGeneratorFromDataframe(diagnoses_test_subset)

train_data_generator[0]

((array([[  101,  2516,  1024, ..., 22137,  1024,   102],
         [  101,  3438, 10930, ...,     0,     0,     0],
         [  101,  3515,  1061, ...,  1013,  1051,   102],
         ...,
         [  101,  2423,  1061, ...,     0,     0,     0],
         [  101, 18234,  3096, ...,     0,     0,     0],
         [  101, 13866,  2003, ...,  3431,  1012,   102]], dtype=int32),
  array([[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0]], dtype=int32),
  array([[1, 1, 1, ..., 1, 1, 1],
         [1, 1, 1, ..., 0, 0, 0],
         [1, 1, 1, ..., 1, 1, 1],
         ...,
         [1, 1, 1, ..., 0, 0, 0],
         [1, 1, 1, ..., 0, 0, 0],
         [1, 1, 1, ..., 1, 1, 1]], dtype=int32)),
 array([[0., 1., 1., ..., 0., 0., 0.],
        [0., 1., 0., ..., 0., 0., 0.],
        [1., 1., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 

In [14]:
def create_diagnosis_bert_model(
    model_checkpoint=model_checkpoint,
    n_classes=len(diagnosis_to_idx),
    hidden_size=256,
    dropout=0.3,
    learning_rate=0.00005,
):
    """
    Classification model with BERT that uses the Pooler Output for classification.
    """
    keras.backend.clear_session()

    input_ids = tf.keras.layers.Input(shape=(MAX_LENGTH,), dtype=tf.int64, name='input_ids_layer')
    token_type_ids = tf.keras.layers.Input(shape=(MAX_LENGTH,), dtype=tf.int64, name='token_type_ids_layer')
    attention_mask = tf.keras.layers.Input(shape=(MAX_LENGTH,), dtype=tf.int64, name='attention_mask_layer')

    bert_inputs = {'input_ids': input_ids,
                   'token_type_ids': token_type_ids,
                   'attention_mask': attention_mask}

    bert_model = TFBertModel.from_pretrained(model_checkpoint)
    bert_out = bert_model(bert_inputs)

    pooler_token = bert_out[1]

    hidden = tf.keras.layers.Dense(hidden_size, activation='relu', name='hidden_layer')(pooler_token)
    hidden = tf.keras.layers.Dropout(dropout)(hidden)

    classification = tf.keras.layers.Dense(n_classes, activation='sigmoid', name='classification_layer')(hidden)

    classification_model = tf.keras.Model(inputs=[input_ids, token_type_ids, attention_mask], outputs=[classification])

    classification_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
                                 loss=tf.keras.losses.BinaryCrossentropy(from_logits=False),
                                 metrics='accuracy')

    return classification_model

In [15]:
diagnosis_bert_model = create_diagnosis_bert_model(model_checkpoint='drive/MyDrive/MIDS/model_checkpoints/embeddings_ms/msbert_model_pretrained')
diagnosis_bert_model.summary()

All model checkpoint layers were used when initializing TFBertModel.

All the layers of TFBertModel were initialized from the model checkpoint at drive/MyDrive/MIDS/model_checkpoints/embeddings_ms/msbert_model_pretrained.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions without further training.


Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 attention_mask_layer (Inpu  [(None, 512)]                0         []                            
 tLayer)                                                                                          
                                                                                                  
 input_ids_layer (InputLaye  [(None, 512)]                0         []                            
 r)                                                                                               
                                                                                                  
 token_type_ids_layer (Inpu  [(None, 512)]                0         []                            
 tLayer)                                                                                      

In [16]:
checkpoint_dir = 'drive/MyDrive/MIDS/model_checkpoints/diagnosis_ms_downstream/'
checkpoint_filepath = checkpoint_dir + 'weights.{epoch:02d}-{val_accuracy:.2f}.model.keras'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=False,
)

In [17]:
diagnosis_history = diagnosis_bert_model.fit(
    train_data_generator,
    validation_data=valid_data_generator,
    batch_size=BATCH_SIZE,
    epochs=3,
    callbacks=[model_checkpoint_callback],
)

Epoch 1/3



Epoch 2/3
Epoch 3/3


In [18]:
import json
with open('drive/MyDrive/MIDS/model_checkpoints/diagnosis_ms_downstream/history.json', 'w') as f:
    json.dump(diagnosis_history.history, f)

In [None]:
# # Skip training by running this cell
# checkpoint_name = "weights.04-0.13.model.keras"
# diagnosis_bert_model = tf.keras.models.load_model(
#     f"{checkpoint_dir}{checkpoint_name}",
#     custom_objects={"TFBertModel": TFBertModel},
# )

In [19]:
def compute_top_k_recall(labels, predictions, k=10):
    # Get indices of top-k predictions
    idxs = np.argsort(predictions, axis=1)[:, ::-1][:, :k]

    # Gather top-k labels
    top_k_labels = labels[np.arange(labels.shape[0])[:, np.newaxis], idxs]

    # Compute recall for each sample
    sum_top_k_labels = np.sum(top_k_labels, axis=1)
    sum_labels = np.sum(labels, axis=1)
    recall_per_sample = sum_top_k_labels / sum_labels

    # Compute mean recall across all samples
    return np.mean(recall_per_sample)


def evaluate(model, test_data):
    top_5_recall = 0
    top_10_recall = 0
    top_30_recall = 0
    count = 0

    print(f"Evaluating top k recall of {len(test_data)} batches")

    for i, (inputs, labels) in enumerate(test_data, start=1):
        predictions = model.predict(inputs, verbose=0)
        top_5_recall += compute_top_k_recall(labels, predictions, k=5)
        top_10_recall += compute_top_k_recall(labels, predictions, k=10)
        top_30_recall += compute_top_k_recall(labels, predictions, k=30)
        count += 1
        if count % 10 == 0:
            print(f"{count} batches evaluated")

    return {
        "top_5_recall": top_5_recall / count,
        "top_10_recall": top_10_recall / count,
        "top_30_recall": top_30_recall / count,
    }

evaluate(diagnosis_bert_model, test_data_generator)

Evaluating top k recall of 441 batches
10 batches evaluated
20 batches evaluated
30 batches evaluated
40 batches evaluated
50 batches evaluated
60 batches evaluated
70 batches evaluated
80 batches evaluated
90 batches evaluated
100 batches evaluated
110 batches evaluated
120 batches evaluated
130 batches evaluated
140 batches evaluated
150 batches evaluated
160 batches evaluated
170 batches evaluated
180 batches evaluated
190 batches evaluated
200 batches evaluated
210 batches evaluated
220 batches evaluated
230 batches evaluated
240 batches evaluated
250 batches evaluated
260 batches evaluated
270 batches evaluated
280 batches evaluated
290 batches evaluated
300 batches evaluated
310 batches evaluated
320 batches evaluated
330 batches evaluated
340 batches evaluated
350 batches evaluated
360 batches evaluated
370 batches evaluated
380 batches evaluated
390 batches evaluated
400 batches evaluated
410 batches evaluated
420 batches evaluated
430 batches evaluated
440 batches evaluated


{'top_5_recall': 0.21063682174126674,
 'top_10_recall': 0.3462350079343712,
 'top_30_recall': 0.6173280643321545}