# Tutorial for Med-BERT
Step-by-step guide to the PyTorch implementation of [Med-BERT](https://www.nature.com/articles/s41746-021-00455-y) 

<sub><sup><em>"Med-BERT: pretrained contextualized embeddings on large-scale structured electronic health records for disease prediction." NPJ digital medicine 4.1 (2021): 1-13., Rasmy, Laila, et al. <em><sub><sup>

------------------------------------------

## Introduction

The goal of Med-Bert is to obtain good representations of elctronic health records (EHR) to make predicitons for downstream tasks. <br>
In order to do so we leverage the power of the pretraining fine-tuning paradigm using a transformer architecture $^{1}$.  
Originally used for Natural Language Processing, the transformers have proven their universality by showing SoTA results in fields like computer vision $^2$ and speech recognition $^3$. <br>
Recently, a variant of the transformers, called BERT $^{4}$ has also been applied to medical data and electronic health records in particular $^{5-7}$.<br> 
There are countless tutorials that explain the theory and basic concepts behind the Transformers and BERT as well as their applicaiton to NLP, so here we will focus on using BERT for EHR specifically.


## 1. Data Preparation

All we need for the input are diagnosis codes and the dates of the visit (hospital or GP visit) in which these codes were assigned.<br>
We then construct a nested lists which consists of:
1. Patient ID
2. The length of stay (LOS) for each visit
3. Diagnosis Codes
4. Visit number 

for each patient.

#### Example:<br>
Assume that patient 0 with id ```'p0'``` has 2 visits with 5 and 20 days length. The first visit has 2 codes ```['M432', 'D321']``` and the second visit has one code ```['S839']```<br>
Then the first entry of the data list looks as follows:<br>
```['p0', [5, 20], ['M432', 'D321','S839'], [1, 1, 2]]```

The BERT model will take inputs of size ```(batch_size, sequence_len, hidden_dim)```.<br>
To get the right format we need to first tokenize the data and then get vector embeddings.

In [39]:
import pickle
import torch
import pprint 
pp = pprint.PrettyPrinter()

with open('../tutorial/example_data.pkl', 'rb') as f:
    example_data = pickle.load(f)

### 1.1 Tokenization
In this step we simply assign integers to each unique code and bring all sequences to the same length.<br>
Additional tokens are needed to Mask inputs and to take care of new codes that are not in the vocab dictionary that might appear in the future.

We create a class to take care of tokenization:

In [14]:
class EHRTokenizer():
    # First, we define the vocabulary dict with the special tokens
    def __init__(self, vocabulary=None):
        if isinstance(vocabulary, type(None)):
            self.vocabulary = {
                'PAD':0, # padding token
                'MASK':1, # masking token (for masked language modeling)
                'UNK':2, # unknown token (for out-of-vocabulary tokens)
            }
            # BERT does not use the 'CLS' and 'SEP' tokens
        else:
            self.vocabulary = vocabulary
    def __call__(self, seq):
        return self.batch_encode(seq)

    def encode(self, seq):
        # create a new token for each new code
        for code in seq:
            if code not in self.vocabulary:
                self.vocabulary[code] = len(self.vocabulary)
        return [self.vocabulary[code] for code in seq]

    def batch_encode(self, seqs, max_len=None):
        # we construct a dictionary to store the tokenized data
        if isinstance(max_len, type(None)):
            max_len = max([len(seq) for seq in seqs])
        pat_ids = [seq[0] for seq in seqs]
        los_seqs = [seq[1] for seq in seqs]
        code_seqs = [seq[2] for seq in seqs] # icd codes
        visit_seqs = [seq[3] for seq in seqs]
        if isinstance(max_len, type(None)):
            max_len = max([len(seq) for seq in code_seqs])    
        output_code_seqs = []
        output_visit_seqs = []
        for code_seq, visit_seq in zip(code_seqs, visit_seqs):
            # truncation
            if len(code_seq)>max_len:
                code_seq = code_seq[:max_len]
                visit_seq = visit_seq[:max_len]
            # Tokenizing
            tokenized_code_seq = self.encode(code_seq)
            output_code_seqs.append(tokenized_code_seq)
            output_visit_seqs.append(visit_seq)
        tokenized_data_dic = {'pats':pat_ids, 'los':los_seqs, 'codes':output_code_seqs, 
                            'segments':output_visit_seqs}
        return tokenized_data_dic

    def save_vocab(self, dest):
        # save the vocabulary
        print(f"Writing vocab to {dest}")
        torch.save(self.vocabulary, dest)

Now, let's run the tokenization:

In [42]:
Tokenizer = EHRTokenizer()
tokenized_data_dic = Tokenizer.batch_encode(example_data, max_len=20)
torch.save(tokenized_data_dic, '../tutorial/tokenized.pt')
Tokenizer.save_vocab('../tutorial/vocab.pt')
# Lets look at the tokenized data
for k,v in tokenized_data_dic.items():
    pp.pprint(f"{k} {v[:3]}")
# Lets look at the vocabulary
print('Vocabulary:')
for i, (k,v) in enumerate(Tokenizer.vocabulary.items()):
    if i>5:
        break
    print(k, v)

Writing vocab to ../tutorial/vocab.pt
'pats [0, 1, 2]'
'los [[1, 24, 18, 20], [27, 12], [22, 1, 18]]'
'codes [[3, 4, 5, 6, 7, 8, 9, 10], [11, 12, 13, 14], [15, 16, 17, 18, 19, 20]]'
'segments [[1, 1, 2, 2, 3, 3, 4, 4], [1, 1, 1, 2], [1, 1, 2, 2, 2, 3]]'
Vocabulary:
PAD 0
MASK 1
UNK 2
M29.7 3
C49.2 4
P12.3 5


### References

<sub><sup>[1] Vaswani, Ashish, et al. "Attention is all you need." Advances in neural information processing systems 30 (2017).</sub></sup> <br>
<sub><sup>[2] Dosovitskiy, Alexey, et al. "An image is worth 16x16 words: Transformers for image recognition at scale." arXiv preprint arXiv:2010.11929 (2020).</sub></sup> <br>
<sub><sup>[3] Polyak, Adam, et al. "Speech resynthesis from discrete disentangled self-supervised representations." arXiv preprint arXiv:2104.00355 (2021).</sub></sup><br>
<sub><sup>[4] Devlin, Jacob, et al. "Bert: Pre-training of deep bidirectional transformers for language understanding." arXiv preprint arXiv:1810.04805 (2018).</sub></sup><br>
<sub><sup>[5] Li, Yikuan, et al. "BEHRT: transformer for electronic health records." Scientific reports 10.1 (2020): 1-12..</sub></sup><br>
<sub><sup>[6] Shang, Junyuan, et al. "Pre-training of graph augmented transformers for medication recommendation." arXiv preprint arXiv:1906.00346 (2019).</sub></sup><br>
<sub><sup>[7] Pang, Chao, et al. "CEHR-BERT: Incorporating temporal information from structured EHR data to improve prediction tasks." Machine Learning for Health. PMLR, 2021.</sub></sup><br>