In [1]:
!pip install transformers

Collecting transformers
  Downloading transformers-4.12.5-py3-none-any.whl (3.1 MB)
[K     |████████████████████████████████| 3.1 MB 9.9 MB/s 
[?25hCollecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 61.0 MB/s 
[?25hCollecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 65.2 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.2.1-py3-none-any.whl (61 kB)
[K     |████████████████████████████████| 61 kB 667 kB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.46-py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 99.1 MB/s 
Installing collected packages: pyyaml, tokenizers, sacremoses, huggingface-hub, transformers
 

In [1]:
import pandas as pd
import json
import numpy as np
from collections import Counter
import pickle
from tqdm import tqdm
import seaborn as sns
import collections

import torch
from torch.utils.data import DataLoader
from transformers import AdamW
from transformers import BertTokenizer, TFBertForSequenceClassification, BertConfig ,DistilBertTokenizerFast, DistilBertForQuestionAnswering

from sklearn.model_selection import train_test_split
from sklearn.utils import resample
from transformers import DistilBertModel, DistilBertConfig

import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.utils import resample

import pickle
import torch.optim as optim

from google.colab import drive 
drive.mount('/content/drive')

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

Mounted at /content/drive


device(type='cuda')

In [2]:
class ermQA(torch.utils.data.Dataset):
    def __init__(self, filename):
        with open(f"/content/drive/Shareddrives/NLP/EHReader/processed_data/{filename}.pickle", "rb") as f:
            self.encodings = pickle.load(f)

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

In [3]:
class DistilBERTEncoder(torch.nn.Module):
    def __init__(self, frozen=True):
        super(DistilBERTEncoder, self).__init__()
        self.encoder = DistilBertModel.from_pretrained('distilbert-base-uncased', output_hidden_states = True)
        self.encoder.to(device)
        if frozen:
            self.encoder.requires_grad = False
            self.encoder.eval()

    def forward(self, input_ids, attention_mask):
        output = self.encoder(input_ids, attention_mask = attention_mask)
        embedding = output.last_hidden_state # [batch, 512, 3072]

        return embedding
    
class DeepReader(torch.nn.Module):
    def __init__(self, embed_size=768, num_heads=1):
        super(DeepReader, self).__init__()
        self.encoder = DistilBERTEncoder(frozen=False)
        
        # Self attention on passage 
        self.passage_key_linear = nn.Linear(embed_size, embed_size)
        self.passage_value_linear = nn.Linear(embed_size, embed_size)
        self.passage_query_linear = nn.Linear(embed_size, embed_size)
        self.passage_attention = nn.MultiheadAttention(embed_dim=embed_size, num_heads=num_heads, batch_first=True)

        # Self attention on question 
        self.question_key_linear = nn.Linear(embed_size, embed_size)
        self.question_value_linear = nn.Linear(embed_size, embed_size)
        self.question_query_linear = nn.Linear(embed_size, embed_size)
        self.question_attention = nn.MultiheadAttention(embed_dim=embed_size, num_heads=num_heads, batch_first=True)

        # Cross attention 
        self.cross_query_linear = nn.Linear(embed_size, embed_size)
        self.cross_key_linear = nn.Linear(embed_size, embed_size)
        self.cross_value_linear = nn.Linear(embed_size, embed_size)
        self.cross_attention = nn.MultiheadAttention(embed_dim=embed_size, num_heads=num_heads, batch_first=True)

        # Feed forward neural network (FFN)
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(embed_size, embed_size),
            nn.ReLU(),
            nn.Linear(embed_size, 2)
        )
    
    def forward(self, input_ids, attention_mask):
        embeddings = self.encoder(input_ids, attention_mask = attention_mask)
        
        token_split_index = self.generate_token_split_index(input_ids)

        passage, question = torch.tensor_split(embeddings, token_split_index, dim=1)

        passage_key = self.passage_key_linear(passage)
        passage_value = self.passage_value_linear(passage)
        passage_query = self.passage_query_linear(passage)
        passage_after_attention, _ = self.passage_attention(query=passage_query, key=passage_key, value=passage_value)

        question_key = self.question_key_linear(question)
        question_value = self.question_value_linear(question)
        question_query = self.question_query_linear(question)
        question_after_attention, _ = self.question_attention(query=question_query, key=question_key, value=question_value)

        cross_query = self.cross_query_linear(passage_after_attention)
        cross_key = self.cross_key_linear(question_after_attention)
        cross_value = self.cross_value_linear(question_after_attention)
        cross_attention_embedding, _ = self.cross_attention(query=cross_query, key=cross_key, value=cross_value)

        ffn_output = self.linear_relu_stack(cross_attention_embedding)
        output = nn.functional.softmax(ffn_output, dim=1)

        return output
    
    def generate_token_split_index(self, input_ids):
        token_split_index = []
        sep_tokens = (input_ids == 102).nonzero(as_tuple=True)
        used_samples = set()
        for i, index in zip(sep_tokens[0], sep_tokens[1]):
            if i.item() not in used_samples:
                token_split_index.append(index.item())
                used_samples.add(i.item())
        return token_split_index

In [4]:
TRAIN_BATCH_SIZE = 1
train_dataset = ermQA('medication_qa_train')
train_loader = DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True)


VAL_BATCH_SIZE = 1
val_dataset = ermQA('medication_qa_val')
val_loader = DataLoader(val_dataset, batch_size=VAL_BATCH_SIZE, shuffle=True)

In [5]:
model = DeepReader()
model.to(device)
metadata = dict()

if metadata == dict():
    START_EPOCH = 0
    train_loss = []
    val_loss = []
else:
    START_EPOCH = metadata['epoch'] + 1
    train_loss = metadata['train_loss']
    val_loss = metadata['valid_loss']

Downloading:   0%|          | 0.00/483 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/256M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
NUM_EPOCHS = 2
dr_loss_func = nn.CrossEntropyLoss()
optim = AdamW(model.parameters(), lr=3e-5)

for epoch in range(START_EPOCH, START_EPOCH+NUM_EPOCHS):
    # Train
    model.train()
    batch_loss = []
    current_loss = []
    for batch in tqdm(train_loader):
        torch.cuda.empty_cache()
        
        # Forward 
        input_ids = torch.tensor(batch['input_ids']).to(device)
        attention_mask = torch.tensor(batch['attention_mask']).to(device)
        dr_out = model(input_ids, attention_mask)

        # Calculate dr loss
        start_logits, end_logits = dr_out.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1).contiguous()
        end_logits = end_logits.squeeze(-1).contiguous()

        start_loss = dr_loss_func(start_logits,  batch['start_positions'].to(device))
        end_loss = dr_loss_func(end_logits, batch['end_positions'].to(device))
        dr_loss = (start_loss + end_loss) / 2

        # Calculate loss and backward
        current_loss.append(dr_loss)
        if len(current_loss) == 32:
          total_loss = sum(current_loss) / 32 
          batch_loss.append(total_loss.item())
          optim.zero_grad()
          total_loss.backward()
          optim.step()
          current_loss = []
          if len(batch_loss) % 100 == 0:
            print(np.mean(batch_loss))

    train_loss.append(np.mean(batch_loss))
  
    # Validation
    model.eval()
    batch_loss = []
    with torch.no_grad():
        for batch in tqdm(val_loader):
            torch.cuda.empty_cache()

            # Forward 
            input_ids = torch.tensor(batch['input_ids']).to(device)
            attention_mask = torch.tensor(batch['attention_mask']).to(device)
            dr_out = model(input_ids, attention_mask)

            # Calculate dr loss
            start_logits, end_logits = dr_out.split(1, dim=-1)
            start_logits = start_logits.squeeze(-1).contiguous()
            end_logits = end_logits.squeeze(-1).contiguous()

            start_loss = dr_loss_func(start_logits,  batch['start_positions'].to(device))
            end_loss = dr_loss_func(end_logits, batch['end_positions'].to(device))
            dr_loss = (start_loss + end_loss) / 2

            # Calculate loss and backward
            current_loss.append(dr_loss)
            if len(current_loss) == 32:
              total_loss = sum(current_loss) / 32 
              batch_loss.append(total_loss.item())
              current_loss = []

    val_loss.append(np.mean(batch_loss))
    
    print(f'Epoch: {epoch}, train_loss: {train_loss[-1]}, val_loss: {val_loss[-1]}')
    
    model_name = f'/content/drive/Shareddrives/NLP/EHReader/DeepReader/m_1_uf_e_{len(val_loss)}_vl_{round(val_loss[-1], 4)}'
    metadata = {
        'epoch': epoch,
        'train_loss': train_loss,
        'valid_loss': val_loss
    }
  
    # Early Stopping
    if len(val_loss) > 3:
        if val_loss[-1] > val_loss[-2] > val_loss[-3]:
            torch.save(model, f'{model_name}.model')
            
            with open(f'{model_name}_metadata.pickle', 'wb') as f:
                pickle.dump(metadata, f)
            
    # Check point
    torch.save(model, f'{model_name}.model') 

    with open(f'{model_name}_metadata.pickle', 'wb') as f:
        pickle.dump(metadata, f)

  import sys
  
  from ipykernel import kernelapp as app
  9%|▉         | 3197/36288 [03:08<19:28, 28.31it/s]

5.714346036911011


 18%|█▊        | 6396/36288 [06:17<18:00, 27.67it/s]

5.713166468143463


 26%|██▋       | 9596/36288 [09:27<16:11, 27.48it/s]

5.710565487543742


 35%|███▌      | 12796/36288 [12:36<14:06, 27.75it/s]

5.707761924266816


 44%|████▍     | 15996/36288 [15:46<12:00, 28.18it/s]

5.706523044586182


 53%|█████▎    | 19197/36288 [18:55<10:06, 28.19it/s]

5.705400544007619


 62%|██████▏   | 22396/36288 [22:05<08:15, 28.03it/s]

5.706202387809753


 71%|███████   | 25600/36288 [25:15<12:25, 14.33it/s]

5.7070645874738695


 79%|███████▉  | 28799/36288 [28:24<04:25, 28.22it/s]

5.705572285652161


 88%|████████▊ | 31999/36288 [31:33<02:27, 29.10it/s]

5.706065826892853


 97%|█████████▋| 35199/36288 [34:43<00:38, 28.29it/s]

5.7061049565401945


100%|██████████| 36288/36288 [35:48<00:00, 16.89it/s]
100%|██████████| 12096/12096 [02:42<00:00, 74.58it/s]


Epoch: 0, train_loss: 5.705892985459989, val_loss: 5.709470722410414


  1%|          | 223/36288 [00:12<34:42, 17.31it/s]


KeyboardInterrupt: ignored