#Imports
Upload in content:


*   config.py
*   test_dataset.py
*   dataset.py
*   dev-v1.1.json
*   model





In [None]:
pip install transformers

In [39]:
import json
import torch
from torch.utils.data import DataLoader
from torch.nn import Module
from transformers import DistilBertTokenizerFast, DistilBertForQuestionAnswering
from tqdm import tqdm

from test_dataset import Test_SquadDataset
from config import conf

model: Module = DistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased')

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForQuestionAnswering: ['vocab_transform.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_projector.bias', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this mode

#Load Model

In [40]:
if not conf['TRAIN_MODEL']:
    filepath = conf['MODELS_FOLDER'] + '/' + conf['MODEL_LOAD_NAME']
    #if not using colab gpu: add map_location=torch.device('cpu') to torch.load
    model.load_state_dict(torch.load(filepath,map_location=torch.device('cpu')))
    print(f"Loaded model at {filepath}")


Loaded model at ./models/model_0125.pt


#Test set preprocessing

In [41]:
import pandas as pd
import math
def select_portion(data: pd.DataFrame, 
                    selection_ratio:float) -> pd.DataFrame:
   
    n_samples = data.shape[0]
    print('number of samples in original dataset', n_samples)
    n_selection = math.ceil(n_samples * selection_ratio) 
    n_rest = n_samples - n_selection

    selection_data = data.iloc[:n_selection] 
    print('number of samples in reduced dataset: ', selection_data.shape[0])
    return selection_data

In [42]:
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

test_data = Test_SquadDataset.from_json('dev-v1.1.json', tokenizer) 
test_dataset=Test_SquadDataset(select_portion(test_data.data,0.01), tokenizer)

number of samples in original dataset 10570
number of samples in reduced dataset:  106


#Test

In [43]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

test_loader = DataLoader(test_dataset, batch_size=32)
test_iter = tqdm(test_loader)

start_tokens=[]
end_tokens=[]
with torch.no_grad():

  for test_batch in test_iter:
    input_ids = test_batch['input_ids'].to(device)
    attention_mask = test_batch['attention_mask'].to(device)
    
    outputs = model(input_ids, 
                    attention_mask=attention_mask)
            
    start_pred = torch.argmax(outputs['start_logits'], dim=1)
    end_pred = torch.argmax(outputs['end_logits'], dim=1)    
    
    start_tokens+=[i.item() for i in start_pred]
    end_tokens+=[i.item() for i in end_pred]
    

100%|██████████| 4/4 [00:30<00:00,  7.65s/it]


#Results

In [50]:
import re
def replace_hashtag(text: str) -> str:
    return re.compile('#')\
        .sub('', text)

In [52]:
ans={}
for i in range(len(test_dataset)):
  q_id=test_dataset.data['id'][i]
  q_text=test_dataset.data['question'][i]
  span=test_dataset.encodings[i].tokens[start_tokens[i]:end_tokens[i]+1]
  pred=replace_hashtag(' '.join(word for word in(span)))
  ans[q_id]=(q_text, pred)

for i in ans.keys():
  print('\n\n',i, '\n',ans[i][0], '\n',ans[i][1])



 56be4db0acb8001400a502ec 
 Which NFL team represented the AFC at Super Bowl 50? 
 denver broncos


 56be4db0acb8001400a502ed 
 Which NFL team represented the NFC at Super Bowl 50? 
 carolina panthers


 56be4db0acb8001400a502ee 
 Where did Super Bowl 50 take place? 
 


 56be4db0acb8001400a502ef 
 Which NFL team won Super Bowl 50? 
 carolina panthers


 56be4db0acb8001400a502f0 
 What color was used to emphasize the 50th anniversary of the Super Bowl? 
 gold


 56be8e613aeaaa14008c90d1 
 What was the theme of Super Bowl 50? 
 golden anniversary "


 56be8e613aeaaa14008c90d2 
 What day was the game played on? 
 february 7 , 2016


 56be8e613aeaaa14008c90d3 
 What is the AFC short for? 
 american football conference


 56bea9923aeaaa14008c91b9 
 What was the theme of Super Bowl 50? 
 golden anniversary "


 56bea9923aeaaa14008c91ba 
 What does AFC stand for? 
 american football conference


 56bea9923aeaaa14008c91bb 
 What day was the Super Bowl played on? 
 february 7 , 2016


 56bea

#In requested format

In [53]:
ans={}
for i in range(len(test_dataset)):
  q_id=test_dataset.data['id'][i]
  span=test_dataset.encodings[i].tokens[start_tokens[i]:end_tokens[i]+1]
  pred=replace_hashtag(' '.join(word for word in(span)))
  ans[q_id]= pred

ans

{'56be4db0acb8001400a502ec': 'denver broncos',
 '56be4db0acb8001400a502ed': 'carolina panthers',
 '56be4db0acb8001400a502ee': '',
 '56be4db0acb8001400a502ef': 'carolina panthers',
 '56be4db0acb8001400a502f0': 'gold',
 '56be4e1facb8001400a502f6': 'cam newton',
 '56be4e1facb8001400a502f9': 'eight',
 '56be4e1facb8001400a502fa': '1995',
 '56be4eafacb8001400a50302': 'von miller',
 '56be4eafacb8001400a50303': 'two',
 '56be4eafacb8001400a50304': 'broncos',
 '56be5333acb8001400a5030a': 'cbs',
 '56be5333acb8001400a5030b': '$ 5 million',
 '56be5333acb8001400a5030c': 'cold play',
 '56be5333acb8001400a5030d': 'beyonce and bruno mars',
 '56be5333acb8001400a5030e': 'super bowl 50',
 '56be8e613aeaaa14008c90d1': 'golden anniversary "',
 '56be8e613aeaaa14008c90d2': 'february 7 , 2016',
 '56be8e613aeaaa14008c90d3': 'american football conference',
 '56bea9923aeaaa14008c91b9': 'golden anniversary "',
 '56bea9923aeaaa14008c91ba': 'american football conference',
 '56bea9923aeaaa14008c91bb': 'february 7 , 20