# Tutorial for Med-BERT
Step-by-step guide to the PyTorch implementation of [Med-BERT](https://www.nature.com/articles/s41746-021-00455-y) 

<sub><sup><em>"Med-BERT: pretrained contextualized embeddings on large-scale structured electronic health records for disease prediction." NPJ digital medicine 4.1 (2021): 1-13., Rasmy, Laila, et al. <em><sub><sup>

------------------------------------------

## Introduction

The goal of Med-Bert is to obtain good representations of elctronic health records (EHR) to make predicitons for downstream tasks. <br>
In order to do so we leverage the power of the pretraining fine-tuning paradigm using a transformer architecture $^{1}$.  
Originally used for Natural Language Processing, the transformers have proven their universality by showing SoTA results in fields like computer vision $^2$ and speech recognition $^3$. <br>
Recently, a variant of the transformers, called BERT $^{4}$ has also been applied to medical data and electronic health records in particular $^{5-7}$.<br> 
There are countless tutorials that explain the theory and basic concepts behind the Transformers and BERT as well as their applicaiton to NLP, so here we will focus on using BERT for EHR specifically.


## 1. Data Preparation

A toy dataset is stored in data/processed/pretrain/synthea500

In [1]:
import torch
# used to create configs
class DotDict(dict):
    def __getattr__(self, attr):
        return self[attr]

### Tokenization
In this step we simply assign integers to each unique code and bring all sequences to the same length to pass data as a tensor.<br>
Alternatively, we can use dynamic padding during training. The advantage of this is that we pad to the longest sequence in the batch instead of the longest sequence in the data. Because the sequence length contributes quadratically to training time this can accelerate training.<br>
Additional tokens are needed to Mask inputs and to take care of new codes that are not in the vocab dictionary that might appear in the future.

Features are stored in a dictionary. Inside that dictionary we have a list of lists, where every inner list is a patient.

In [2]:
train_features = torch.load("data/processed/pretrain/synthea500/train.pt")
val_features = torch.load("data/processed/pretrain/synthea500/val.pt")
print(val_features['concept'][4][:12])

['314529007', '314529007', '314529007', '314529007', '314529007', '314529007', '314529007', '314529007', '314529007', '314529007', '314529007', '314529007']


In [3]:
from medbert.features.tokenizer import EHRTokenizer
tokenizer_config = DotDict({
    'sep_tokens': True, # should we add [SEP] tokens?
    'cls_token': True, # should we add a [CLS] token?
    'padding': True, # should we pad the sequences?
    'truncation': 100}) # how long should the longest sequence be
tokenizer = EHRTokenizer(config=tokenizer_config)
train_tokenized = tokenizer(train_features) 
print(train_tokenized['concept'][:3])

Encoding patients: 391it [00:00, 4035.80it/s]

tensor([[ 3,  5,  6,  4,  6,  7,  4,  8,  9,  4,  9,  4,  7,  9,  4,  9,  6,  4,
          9,  4,  9,  4,  9,  4,  9,  4,  9, 10,  4,  9,  4,  9,  4, 11,  9,  4,
          9, 12,  6,  4, 13,  9, 10,  4,  9,  7,  4, 11,  9,  4,  6, 11,  9,  4,
          9, 14,  4, 11,  7,  9,  4, 15, 16, 17, 18, 19,  4,  9,  6,  4, 11, 20,
          9,  4,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 3, 21,  4,  5,  7,  4, 10,  7,  4, 22,  4,  7,  4, 23, 24, 20,  4, 25,
          4, 26,  4,  6,  4, 27, 28, 29, 30,  4, 11, 27, 12,  4, 11, 27,  4, 11,
         27,  4, 11, 27,  4, 27, 11,  6,  4, 11, 31, 27,  4, 11, 27,  4, 32, 33,
         34,  4, 27, 12,  6,  4, 27, 35,  4, 36,  4, 37, 38, 39,  4, 27, 11,  4,
         36,  4,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 3, 40,  9, 41, 42, 11, 43, 14,  4, 44,  9, 41, 42, 13,  6,  4, 31,  7,
         




After tokenizing the train set, one should freeze and save the vocabulary.

In [4]:
tokenizer.freeze_vocabulary()
# tokenizer.save_vocab("../output/vocabulary.pt")
print({k:v for k,v in tokenizer.vocabulary.items() if v < 10})
# Now we tokenize the validation set:
val_tokenized = tokenizer(val_features)

{'[PAD]': 0, '[MASK]': 1, '[UNK]': 2, '[CLS]': 3, '[SEP]': 4, '224299000': 5, '73595000': 6, '160903007': 7, '59621000': 8, '314076': 9}


Encoding patients: 112it [00:00, 3881.08it/s]


### Dataset Creation

The dataset takes care of multiple things:
* masking input tokens and creating an appropriate target
* constructing an attention mask to ignore padding tokens 
* It allows the model to get an item (patient) via __getitem__ 

In [77]:
from medbert.features import dataset
import importlib
importlib.reload(dataset)

dataset_config = DotDict({
    'masked_ratio': 0.99, # 0.15 usually
    'ignore_special_tokens': True,
})

train_dataset = dataset.MLM_PLOS_Dataset(train_tokenized, vocabulary=tokenizer.vocabulary, dataset_config=dataset_config, min_los=3, masked_ratio=0.8)
val_dataset = dataset.MLM_PLOS_Dataset(val_tokenized, vocabulary=tokenizer.vocabulary, dataset_config=dataset_config, min_los=3, masked_ratio=0.8)

In [78]:
tokenizer.vocabulary['[MASK]']

1

In [79]:
train_tokenized['concept'][0]

tensor([ 3,  5,  6,  4,  6,  7,  4,  8,  9,  4,  9,  4,  7,  9,  4,  9,  6,  4,
         9,  4,  9,  4,  9,  4,  9,  4,  9, 10,  4,  9,  4,  9,  4, 11,  9,  4,
         9, 12,  6,  4, 13,  9, 10,  4,  9,  7,  4, 11,  9,  4,  6, 11,  9,  4,
         9, 14,  4, 11,  7,  9,  4, 15, 16, 17, 18, 19,  4,  9,  6,  4, 11, 20,
         9,  4,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0])

In [80]:
ds_sample =train_dataset[0] 

In [86]:
ds_sample['concept']

tensor([  3,   1,   1,   4,   1,   7,   4,   8,   1,   4,   1,   4,   1,   1,
          4,   9, 277,   4,   1,   4,   1,   4,   9,   4,   1,   4,   1,   1,
          4, 286,   4,  70,   4,  11,   9,   4,   9,   1,   6,   4,   1,   1,
        216,   4,   1,   1,   4,   1,   1,   4,   1,   1,   1,   4,   1,   1,
          4,  11,   1, 236,   4,   1,   1,   1,   1,   1,   4,   1,   1,   4,
         11,   1,   1,   4,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0])

In [87]:
mask_mask = ds_sample['concept']==1
print(ds_sample['target'][mask_mask])

tensor([ 5,  6,  6,  9,  9,  7,  9,  9,  9,  9,  9, 10, 12, 13,  9,  9,  7, 11,
         9,  6, 11,  9,  9, 14,  7, 15, 16, 17, 18, 19,  9,  6, 20,  9])


In [91]:
test_masking_target(train_tokenized['concept'][0], ds_sample)

True

In [92]:
train_tokenized
    

{'concept': tensor([[  3,   5,   6,  ...,   0,   0,   0],
        [  3,  21,   4,  ...,   0,   0,   0],
        [  3,  40,   9,  ...,   4,  51,   4],
        ...,
        [  3, 128, 239,  ...,  39,  67,   4],
        [  3,  11,   4,  ...,   4,   0,   0],
        [  3,  22,   4,  ...,   0,   0,   0]]), 'age': tensor([[ 0., 19., 19.,  ...,  0.,  0.,  0.],
        [ 0., 14., 14.,  ...,  0.,  0.,  0.],
        [ 0., 19., 64.,  ..., 68., 68., 68.],
        ...,
        [ 0., 19., 32.,  ..., 49., 49., 49.],
        [ 0.,  0.,  0.,  ..., 60.,  0.,  0.],
        [ 0., 15., 15.,  ...,  0.,  0.,  0.]]), 'abspos': tensor([[      0.,  879144.,  879144.,  ...,       0.,       0.,       0.],
        [      0.,  753360.,  753360.,  ...,       0.,       0.,       0.],
        [      0.,  637104., 1037784.,  ..., 1073232., 1073376., 1073376.],
        ...,
        [      0.,  698568.,  813984.,  ...,  960144.,  960144.,  960144.],
        [      0.,  548136.,  548136.,  ..., 1069944.,       0.,       0

### Setup model and optimizer

In [11]:
from transformers import BertForPreTraining, BertConfig
from torch.optim import AdamW

num_attention_heads = 6
hidden_size = num_attention_heads*10 # for parallel computation
intermediate_size = hidden_size*4 # from original paper

model = BertForPreTraining(
        BertConfig(
            vocab_size=len(train_dataset.vocabulary),
            type_vocab_size=int(train_dataset.max_segments),
            max_position_embeddings= 1024,
            hidden_size= hidden_size,
            num_hidden_layers= 6,
            linear=True,
            num_attention_heads= num_attention_heads,
            intermediate_size= intermediate_size
        )
    )

In [90]:
optimizer =  AdamW(
        model.parameters(),
        lr=1e-4,
        weight_decay= 0.01,
    )

In [91]:
from medbert.trainer import trainer
import importlib
importlib.reload(trainer)

ehr_trainer = trainer.EHRTrainer(
    model=model,
    optimizer=optimizer,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    args=DotDict({
        'epochs': 2,
        'batch_size': 32,
        'effective_batch_size': 128,
    }),
    cfg=DotDict({'scheduler': None,
                 'collate_fn': None,
                 'run_name':'ptest'}))
ehr_trainer.train()

[INFO] Run name not provided. Using random run name: 4ab7e77e0d4f4e4a8d14267137857956
[INFO] Run folder: ../runs\4ab7e77e0d4f4e4a8d14267137857956


Train 0:   0%|          | 0/13 [00:00<?, ?it/s]

Train 0:  31%|███       | 4/13 [00:13<00:29,  3.28s/it, loss=6.65]

Train loss 1: 6.64593505859375


Train 0:  62%|██████▏   | 8/13 [00:26<00:16,  3.24s/it, loss=6.64]

Train loss 2: 6.6392292976379395


Train 0:  92%|█████████▏| 12/13 [00:35<00:02,  2.45s/it, loss=6.62]

Train loss 3: 6.616518497467041


Train 0: 100%|██████████| 13/13 [00:36<00:00,  2.81s/it, loss=6.62]
Validation: 100%|██████████| 4/4 [00:04<00:00,  1.07s/it]


[INFO] Epoch 0 train loss: 6.123594724214994
[INFO] Epoch 0 val loss: 6.615511775016785
[INFO] Epoch 0 metrics: None



Train 1:  31%|███       | 4/13 [00:08<00:18,  2.08s/it, loss=6.61]

Train loss 1: 6.607713937759399


Train 1:  62%|██████▏   | 8/13 [00:15<00:08,  1.71s/it, loss=6.59]

Train loss 2: 6.594194531440735


Train 1:  92%|█████████▏| 12/13 [00:20<00:01,  1.38s/it, loss=6.58]

Train loss 3: 6.57567298412323


Train 1: 100%|██████████| 13/13 [00:21<00:00,  1.63s/it, loss=6.58]
Validation: 100%|██████████| 4/4 [00:02<00:00,  1.40it/s]

[INFO] Epoch 1 train loss: 6.08540967794565
[INFO] Epoch 1 val loss: 6.569050073623657
[INFO] Epoch 1 metrics: None






### References

<sub><sup>[1] Vaswani, Ashish, et al. "Attention is all you need." Advances in neural information processing systems 30 (2017).</sub></sup> <br>
<sub><sup>[2] Dosovitskiy, Alexey, et al. "An image is worth 16x16 words: Transformers for image recognition at scale." arXiv preprint arXiv:2010.11929 (2020).</sub></sup> <br>
<sub><sup>[3] Polyak, Adam, et al. "Speech resynthesis from discrete disentangled self-supervised representations." arXiv preprint arXiv:2104.00355 (2021).</sub></sup><br>
<sub><sup>[4] Devlin, Jacob, et al. "Bert: Pre-training of deep bidirectional transformers for language understanding." arXiv preprint arXiv:1810.04805 (2018).</sub></sup><br>
<sub><sup>[5] Li, Yikuan, et al. "BEHRT: transformer for electronic health records." Scientific reports 10.1 (2020): 1-12..</sub></sup><br>
<sub><sup>[6] Shang, Junyuan, et al. "Pre-training of graph augmented transformers for medication recommendation." arXiv preprint arXiv:1906.00346 (2019).</sub></sup><br>
<sub><sup>[7] Pang, Chao, et al. "CEHR-BERT: Incorporating temporal information from structured EHR data to improve prediction tasks." Machine Learning for Health. PMLR, 2021.</sub></sup><br>