In [1]:
import pandas as pd
import json
import numpy as np
from collections import Counter
import pickle
from tqdm import tqdm
import seaborn as sns
import collections

import torch
from torch.utils.data import DataLoader
from transformers import AdamW
from transformers import BertTokenizer, TFBertForSequenceClassification, BertConfig ,DistilBertTokenizerFast, DistilBertForQuestionAnswering

from sklearn.model_selection import train_test_split
from sklearn.utils import resample
from transformers import DistilBertModel, DistilBertConfig

import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.utils import resample

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda')

In [2]:
class ermQA(torch.utils.data.Dataset):
    def __init__(self, filename):
        with open(f"processed_data/{filename}.pickle", "rb") as f:
            self.encodings = pickle.load(f)

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

In [3]:
class DistilBERTEncoder(torch.nn.Module):
    def __init__(self, frozen=True):
        super(DistilBERTEncoder, self).__init__()
        self.encoder = DistilBertModel.from_pretrained('distilbert-base-uncased', output_hidden_states = True)
        self.encoder.to(device)
        if frozen:
            self.encoder.requires_grad = False
            self.encoder.eval()

    def forward(self, input_ids, attention_mask):
        output = self.encoder(input_ids, attention_mask = attention_mask)
        embedding = output.last_hidden_state # [batch, 512, 3072]

        return embedding
    
    
class SimpleReader(torch.nn.Module):
    def __init__(self, in_features=768, out_features=1):
        super(SimpleReader, self).__init__()
        self.encoder = DistilBERTEncoder(frozen=False)
        self.linear = nn.Linear(in_features=in_features, out_features=out_features)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_ids, attention_mask):
        embeddings = self.encoder(input_ids, attention_mask = attention_mask)
        embedding_first_token = torch.squeeze(embeddings[:, 0, :], axis = 1) # [batch, 3072]    
        linear = self.linear(embedding_first_token) # [batch, 1] 
        logit = self.sigmoid(linear)  # [batch, 1]   
        return logit

In [4]:
TRAIN_BATCH_SIZE = 8
train_dataset = ermQA('medication_qa_train')
train_loader = DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True)


VAL_BATCH_SIZE = 8
val_dataset = ermQA('medication_qa_val')
val_loader = DataLoader(val_dataset, batch_size=VAL_BATCH_SIZE, shuffle=True)

In [5]:
model = SimpleReader()
model.to(device)
metadata = dict()

if metadata == dict():
    START_EPOCH = 0
    train_loss = []
    val_loss = []
else:
    START_EPOCH = metadata['epoch'] + 1
    train_loss = metadata['train_loss']
    val_loss = metadata['valid_loss']

In [6]:
NUM_EPOCHS = 2
sr_loss_func = nn.BCELoss()
optim = AdamW(model.parameters(), lr=3e-5)

for epoch in range(START_EPOCH, START_EPOCH+NUM_EPOCHS):
    # Train
    model.train()
    batch_loss = []
    for batch in tqdm(train_loader):
        torch.cuda.empty_cache()
        
        # Forward 
        input_ids = torch.tensor(batch['input_ids']).to(device)
        attention_mask = torch.tensor(batch['attention_mask']).to(device)
        sr_out = model(input_ids, attention_mask)

        # Calculate sr loss
        answerability = (batch['start_positions'] < batch['end_positions']).float().view(-1, 1).to(device)
        sr_loss = sr_loss_func(sr_out, answerability)

        # Calculate loss and backward
        batch_loss.append(sr_loss.item())
        optim.zero_grad()
        sr_loss.backward()
        optim.step()
        

    train_loss.append(np.mean(batch_loss))
 
        
    # Validation
    model.eval()
    current_loss = []
    batch_loss = []
    with torch.no_grad():
        for batch in tqdm(val_loader):
            torch.cuda.empty_cache()

            # Forward 
            input_ids = torch.tensor(batch['input_ids']).to(device)
            attention_mask = torch.tensor(batch['attention_mask']).to(device)
            sr_out = model(input_ids, attention_mask)

            # Calculate sr loss
            answerability = (batch['start_positions'] < batch['end_positions']).float().view(-1, 1).to(device)
            sr_loss = sr_loss_func(sr_out, answerability)

            # Calculate loss and backward
            batch_loss.append(sr_loss.item())
  
    val_loss.append(np.mean(batch_loss))
    
    print(f'Epoch: {epoch}, train_loss: {train_loss[-1]}, val_loss: {val_loss[-1]}')
    
    model_name = f'models/SimpleReader/m_1_uf_e_{len(val_loss)}_vl_{round(val_loss[-1], 4)}'
    metadata = {
        'epoch': epoch,
        'train_loss': train_loss,
        'valid_loss': val_loss
    }
  
    # Early Stopping
    if len(val_loss) > 3:
        if val_loss[-1] > val_loss[-2] > val_loss[-3]:
            torch.save(model, f'{model_name}.model')
            
            with open(f'{model_name}_metadata.pickle', 'wb') as f:
                pickle.dump(metadata, f)
            
    # Check point
    torch.save(model, f'{model_name}.model') 

    with open(f'{model_name}_metadata.pickle', 'wb') as f:
        pickle.dump(metadata, f)

  import sys
  del sys.path[0]
  
100%|██████████| 4536/4536 [33:05<00:00,  2.28it/s]
100%|██████████| 1512/1512 [03:18<00:00,  7.63it/s]


Epoch: 0, train_loss: 0.26788594718321146, val_loss: 0.18841451284858748


  0%|          | 6/4536 [00:02<32:58,  2.29it/s]


KeyboardInterrupt: 