# Tutorial for Med-BERT
Step-by-step guide to the PyTorch implementation of [Med-BERT](https://www.nature.com/articles/s41746-021-00455-y) 

<sub><sup><em>"Med-BERT: pretrained contextualized embeddings on large-scale structured electronic health records for disease prediction." NPJ digital medicine 4.1 (2021): 1-13., Rasmy, Laila, et al. <em><sub><sup>

------------------------------------------

## Introduction

The goal of Med-Bert is to obtain good representations of elctronic health records (EHR) to make predicitons for downstream tasks. <br>
In order to do so we leverage the power of the pretraining fine-tuning paradigm using a transformer architecture $^{1}$.  
Originally used for Natural Language Processing, the transformers have proven their universality by showing SoTA results in fields like computer vision $^2$ and speech recognition $^3$. <br>
Recently, a variant of the transformers, called BERT $^{4}$ has also been applied to medical data and electronic health records in particular $^{5-7}$.<br> 
There are countless tutorials that explain the theory and basic concepts behind the Transformers and BERT as well as their applicaiton to NLP, so here we will focus on using BERT for EHR specifically.


## 1. Data Preparation


You can obtain synthetic data from: https://github.com/synthetichealth/synthea which can be processed (from csv) with https://github.com/kirilklein/ehr_preprocess.<br>
This will produce a format ready to be turned into input features for our pipeline by main_data_pretrain.py (see data/raw/synthea500).<br>
Here, we will start using the formatted data stored in: data/processed/pretrain/synthea500.<br>


In [96]:
import torch
# used to create configs
class DotDict(dict):
    def __getattr__(self, attr):
        return self[attr]

### Tokenization

We start with features stored in a dictionary. Inside that dictionary we have a list of lists, where every inner list represents one patient.

In [98]:
train_features = torch.load("data/processed/pretrain/synthea500/train.pt")
val_features = torch.load("data/processed/pretrain/synthea500/val.pt")
print("Here we can see the SNOMED-CT codes for patient number 4 in the validation set: ")
print(val_features['concept'][4][:12]) 

Here we can see the SNOMED-CT codes for patient number 4 in the validation set: 
['314529007', '314529007', '314529007', '314529007', '314529007', '314529007', '314529007', '314529007', '314529007', '314529007', '314529007', '314529007']


Now, we assign integers to each unique code and bring all sequences to the same length to pass data as a tensor.<br>
Alternatively, we can use dynamic padding during training. The advantage of this is that we pad to the longest sequence in the batch instead of the longest sequence in the data (or truncation length). Because the sequence length contributes quadratically to training time this can accelerate training.<br>
Additional tokens are needed to **mask** inputs and to take care of new codes that are not in the vocab dictionary that might appear in the future.

In [99]:
from medbert.features.tokenizer import EHRTokenizer
tokenizer_config = DotDict({
    'sep_tokens': True, # should we add [SEP] tokens?
    'cls_token': True, # should we add a [CLS] token?
    'padding': True, # should we pad the sequences?
    'truncation': 100}) # how long should the longest sequence be
tokenizer = EHRTokenizer(config=tokenizer_config)
train_tokenized = tokenizer(train_features) 
print(train_tokenized['concept'][:3])

Encoding patients: 0it [00:00, ?it/s]

Encoding patients: 391it [00:00, 2580.67it/s]


tensor([[ 3,  5,  6,  4,  6,  7,  4,  8,  9,  4,  9,  4,  7,  9,  4,  9,  6,  4,
          9,  4,  9,  4,  9,  4,  9,  4,  9, 10,  4,  9,  4,  9,  4, 11,  9,  4,
          9, 12,  6,  4, 13,  9, 10,  4,  9,  7,  4, 11,  9,  4,  6, 11,  9,  4,
          9, 14,  4, 11,  7,  9,  4, 15, 16, 17, 18, 19,  4,  9,  6,  4, 11, 20,
          9,  4,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 3, 21,  4,  5,  7,  4, 10,  7,  4, 22,  4,  7,  4, 23, 24, 20,  4, 25,
          4, 26,  4,  6,  4, 27, 28, 29, 30,  4, 11, 27, 12,  4, 11, 27,  4, 11,
         27,  4, 11, 27,  4, 27, 11,  6,  4, 11, 31, 27,  4, 11, 27,  4, 32, 33,
         34,  4, 27, 12,  6,  4, 27, 35,  4, 36,  4, 37, 38, 39,  4, 27, 11,  4,
         36,  4,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 3, 40,  9, 41, 42, 11, 43, 14,  4, 44,  9, 41, 42, 13,  6,  4, 31,  7,
         

After tokenizing the train set, one should freeze and save the vocabulary.

In [100]:
tokenizer.freeze_vocabulary()
# tokenizer.save_vocab("../output/vocabulary.pt")
print({k:v for k,v in tokenizer.vocabulary.items() if v < 10})
# Now we tokenize the validation set:
val_tokenized = tokenizer(val_features)

{'[PAD]': 0, '[MASK]': 1, '[UNK]': 2, '[CLS]': 3, '[SEP]': 4, '224299000': 5, '73595000': 6, '160903007': 7, '59621000': 8, '314076': 9}


Encoding patients: 112it [00:00, 4259.06it/s]


### Dataset Creation

The dataset takes care of multiple things:
* masking input tokens and creating an appropriate target
* constructing an attention mask to ignore padding tokens 
* It allows the model to get an item (patient) via __getitem__ 

In [133]:
from medbert.features import dataset

dataset_config = DotDict({
    'masked_ratio': 0.30, # 0.15 usually
    'ignore_special_tokens': True,
})
train_dataset = dataset.MLM_PLOS_Dataset(train_tokenized, vocabulary=tokenizer.vocabulary, min_los=3,**dataset_config)
val_dataset = dataset.MLM_PLOS_Dataset(val_tokenized, vocabulary=tokenizer.vocabulary,  min_los=3, **dataset_config) 
patient = train_dataset[0] 

The masking is adapted from BERT:
* Select len(sequence)*masked_ratio tokens
* Replace 80% of the selected tokens with 1 (mask token) 
* Replace 10% of the selected tokens with a random token
* Keep the remaining 10% unchanged
* The target consists of the original tokens in the place of the selected tokens. The rest is ignored, filling it with an ignore index e.g. -100

When comparing the original tokenized features and the masked features, you can see that some integers were replaced by 1 (mask token). Others were replaced by random tokens and some are left unchanged

In [134]:
mask_mask = patient['target']!= -100 # mask
train_tokenized['concept'][0][mask_mask] # original sequence

tensor([ 7,  8,  7,  6,  9,  9, 10,  6,  9,  7, 11,  9, 11,  9,  9])

In [135]:
patient['concept'][mask_mask] # masked sequence

tensor([ 64,   1,   7,   6,   1, 166,  67,   1,   1,   1,  11,  98,  11,   1,
        134])

In [136]:
print(patient['target'][mask_mask]) # target

tensor([ 7,  8,  7,  6,  9,  9, 10,  6,  9,  7, 11,  9, 11,  9,  9])


### Setup model and optimizer

In [140]:
from transformers import BertForPreTraining, BertConfig
from torch.optim import AdamW

num_attention_heads = 3
hidden_size = num_attention_heads*10 # for parallel computation
intermediate_size = hidden_size*4 # from original paper

model = BertForPreTraining(
        BertConfig(
            vocab_size=len(train_dataset.vocabulary),
            type_vocab_size=int(train_dataset.max_segments),
            max_position_embeddings= 1024, # used for sequence embeddings
            hidden_size= hidden_size,
            num_hidden_layers= 3,
            linear=True,
            num_attention_heads= num_attention_heads,
            intermediate_size= intermediate_size
        )
    )

In [141]:
optimizer =  AdamW(
        model.parameters(),
        lr=1e-4,
        weight_decay= 0.01,
    )

In [142]:
from medbert.trainer import trainer
import importlib
importlib.reload(trainer)

ehr_trainer = trainer.EHRTrainer(
    model=model,
    optimizer=optimizer,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    args=DotDict({
        'epochs': 2,
        'batch_size': 64,
        'effective_batch_size': 128,
    }),
    cfg=DotDict({'scheduler': None,
                 'collate_fn': None,
                 'run_name':'ptest'}))
ehr_trainer.train()

[INFO] Run name not provided. Using random run name: 3c5b0d19972b43909606e7d8dc518858
[INFO] Run folder: ../runs\3c5b0d19972b43909606e7d8dc518858


Train 0:   0%|          | 0/7 [00:00<?, ?it/s]

Train 0:  29%|██▊       | 2/7 [00:00<00:02,  2.42it/s, loss=6.63]

Train loss 1: 6.631284713745117


Train 0:  57%|█████▋    | 4/7 [00:01<00:01,  2.79it/s, loss=6.63]

Train loss 2: 6.633287668228149


Train 0: 100%|██████████| 7/7 [00:02<00:00,  3.11it/s, loss=6.63]


Train loss 3: 6.633130073547363


Validation: 100%|██████████| 2/2 [00:00<00:00,  5.88it/s]


[INFO] Epoch 0 train loss: 5.685057844434466
[INFO] Epoch 0 val loss: 6.625337839126587
[INFO] Epoch 0 metrics: None



Train 1:  29%|██▊       | 2/7 [00:00<00:01,  3.27it/s, loss=6.62]

Train loss 1: 6.62165904045105


Train 1:  57%|█████▋    | 4/7 [00:01<00:00,  3.11it/s, loss=6.62]

Train loss 2: 6.615631818771362


Train 1: 100%|██████████| 7/7 [00:02<00:00,  3.36it/s, loss=6.61]


Train loss 3: 6.612935304641724


Validation: 100%|██████████| 2/2 [00:00<00:00,  6.17it/s]

[INFO] Epoch 1 train loss: 5.671493189675467
[INFO] Epoch 1 val loss: 6.6075599193573
[INFO] Epoch 1 metrics: None






### References

<sub><sup>[1] Vaswani, Ashish, et al. "Attention is all you need." Advances in neural information processing systems 30 (2017).</sub></sup> <br>
<sub><sup>[2] Dosovitskiy, Alexey, et al. "An image is worth 16x16 words: Transformers for image recognition at scale." arXiv preprint arXiv:2010.11929 (2020).</sub></sup> <br>
<sub><sup>[3] Polyak, Adam, et al. "Speech resynthesis from discrete disentangled self-supervised representations." arXiv preprint arXiv:2104.00355 (2021).</sub></sup><br>
<sub><sup>[4] Devlin, Jacob, et al. "Bert: Pre-training of deep bidirectional transformers for language understanding." arXiv preprint arXiv:1810.04805 (2018).</sub></sup><br>
<sub><sup>[5] Li, Yikuan, et al. "BEHRT: transformer for electronic health records." Scientific reports 10.1 (2020): 1-12..</sub></sup><br>
<sub><sup>[6] Shang, Junyuan, et al. "Pre-training of graph augmented transformers for medication recommendation." arXiv preprint arXiv:1906.00346 (2019).</sub></sup><br>
<sub><sup>[7] Pang, Chao, et al. "CEHR-BERT: Incorporating temporal information from structured EHR data to improve prediction tasks." Machine Learning for Health. PMLR, 2021.</sub></sup><br>