**Purpose:**

Fine tuned the SQuAD pretrained "distilbert-base-uncased-distilled-squad" model on the Covid-QA dataset. 

In [None]:
!pip install transformers
import time
import os
import contextlib
import torch
import nltk
nltk.download('punkt')

from google.colab import drive
drive.mount('/content/drive')

In [None]:
import json
from pathlib import Path

def read_squad(path):
    path = Path(path)
    with open(path, 'rb') as f:
        squad_dict = json.load(f)

    contexts = []
    questions = []
    answers = []
    for group in squad_dict['data']:
        for passage in group['paragraphs']:
            context = passage['context']
            for qa in passage['qas']:
                question = qa['question']
                for answer in qa['answers']:
                    contexts.append(context)
                    questions.append(question)
                    answers.append(answer)

    return contexts, questions, answers

train_contexts, train_questions, train_answers = read_squad('/content/drive/My Drive/colab_files/data/Covid-QA/Covid-QA-train.json')
val_contexts, val_questions, val_answers = read_squad('/content/drive/My Drive/colab_files/data/Covid-QA/Covid-QA-val.json')

In [None]:
def add_end_idx(answers, contexts):
    for answer, context in zip(answers, contexts):
        gold_text = answer['text']
        start_idx = answer['answer_start']
        end_idx = start_idx + len(gold_text)

        # sometimes squad answers are off by a character or two – fix this
        if context[start_idx:end_idx] == gold_text:
            answer['answer_end'] = end_idx
        elif context[start_idx-1:end_idx-1] == gold_text:
            answer['answer_start'] = start_idx - 1
            answer['answer_end'] = end_idx - 1     # When the gold label is off by one character
        elif context[start_idx-2:end_idx-2] == gold_text:
            answer['answer_start'] = start_idx - 2
            answer['answer_end'] = end_idx - 2     # When the gold label is off by two characters

add_end_idx(train_answers, train_contexts)
add_end_idx(val_answers, val_contexts)

In [None]:
from transformers import DistilBertTokenizerFast, DistilBertForQuestionAnswering

tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased', return_token_type_ids = True)
model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased-distilled-squad")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=451.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=265481570.0, style=ProgressStyle(descri…




In [None]:
train_encodings = tokenizer(train_contexts, train_questions, truncation=True, padding=True)
val_encodings = tokenizer(val_contexts, val_questions, truncation=True, padding=True)

In [None]:
def add_token_positions(encodings, answers):
    start_positions = []
    end_positions = []
    for i in range(len(answers)):
        start_positions.append(encodings.char_to_token(i, answers[i]['answer_start']))
        end_positions.append(encodings.char_to_token(i, answers[i]['answer_end'] - 1))
        # if None, the answer passage has been truncated
        if start_positions[-1] is None:
            start_positions[-1] = tokenizer.model_max_length
        if end_positions[-1] is None:
            end_positions[-1] = tokenizer.model_max_length
    encodings.update({'start_positions': start_positions, 'end_positions': end_positions})

add_token_positions(train_encodings, train_answers)
add_token_positions(val_encodings, val_answers)

In [None]:
import torch

class SquadDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

train_dataset = SquadDataset(train_encodings)
val_dataset = SquadDataset(val_encodings)

In [None]:
from torch.utils.data import DataLoader
from transformers import AdamW

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model.to(device)
model.train()

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

optim = AdamW(model.parameters(), lr=5e-5)

for epoch in range(3):
    for batch in train_loader:
        optim.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)
        outputs = model(input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions)
        loss = outputs[0]
        loss.backward()
        optim.step()

model.eval()

DistilBertForQuestionAnswering(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            

In [None]:
# Inference:
import time
start_time = time.time()
context = 'The models suggested that without a vaccine, school closures would be unlikely to affect the pandemic, an estimated 35 000 to 60 000 ventilators would be needed, up to an estimated 7.3 billion surgical masks or respirators would be required, and perhaps most important, if vaccine development did not start before the virus was introduced, it was unlikely that a significant number of hospitalizations and deaths could be averted due to the time it takes to develop, test, manufacture, and distribute a vaccine.'
question = "How many surgical masks or respirators have past studies projected will be required for a pandemic in the United States?"

encoding = tokenizer.encode_plus(question, context)

input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]

start_scores, end_scores = model(torch.tensor([input_ids]).to(device), attention_mask=torch.tensor([attention_mask]).to(device))

ans_tokens = input_ids[torch.argmax(start_scores) : torch.argmax(end_scores)+1]
answer_tokens = tokenizer.convert_ids_to_tokens(ans_tokens , skip_special_tokens=True)

print ("\nQuestion ",question)
print ("\nAnswer Tokens: ")
print (answer_tokens)

#decode() works similar to doing self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))
answer_tokens_to_string = tokenizer.decode(ans_tokens)

print ("\nAnswer : ",answer_tokens_to_string)

end_time = time.time()

print("\nExecution Time: {} seconds.".format(end_time - start_time))


Question  How many surgical masks or respirators have past studies projected will be required for a pandemic in the United States?

Answer Tokens: 
['7', '.', '3', 'billion']

Answer :  7. 3 billion

Execution Time: 0.019294023513793945 seconds.


In [None]:
#model.save_pretrained('/content/drive/My Drive/colab_files/trained_models/Squad_CovidQA_Model')

In [None]:
#tokenizer.save_pretrained('/content/drive/My Drive/colab_files/trained_models/Squad_CovidQA_Model')

('/content/drive/My Drive/colab_files/trained_models/Squad_CovidQA_Model/vocab.txt',
 '/content/drive/My Drive/colab_files/trained_models/Squad_CovidQA_Model/special_tokens_map.json',
 '/content/drive/My Drive/colab_files/trained_models/Squad_CovidQA_Model/added_tokens.json')

In [None]:
val_encodings

{'input_ids': [[101, 17070, 26243, 28518, 2638, 1998, 16371, 14321, 2891, 1006, 1056, 1007, 8909, 2063, 10752, 25456, 2024, 5041, 1011, 8674, 3424, 24093, 2140, 5850, 2008, 20544, 25605, 15403, 16770, 1024, 1013, 1013, 7479, 1012, 13316, 5638, 1012, 17953, 2213, 1012, 9152, 2232, 1012, 18079, 1013, 7610, 2278, 1013, 4790, 1013, 7610, 2278, 28154, 21926, 12376, 2629, 1013, 21146, 1024, 20069, 2063, 2487, 2063, 17788, 14526, 2063, 2692, 22203, 16147, 2629, 2278, 2620, 16703, 2581, 2050, 26976, 2094, 2629, 2278, 21486, 2487, 2050, 2475, 14141, 2549, 7875, 2575, 2497, 2509, 6048, 1024, 12277, 1010, 1044, 6672, 9743, 1025, 5035, 1010, 16480, 3619, 6679, 3070, 1025, 16480, 1010, 7042, 14856, 3058, 1024, 2760, 1011, 5840, 1011, 2322, 9193, 1024, 2184, 1012, 28977, 2692, 1013, 1058, 18613, 12740, 17465, 2487, 6105, 1024, 10507, 1011, 2011, 10061, 1024, 16371, 14321, 20049, 3207, 11698, 2015, 2031, 2042, 4703, 4453, 2004, 3424, 24093, 2140, 6074, 1012, 1999, 3522, 2086, 1010, 17070, 26243, 2851

In [None]:
#!wget -P data/squad/ https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json

--2020-09-19 20:11:50--  https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json
Resolving rajpurkar.github.io (rajpurkar.github.io)... 185.199.109.153, 185.199.108.153, 185.199.111.153, ...
Connecting to rajpurkar.github.io (rajpurkar.github.io)|185.199.109.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4370528 (4.2M) [application/json]
Saving to: ‘data/squad/dev-v2.0.json’


2020-09-19 20:11:51 (19.8 MB/s) - ‘data/squad/dev-v2.0.json’ saved [4370528/4370528]



In [None]:
from transformers.data.processors.squad import SquadV2Processor

# this processor loads the SQuAD2.0 dev set examples
processor = SquadV2Processor()
examples = processor.get_dev_examples("./data/squad/", filename="dev-v2.0.json")
print(len(examples))

100%|██████████| 35/35 [00:03<00:00,  8.77it/s]

11873





In [None]:
print(len(examples))

11873


In [None]:
examples[0].qas_id

'56ddde6b9a695914005b9628'

In [None]:
val_encodings

{'input_ids': [[101, 17070, 26243, 28518, 2638, 1998, 16371, 14321, 2891, 1006, 1056, 1007, 8909, 2063, 10752, 25456, 2024, 5041, 1011, 8674, 3424, 24093, 2140, 5850, 2008, 20544, 25605, 15403, 16770, 1024, 1013, 1013, 7479, 1012, 13316, 5638, 1012, 17953, 2213, 1012, 9152, 2232, 1012, 18079, 1013, 7610, 2278, 1013, 4790, 1013, 7610, 2278, 28154, 21926, 12376, 2629, 1013, 21146, 1024, 20069, 2063, 2487, 2063, 17788, 14526, 2063, 2692, 22203, 16147, 2629, 2278, 2620, 16703, 2581, 2050, 26976, 2094, 2629, 2278, 21486, 2487, 2050, 2475, 14141, 2549, 7875, 2575, 2497, 2509, 6048, 1024, 12277, 1010, 1044, 6672, 9743, 1025, 5035, 1010, 16480, 3619, 6679, 3070, 1025, 16480, 1010, 7042, 14856, 3058, 1024, 2760, 1011, 5840, 1011, 2322, 9193, 1024, 2184, 1012, 28977, 2692, 1013, 1058, 18613, 12740, 17465, 2487, 6105, 1024, 10507, 1011, 2011, 10061, 1024, 16371, 14321, 20049, 3207, 11698, 2015, 2031, 2042, 4703, 4453, 2004, 3424, 24093, 2140, 6074, 1012, 1999, 3522, 2086, 1010, 17070, 26243, 2851

In [None]:
val_dataset.answer_qids

<__main__.SquadDataset at 0x7f3b6928b828>