#Imports

In [2]:
%load_ext autoreload
%autoreload 2

import torch
from torch.utils.data import DataLoader, Subset
from torch.nn import Module
from transformers import DistilBertTokenizerFast, DistilBertForQuestionAnswering
from pathlib import Path
from datetime import datetime

from config import conf
from dataset import SquadDataset
from training import train_loop, evaluate

torch.cuda.is_available()

True

#Temporary function for dataset portion selection

In [3]:
import pandas as pd
import math
def select_portion(data: pd.DataFrame, 
                    portion_val: int) -> pd.DataFrame:
   
    print('number of samples in original dataset', data.shape[0])

    selection_data = data.iloc[:portion_val] 
    print('number of samples in reduced dataset: ', selection_data.shape[0])
    return selection_data

#Datasets - tokenizer - model

In [4]:
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
dataset = SquadDataset.from_json(conf['DATASET_FILE'], tokenizer)
train_dataset, val_dataset = dataset.train_val_split(conf['TRAIN_RATIO'])

#train_dataset = Subset(train_dataset, range(500))
#val_dataset = Subset(val_dataset, range(100))
val_dataset=SquadDataset(select_portion(val_dataset.data, 100), tokenizer)

model = DistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased')

number of samples in original dataset 21899
number of samples in reduced dataset:  100


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForQuestionAnswering: ['vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this mode

In [5]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else: 
    device = torch.device('cpu')
model.to(device)

opt = torch.optim.Adam(model.parameters(), lr=5e-5)
train_loader = DataLoader(train_dataset, 
                          batch_size=conf['BATCH_SIZE'], 
                          shuffle=True)
val_loader = DataLoader(val_dataset, 
                        batch_size=conf['BATCH_SIZE'])

#Train or Evaluate

In [6]:
if conf['TRAIN_MODEL']:
    train_loss, train_acc, val_loss, val_acc = train_loop(model, train_loader, val_loader, opt, device)

    if conf['SAVE_MODEL']:
        Path(conf['MODELS_FOLDER']).mkdir(parents=True, exist_ok=True)
        filepath = f"{conf['MODELS_FOLDER']}/model_{datetime.today().strftime('%m%d')}.pt"
        torch.save(model.state_dict(), filepath)
        print(f'Model saved in {filepath}')

else:
    filepath = conf['MODELS_FOLDER'] + '/' + conf['MODEL_LOAD_NAME']
    model.load_state_dict(torch.load(filepath))
    print(f'Loaded model at {filepath}')

    n_val, val_loss, val_acc, pred, ref = evaluate(model, val_loader, device)
    val_loss /= n_val
    val_acc /= n_val
    print(f'\n\nValidation loss: {val_loss:.3f}')
    print(f'\nValidation accuracy: {val_acc:.3f}')

Loaded model at ./models/model_0125.pt


100%|██████████| 9/9 [00:03<00:00,  2.66it/s]



Validation loss: 1.144

Validation accuracy: 0.765





#Predicted answers

In [7]:
#Due to the WordPiece embedding employed by Bert, some words are splitted by ##

import re
def replace_hashtag(text: str) -> str:
    return re.compile(' ##')\
        .sub('', text)

In [8]:
ans={}
for i in range(len(val_dataset)):
  q_id=val_dataset.data['id'][i]
  q_text=val_dataset.data['question'][i]
  pred_span=val_dataset.encodings[i].tokens[pred[i][0]:pred[i][1]+1]
  pred_ans=replace_hashtag(' '.join(word for word in(pred_span)))
  ref_span=val_dataset.encodings[i].tokens[ref[i][0]:ref[i][1]+1]
  ref_ans=replace_hashtag(' '.join(word for word in(ref_span)))
  ans[q_id]=(q_text, pred_ans, ref_ans)

for i in ans.keys():
  print('\n\nQuestion id: ', i)
  print('Question text: ', ans[i][0])
  print('Predicted answer: ', ans[i][1])
  print('True answer: ', ans[i][2])




Question id:  5728b3aeff5b5019007da4e6
Question text:  Who was described at the prophetic deity of the Delphic Oracle?
Predicted answer:  apollo
True answer:  apollo


Question id:  5728b3aeff5b5019007da4e7
Question text:  What is the name of Apollo's son?
Predicted answer:  asclepius
True answer:  asclepius


Question id:  5728b3aeff5b5019007da4e8
Question text:  Who created the lyre for Apollo?
Predicted answer:  hermes
True answer:  hermes


Question id:  5728b3aeff5b5019007da4e9
Question text:  What was the term for hymns sung to Apollo?
Predicted answer:  paeans
True answer:  paeans


Question id:  5728b449ff5b5019007da4f6
Question text:  Who was the Titan goddess of the moon?
Predicted answer:  selene
True answer:  selene


Question id:  5728b449ff5b5019007da4f7
Question text:  In Hellenestic times, Greeks identified Apollo Helios as what name?
Predicted answer:  helios
True answer:  helios


Question id:  5728b449ff5b5019007da4f8
Question text:  What was the name of Apollo's s