In [1]:
!pip install transformers



In [2]:
import pandas as pd
import json
import numpy as np
from collections import Counter
import pickle
from tqdm import tqdm
import seaborn as sns
import collections

import torch
from torch.utils.data import DataLoader
from transformers import AdamW
from transformers import BertTokenizer, TFBertForSequenceClassification, BertConfig ,DistilBertTokenizerFast, DistilBertForQuestionAnswering

from sklearn.model_selection import train_test_split
from sklearn.utils import resample
from transformers import DistilBertModel, DistilBertConfig

import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.utils import resample

import pickle
import torch.optim as optim

from google.colab import drive 
drive.mount('/content/drive')

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


device(type='cuda')

In [3]:
class ermQA(torch.utils.data.Dataset):
    def __init__(self, filename):
        with open(f"/content/drive/Shareddrives/NLP/EHReader/processed_data/{filename}.pickle", "rb") as f:
            self.encodings = pickle.load(f)

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

In [4]:
class DistilBERTEncoder(torch.nn.Module):
    def __init__(self, frozen=True):
        super(DistilBERTEncoder, self).__init__()
        self.encoder = DistilBertModel.from_pretrained('distilbert-base-uncased', output_hidden_states = True)
        self.encoder.to(device)
        if frozen:
            self.encoder.requires_grad = False
            self.encoder.eval()

    def forward(self, input_ids, attention_mask):
        output = self.encoder(input_ids, attention_mask = attention_mask)
        embedding = output.last_hidden_state # [batch, 512, 3072]

        return embedding
    
class DeepReader(torch.nn.Module):
    def __init__(self, embed_size=768, num_heads=1):
        super(DeepReader, self).__init__()
        self.encoder = DistilBERTEncoder(frozen=False)
        
        self.key_linear = nn.Linear(embed_size, embed_size)
        self.value_linear = nn.Linear(embed_size, embed_size)
        self.query_linear = nn.Linear(embed_size, embed_size)
        self.attention = nn.MultiheadAttention(embed_dim=embed_size, num_heads=num_heads, batch_first=True)

        # Feed forward neural network (FFN)
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(embed_size, embed_size),
            nn.ReLU(),
            nn.Linear(embed_size, 2)
        )
    
    def forward(self, input_ids, attention_mask):
        embeddings = self.encoder(input_ids, attention_mask = attention_mask)
        
        key = self.key_linear(embeddings)
        value = self.value_linear(embeddings)
        query = self.query_linear(embeddings)
        embedding, _ = self.attention(query=query, key=key, value=value)
        
        ffn_output = self.linear_relu_stack(embedding)
        output = nn.functional.softmax(ffn_output, dim=1)

        return output

In [5]:
TRAIN_BATCH_SIZE = 32
train_dataset = ermQA('medication_qa_train')
train_loader = DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True)


VAL_BATCH_SIZE = 32
val_dataset = ermQA('medication_qa_val')
val_loader = DataLoader(val_dataset, batch_size=VAL_BATCH_SIZE, shuffle=True)

In [6]:
model = DeepReader()
model.to(device)
metadata = dict()

if metadata == dict():
    START_EPOCH = 0
    train_loss = []
    val_loss = []
else:
    START_EPOCH = metadata['epoch'] + 1
    train_loss = metadata['train_loss']
    val_loss = metadata['valid_loss']

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [7]:
NUM_EPOCHS = 2
dr_loss_func = nn.CrossEntropyLoss()
optim = AdamW(model.parameters(), lr=3e-5)

for epoch in range(START_EPOCH, START_EPOCH+NUM_EPOCHS):
    # Train
    model.train()
    batch_loss = []
    for batch in tqdm(train_loader):
        torch.cuda.empty_cache()
        
        # Forward 
        input_ids = torch.tensor(batch['input_ids']).to(device)
        attention_mask = torch.tensor(batch['attention_mask']).to(device)
        dr_out = model(input_ids, attention_mask)

        # Calculate dr loss
        start_logits, end_logits = dr_out.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1).contiguous()
        end_logits = end_logits.squeeze(-1).contiguous()

        start_loss = dr_loss_func(start_logits,  batch['start_positions'].to(device))
        end_loss = dr_loss_func(end_logits, batch['end_positions'].to(device))
        dr_loss = (start_loss + end_loss) / 2

        # Calculate loss and backward
        batch_loss.append(dr_loss.item())
        optim.zero_grad()
        dr_loss.backward()
        optim.step()

    train_loss.append(np.mean(batch_loss))
  
    # Validation
    model.eval()
    batch_loss = []
    with torch.no_grad():
        for batch in tqdm(val_loader):
            torch.cuda.empty_cache()

            # Forward 
            input_ids = torch.tensor(batch['input_ids']).to(device)
            attention_mask = torch.tensor(batch['attention_mask']).to(device)
            dr_out = model(input_ids, attention_mask)

            # Calculate dr loss
            start_logits, end_logits = dr_out.split(1, dim=-1)
            start_logits = start_logits.squeeze(-1).contiguous()
            end_logits = end_logits.squeeze(-1).contiguous()

            start_loss = dr_loss_func(start_logits,  batch['start_positions'].to(device))
            end_loss = dr_loss_func(end_logits, batch['end_positions'].to(device))
            dr_loss = (start_loss + end_loss) / 2

            # Calculate loss and backward
            batch_loss.append(dr_loss.item())

    val_loss.append(np.mean(batch_loss))
    
    print(f'Epoch: {epoch}, train_loss: {train_loss[-1]}, val_loss: {val_loss[-1]}')
    
    model_name = f'/content/drive/Shareddrives/NLP/EHReader/DeepReader/m_1_uf_e_{len(val_loss)}_vl_{round(val_loss[-1], 4)}'
    metadata = {
        'epoch': epoch,
        'train_loss': train_loss,
        'valid_loss': val_loss
    }
  
    # Early Stopping
    if len(val_loss) > 3:
        if val_loss[-1] > val_loss[-2] > val_loss[-3]:
            torch.save(model, f'{model_name}.model')
            
            with open(f'{model_name}_metadata.pickle', 'wb') as f:
                pickle.dump(metadata, f)
            
    # Check point
    torch.save(model, f'{model_name}.model') 

    with open(f'{model_name}_metadata.pickle', 'wb') as f:
        pickle.dump(metadata, f)

  import sys
  del sys.path[0]
  
100%|██████████| 1134/1134 [24:09<00:00,  1.28s/it]
100%|██████████| 378/378 [02:13<00:00,  2.83it/s]


Epoch: 0, train_loss: 6.2094117909722435, val_loss: 5.741673818971745


 43%|████▎     | 485/1134 [10:20<13:50,  1.28s/it]


KeyboardInterrupt: ignored