**Purpose:**

Inference with a DistilBERT model pretrained on SQuAD


In [None]:
%%capture
!pip install transformers

import time
import sys
import os
import contextlib

from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering
import torch

from google.colab import drive
drive.mount('/content/drive')

In [None]:
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', return_token_type_ids = True)
model = DistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased-distilled-squad')

In [None]:
# Inference:

start_time = time.time()
context = "The US has passed the peak on new coronavirus cases, " \
          "President Donald Trump said and predicted that some states would reopen this month. " \
          "The US has over 637,000 confirmed Covid-19 cases and over 30,826 deaths, the highest for any country in the world."

question = "What was President Donald Trump's prediction?"

encoding = tokenizer.encode_plus(question, context)


input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]

start_scores, end_scores = model(torch.tensor([input_ids]), attention_mask=torch.tensor([attention_mask]))

ans_tokens = input_ids[torch.argmax(start_scores) : torch.argmax(end_scores)+1]
answer_tokens = tokenizer.convert_ids_to_tokens(ans_tokens , skip_special_tokens=True)

print ("\nQuestion ",question)
print ("\nAnswer Tokens: ")
print (answer_tokens)

answer_tokens_to_string = tokenizer.convert_tokens_to_string(answer_tokens)

print ("\nAnswer : ",answer_tokens_to_string)

end_time = time.time()

print("\nExecution Time: {} seconds.".format(end_time - start_time))


Question  What was President Donald Trump's prediction?

Answer Tokens: 
['some', 'states', 'would', 're', '##open', 'this', 'month']

Answer :  some states would reopen this month

Execution Time: 0.13521170616149902 seconds.


In [None]:
from transformers.data.processors.squad import SquadV2Processor

# this processor loads the SQuAD2.0 dev set examples
processor = SquadV2Processor()
examples = processor.get_dev_examples("/content/drive/My Drive/colab_files/data/Covid-QA/", filename="Covid-QA-val.json")
print(len(examples))

100%|██████████| 29/29 [00:01<00:00, 15.12it/s]

215





In [None]:
# generate some maps to help us identify examples of interest
qid_to_example_index = {example.qas_id: i for i, example in enumerate(examples)}
qid_to_has_answer = {example.qas_id: bool(example.answers) for example in examples}
answer_qids = [qas_id for qas_id, has_answer in qid_to_has_answer.items() if has_answer]
no_answer_qids = [qas_id for qas_id, has_answer in qid_to_has_answer.items() if not has_answer]

In [None]:
def display_example(qid):    
    from pprint import pprint

    idx = qid_to_example_index[qid]
    q = examples[idx].question_text
    c = examples[idx].context_text
    a = [answer['text'] for answer in examples[idx].answers]
    
    print(f'Example {idx} of {len(examples)}\n---------------------')
    print(f"Q: {q}\n")
    print("Context:")
    pprint(c)
    print(f"\nTrue Answers:\n{a}")

#display_example(answer_qids[0])

In [None]:
import sys
sys.path.append('/content/drive/My Drive/colab_files/modules')

import infersent_glove_context_generation as ig

import time
import os
import contextlib
import torch
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [None]:
def get_prediction(qid):
    # given a question id (qas_id or qid), load the example, get the model outputs and generate an answer
    question = examples[qid_to_example_index[qid]].question_text
    doc_text = examples[qid_to_example_index[qid]].context_text

    with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
        context = ig.generate_context_from_doc(doc_text, question)
    context_tokens = nltk.word_tokenize(context)
    #print('\nContext token count: ', len(context_tokens))
    #print('\n\nContext tokens: ', context_tokens)

    inputs = tokenizer.encode_plus(question, context, return_tensors='pt')

    outputs = model(**inputs)
    answer_start = torch.argmax(outputs[0])  # get the most likely beginning of answer with the argmax of the score
    answer_end = torch.argmax(outputs[1]) + 1 

    answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))

    return answer

In [None]:
# these functions are heavily influenced by the HF squad_metrics.py script
def normalize_text(s):
    """Removing articles and punctuation, and standardizing whitespace are all typical text processing steps."""
    import string, re

    def remove_articles(text):
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        return re.sub(regex, " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def compute_exact_match(prediction, truth):
    return int(normalize_text(prediction) == normalize_text(truth))

def compute_f1(prediction, truth):
    pred_tokens = normalize_text(prediction).split()
    truth_tokens = normalize_text(truth).split()
    
    # if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise
    if len(pred_tokens) == 0 or len(truth_tokens) == 0:
        return int(pred_tokens == truth_tokens)
    
    common_tokens = set(pred_tokens) & set(truth_tokens)
    
    # if there are no common tokens then f1 = 0
    if len(common_tokens) == 0:
        return 0
    
    prec = len(common_tokens) / len(pred_tokens)
    rec = len(common_tokens) / len(truth_tokens)
    
    return 2 * (prec * rec) / (prec + rec)

def get_gold_answers(example):
    """helper function that retrieves all possible true answers from a squad2.0 example"""
    
    gold_answers = [answer["text"] for answer in example.answers if answer["text"]]

    # if gold_answers doesn't exist it's because this is a negative example - 
    # the only correct answer is an empty string
    if not gold_answers:
        gold_answers = [""]
        
    return gold_answers

In [None]:
answer_qids[0]

5232

In [None]:
start_time = time.time()
prediction = get_prediction(answer_qids[0])
example = examples[qid_to_example_index[answer_qids[0]]]

gold_answers = get_gold_answers(example)

em_score = max((compute_exact_match(prediction, answer)) for answer in gold_answers)
f1_score = max((compute_f1(prediction, answer)) for answer in gold_answers)

print(f"Question: {example.question_text}")
print(f"Prediction: {prediction}")
print(f"True Answers: {gold_answers}")
print(f"EM: {em_score} \t F1: {f1_score}")

print("\nExecution time: {}".format(time.time() - start_time))


Question: Why are nucleosides analogs used for chemotheraphy?
Prediction: they inhibit cellular dna / rna polymerases
True Answers: ['they inhibit cellular DNA/RNA polymerases']
EM: 0 	 F1: 0.7272727272727272

Execution time: 10.858467817306519


In [None]:
def evaluate_model():
    em_scores = []
    f1_scores = []

    for qid in answer_qids:
        prediction = get_prediction(qid)
        example = examples[qid_to_example_index[qid]]
        gold_answers = get_gold_answers(example)
        em_score = max((compute_exact_match(prediction, answer)) for answer in gold_answers)
        f1_score = max((compute_f1(prediction, answer)) for answer in gold_answers)

        em_scores.append(em_score)
        f1_scores.append(f1_score)

    avg_em = sum(em_scores) / len(em_scores)
    avg_f1 = sum(f1_scores) / len(f1_scores)

    print("\nAvg EM: {}".format(avg_em))
    print("\nAvg F1: {}".format(avg_f1))

In [None]:
start_time = time.time()
evaluate_model()
print("\n\nExecution time: {}".format(time.time() - start_time))

  Replacing by "</s>"..' % (sentences[i], i))
  Replacing by "</s>"..' % (sentences[i], i))
  Replacing by "</s>"..' % (sentences[i], i))
  Replacing by "</s>"..' % (sentences[i], i))
  Replacing by "</s>"..' % (sentences[i], i))
  Replacing by "</s>"..' % (sentences[i], i))
  Replacing by "</s>"..' % (sentences[i], i))
  Replacing by "</s>"..' % (sentences[i], i))
  Replacing by "</s>"..' % (sentences[i], i))
  Replacing by "</s>"..' % (sentences[i], i))
  Replacing by "</s>"..' % (sentences[i], i))
  Replacing by "</s>"..' % (sentences[i], i))
  Replacing by "</s>"..' % (sentences[i], i))
  Replacing by "</s>"..' % (sentences[i], i))
  Replacing by "</s>"..' % (sentences[i], i))
  Replacing by "</s>"..' % (sentences[i], i))
  Replacing by "</s>"..' % (sentences[i], i))
  Replacing by "</s>"..' % (sentences[i], i))
  Replacing by "</s>"..' % (sentences[i], i))
  Replacing by "</s>"..' % (sentences[i], i))
  Replacing by "</s>"..' % (sentences[i], i))
  Replacing by "</s>"..' % (senten


Avg EM: 0.16279069767441862

Avg F1: 0.294424725711961


Execution time: 2226.6570670604706


Avg EM: 0.16279069767441862

Avg F1: 0.294424725711961

Execution time: 2226.6570670604706