In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import random
import re
import textwrap
from sklearn.metrics import classification_report
from pprint import pprint


import tensorflow as tf
from tensorflow import keras
from transformers import BertTokenizer, TFBertModel


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
diagnoses_train = pd.read_csv("drive/MyDrive/MIDS/medal_mimic_subset/diagnoses/train.csv")
diagnoses_validation = pd.read_csv("drive/MyDrive/MIDS/medal_mimic_subset/diagnoses/valid.csv")
diagnoses_test = pd.read_csv("drive/MyDrive/MIDS/medal_mimic_subset/diagnoses/test.csv")

In [7]:
total_abbreviations = pd.read_csv("drive/MyDrive/MIDS/medal_mimic_subset/total_abbreviations.csv")

In [8]:
print(len(diagnoses_train))
print(len(diagnoses_validation))
print(len(diagnoses_test))

61079
9258
14287


In [9]:
# diagnoses_train.head()
# print(diagnoses_train[diagnoses_train.SUBJECT_ID == 29487].TEXT[0])
print(diagnoses_train[diagnoses_train.SUBJECT_ID == 29487].TEXT.values[-1])
# diagnoses_train.HOSPITAL_EXPIRE_FLAG.value_counts()

Chief Complaint : 
   24 Hour Events : 
    -  BP stable off pressors ,  no bolus requirement
    -  UOP adequate ,  still net positive
    -  tolerated 8 - 28 for 5 hrs ,  ABGs improved
    -  admin Lasix 40mg IV this am
   Pt intubated ,  sedated ,  opening eyes but only intermittently responding
   to commands . 
   Allergies : 
   Iodine
   Anaphylaxis ; 
   Cipro  ( Oral )   ( Ciprofloxacin Hcl ) 
   Rash ; 
   Sulfonamides
   Rash ; 
   Morphine
   Nausea / Vomiting
   Codeine
   Nausea / Vomiting
   Levofloxacin
   Anaphylaxis ; 
   Last dose of Antibiotics : 
   Vancomycin  -  2163 - 1 - 19 12 : 45 PM
   Piperacillin / Tazobactam  ( Zosyn )   -  2163 - 1 - 22 04 : 20 AM
   Infusions : 
   Fentanyl  -  25 mcg / hour
   Other ICU medications : 
   Other medications : 
   Flowsheet Data as of  2163 - 1 - 22 06 : 28 AM
   Vital signs
   Hemodynamic monitoring
   Fluid balance
                                                                  24 hours
                                

In [10]:
abbreviations = set(total_abbreviations.abbreviation)

def has_any_abbreviation(text):
    words = set(text.split())
    found = words.intersection(abbreviations)
    return bool(found)

diagnoses_train_subset = diagnoses_train.loc[diagnoses_train.TEXT.apply(has_any_abbreviation)].reset_index(drop=True)
diagnoses_validation_subset = diagnoses_validation.loc[diagnoses_validation.TEXT.apply(has_any_abbreviation)].reset_index(drop=True)
diagnoses_test_subset = diagnoses_test.loc[diagnoses_test.TEXT.apply(has_any_abbreviation)].reset_index(drop=True)

In [11]:
print(f"Train: {len(diagnoses_train_subset)} out of {len(diagnoses_train)}")
print(f"Validation: {len(diagnoses_validation_subset)} out of {len(diagnoses_validation)}")
print(f"Test: {len(diagnoses_test_subset)} out of {len(diagnoses_test)}")

Train: 60444 out of 61079
Validation: 9147 out of 9258
Test: 14132 out of 14287


In [12]:
model_checkpoint = 'bert-base-cased'
bert_tokenizer = BertTokenizer.from_pretrained(model_checkpoint)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [13]:
MAX_LENGTH = 512

def head_and_tail(tokens, total_size=MAX_LENGTH):
    """Build an array of size <total_size> using the head and tail ends of the input array"""
    half = int(total_size / 2)
    if len(tokens) <= total_size:
        return tokens
    head = tokens[:half]
    tail = tokens[-half:]
    return np.concatenate([head, tail])


def tokenize(texts):
    """
    Tokenize an array of text inputs for bert, and take the head and tail of
    each set of bert inputs. Returns a tuple of three arrays: input_ids, token_type_ids,
    and attention_mask.
    """
    input_ids_list = []
    token_type_ids_list = []
    attention_mask_list = []

    for text in texts:
        bert_output = bert_tokenizer(
            text, padding="max_length", return_tensors="tf", max_length=MAX_LENGTH
        )
        input_ids = bert_output["input_ids"][0]
        token_type_ids = bert_output["token_type_ids"][0]
        attention_mask = bert_output["attention_mask"][0]

        input_ids_list.append(head_and_tail(input_ids))
        token_type_ids_list.append(head_and_tail(token_type_ids))
        attention_mask_list.append(head_and_tail(attention_mask))

    return np.array(input_ids_list), np.array(token_type_ids_list), np.array(attention_mask_list)


examples = diagnoses_train_subset.TEXT[0:3].tolist()
input_ids, token_type_ids, attention_mask = tokenize(examples)

print(len(input_ids[2]))
print(examples[2])
print(bert_tokenizer.decode(input_ids[2]))


512
No significant events overnight
   Renal failure ,  Chronic  ( Chronic renal failure ,  CRF ,  Chronic kidney
   disease ) 
   Assessment : 
   U / O remains extremely low  ~  10 cc / hr ,  total body overloaded w /  4 + 
   pitting edema x all 4 extremities
   Action : 
   All meds renally dosed ,  no fluid boluses overnight
   Response : 
   Plan : 
   Cont to trend changes in BUN / CR ,  renally dose all meds ,  nephrology may
   need to re evaluate if urine output does not improve . 
   Pain control  ( acute pain ,  chronic pain ) 
   Assessment : 
   Sedated on fent / midaz grimaces during turns / repositioning
   Action : 
   Fent boluses prior to turning ,  lido patch off  @  00 : 00
   Response : 
   Continues to experience pain
   Plan : 
   Continue w /  current pain / sedation regimen ,  ortho consult to evaluate
   for septic L hip . 

[CLS] No significant events overnight Renal failure, Chronic ( Chronic renal failure, CRF, Chronic kidney disease ) Assessment : U / O r

In [14]:
mimic_train_subset_inputs = tokenize(diagnoses_train_subset.TEXT.tolist())
mimic_validation_subset_inputs = tokenize(diagnoses_validation_subset.TEXT.tolist())
mimic_test_subset_inputs = tokenize(diagnoses_test_subset.TEXT.tolist())

In [15]:
unique_icd9_codes = set()
for _, row in diagnoses_train_subset.iterrows():
    for code in row.ICD9_ID.split(";"):
        unique_icd9_codes.add(code)


diagnosis_labels = sorted(unique_icd9_codes)
diagnosis_labels_indexes = {label: i for i, label in enumerate(diagnosis_labels)}

print(f"{len(diagnosis_labels)} unique diagnosis codes")


3842 unique diagnosis codes


In [16]:
def get_labels(df):
    """
    Get array of labels for each record in the dataframe as a one-hot encoded array, e.g.
    each ICD9 code appears as a 0 or 1 mapping to the diagnosis_labels array. If the ICD9
    code does not appear in the training set, it is ignored.
    """
    labels_by_hadm_id = {row.HADM_ID: row.ICD9_ID for _, row in df.iterrows()}
    encoded_labels = {}
    output = np.zeros((len(df), len(diagnosis_labels)))
    for sample_idx, row in df.iterrows():
        for icd9 in row.ICD9_ID.split(';'):
            # skip codes that are not present in the train set
            if icd9 not in unique_icd9_codes:
                continue
            pos_idx = diagnosis_labels_indexes[icd9]
            output[sample_idx, pos_idx] = 1

    return output

diagnosis_labels_train = get_labels(diagnoses_train_subset)
diagnosis_labels_validation = get_labels(diagnoses_validation_subset)
diagnosis_labels_test = get_labels(diagnoses_test_subset)

In [17]:
print(f"Label indices: {np.where(diagnosis_labels_train[0] == 1)}")
first_label_index = int(np.where(diagnosis_labels_train[0] == 1)[0][0])
print("First label index:", first_label_index)
print(f"First label: {diagnosis_labels[first_label_index]}")


Label indices: (array([ 284,  296,  300,  567,  569,  722,  888,  891, 1547, 1558, 1577,
       1705, 1871, 2263, 2472, 3628, 3651, 3656, 3754, 3792, 3819]),)
First label index: 284
First label: 1970


In [18]:
def create_diagnosis_bert_model(
    model_checkpoint=model_checkpoint,
    n_classes=len(diagnosis_labels),
    hidden_size=201,
    dropout=0.3,
    learning_rate=0.00005,
):
    """
    Build a simple classification model with BERT. Use the Pooler Output for classification purposes.
    """
    input_ids = tf.keras.layers.Input(shape=(MAX_LENGTH,), dtype=tf.int64, name='input_ids_layer')
    token_type_ids = tf.keras.layers.Input(shape=(MAX_LENGTH,), dtype=tf.int64, name='token_type_ids_layer')
    attention_mask = tf.keras.layers.Input(shape=(MAX_LENGTH,), dtype=tf.int64, name='attention_mask_layer')

    bert_inputs = {'input_ids': input_ids,
                   'token_type_ids': token_type_ids,
                   'attention_mask': attention_mask}

    bert_model = TFBertModel.from_pretrained(model_checkpoint)
    bert_out = bert_model(bert_inputs)

    pooler_token = bert_out[1]

    hidden = tf.keras.layers.Dense(hidden_size, activation='relu', name='hidden_layer')(pooler_token)
    hidden = tf.keras.layers.Dropout(dropout)(hidden)

    classification = tf.keras.layers.Dense(n_classes, activation='sigmoid', name='classification_layer')(hidden)

    classification_model = tf.keras.Model(inputs=[input_ids, token_type_ids, attention_mask], outputs=[classification])

    classification_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
                                 loss=tf.keras.losses.BinaryCrossentropy(from_logits=False),
                                 metrics='accuracy')

    return classification_model

In [19]:
diagnosis_bert_model = create_diagnosis_bert_model()
diagnosis_bert_model.summary()

model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
- This IS expected if you are initializing TFBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFBertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions w

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 attention_mask_layer (Inpu  [(None, 512)]                0         []                            
 tLayer)                                                                                          
                                                                                                  
 input_ids_layer (InputLaye  [(None, 512)]                0         []                            
 r)                                                                                               
                                                                                                  
 token_type_ids_layer (Inpu  [(None, 512)]                0         []                            
 tLayer)                                                                                      

In [20]:
diagnosis_history = diagnosis_bert_model.fit(
    mimic_train_subset_inputs,
    diagnosis_labels_train,
    validation_data=(mimic_validation_subset_inputs, diagnosis_labels_validation),
    epochs=2
)

Epoch 1/2
Epoch 2/2
