In [None]:
import json
from pathlib import Path
from torch.utils.data import DataLoader
import time
import pandas as pd

In [None]:
import torch
import transformers
from transformers import AutoTokenizer,BertTokenizerFast, BertForQuestionAnswering, DistilBertForQuestionAnswering

In [None]:
with open('/content/drive/MyDrive/Spring22/CS769/Project/squad/dev-v2.0.json' , 'r') as fp:
  data = json.load(fp)

In [None]:
data = data['data'] 

In [None]:
tokenizer = AutoTokenizer.from_pretrained('/content/drive/MyDrive/Spring22/CS769/Project/FInal/DistilBERT/distillbert-test-squad-trained')

model = DistilBertForQuestionAnswering.from_pretrained('/content/drive/MyDrive/Spring22/CS769/Project/FInal/DistilBERT/distillbert-test-squad-trained')
model.eval()

In [None]:
# tokenizer = AutoTokenizer.from_pretrained('/content/drive/MyDrive/Spring22/CS769/Project/FInal/BERT/bert-test-squad-trained')

# model = BertForQuestionAnswering.from_pretrained('/content/drive/MyDrive/Spring22/CS769/Project/FInal/BERT/bert-test-squad-trained')
# model.eval()

In [None]:
def predict(context,query):

  inputs = tokenizer.encode_plus(query, context, return_tensors='pt')

  outputs = model(**inputs)
  # print(outputs)
  answer_start = torch.argmax(outputs[0])  
  # print(answer_start)
  answer_end = torch.argmax(outputs[1]) + 1 
  # print(answer_end)

  answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))

  return answer

def normalize_text(s):
  """Removing articles and punctuation, and standardizing whitespace are all typical text processing steps."""
  import string, re

  def remove_articles(text):
    regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
    return re.sub(regex, " ", text)

  def white_space_fix(text):
    return " ".join(text.split())

  def remove_punc(text):
    exclude = set(string.punctuation)
    return "".join(ch for ch in text if ch not in exclude)

  def lower(text):
    return text.lower()

  return white_space_fix(remove_articles(remove_punc(lower(s))))

def compute_exact_match(prediction, truth):
    return int(normalize_text(prediction) == normalize_text(truth))

def compute_f1(prediction, truth):
  pred_tokens = normalize_text(prediction).split()
  truth_tokens = normalize_text(truth).split()
  
  if len(pred_tokens) == 0 or len(truth_tokens) == 0:
    return int(pred_tokens == truth_tokens)
  
  common_tokens = set(pred_tokens) & set(truth_tokens)
  
  if len(common_tokens) == 0:
    return 0
  
  prec = len(common_tokens) / len(pred_tokens)
  rec = len(common_tokens) / len(truth_tokens)
  
  return 2 * (prec * rec) / (prec + rec)

In [None]:
def give_an_answer(context,query,answer):

  prediction = predict(context,query)
  if prediction == '[CLS]':
    prediction = ''
  em_score = compute_exact_match(prediction, answer)
  f1_score = compute_f1(prediction, answer)

  # print(f"Question: {query}")
  # print(f"Prediction: {prediction}")
  # print(f"True Answer: {answer}")
  # print(f"EM: {em_score}")
  # print(f"F1: {f1_score}")
  print("\n")
  return prediction , f1_score , em_score

In [None]:
qa_dict = {'what':[] , 'where': [], 'how': [], 'why':[], 'when': [], 'which':[],  'misc': [], 'who' : []}
qa_keys = ['what', 'where', 'how', 'why', 'when', 'which', 'who']
answerable_dict = {'True' : [] , 'False' : []}
for passage in data:
  for paragraph in passage['paragraphs']:
    context = paragraph['context']
    qas = paragraph['qas']
    for i in range(len(qas)):
      if qas[i]['is_impossible'] == True:
        answerable_dict['True'].append(qas[i]['id'])
      else:
        answerable_dict['False'].append(qas[i]['id'])
      misc_flag = True
      for key in qa_keys:
        if key in qas[i]['question'].lower():
          qa_dict[key].append(qas[i]['id'])
          misc_flag = False
          break
      if misc_flag == True:
        qa_dict['misc'].append(qas[i]['id'])
print({key : len(answerable_dict[key]) for key in answerable_dict.keys()})
# print(qa_dict['what'])

In [None]:
from collections import defaultdict
import numpy as np

predictions = {}
scores = defaultdict(list)
count = 0
net_f1 = 0
net_em = 0

for passage in data[0:18]:
  for paragraph in passage['paragraphs']:
    context = paragraph['context']
    if len(context) > 512:
      context = context[0:512]
    qas = paragraph['qas']
    for i in range(len(qas)):

      # misc_flag = True
      # for key in qa_keys:
      #   if key in qas[i]['question'].lower():
      #     qa_dict[key].append(qas[i]['id'])
      #     misc_flag = False
      #     break
      # if misc_flag == True:
      #   qa_dict['misc'].append(qas[i]['id'])

      if len(qas[i]['answers']) == 0:
        count += 1
        prediction , f1_score , em_score = give_an_answer(context, qas[i]['question'] , '')
        predictions[qas[i]['id']] = prediction
        scores[qas[i]['id']] = [f1_score , em_score]
        # print(scores, predictions)
        net_f1 += f1_score
        net_em += em_score
      else:
        temp_f1 = []
        temp_em = []
        for j in range(len(qas[i]['answers'])):
          count += 1
          prediction , f1_score , em_score = give_an_answer(context, qas[i]['question'] , qas[i]['answers'][j]['text'])
          predictions[qas[i]['id']] = prediction
          temp_f1.append(f1_score)
          temp_em.append(em_score)
          net_f1 += f1_score
          net_em += em_score
        scores[qas[i]['id']] = [np.mean(temp_f1) , np.mean(temp_em)]
        # print(scores, predictions)
        

print('Total F1 score is ' , net_f1/count)
print('Net EM score is ' , net_em/count)
print(len(predictions))

In [None]:
with open('/content/drive/MyDrive/Spring22/CS769/Project/FInal/mapping.json' , 'w') as p:
  json.dump(qa_dict , p)

In [None]:
with open('/content/drive/MyDrive/Spring22/CS769/Project/FInal/DistilBERT/predictions.json' , 'w') as fp:
  json.dump(predictions , fp)

In [None]:
with open('/content/drive/MyDrive/Spring22/CS769/Project/FInal/DistilBERT/scores.json' , 'w') as f:
  json.dump(scores , f)

In [None]:
'5ad39d53604f3c001a3fe8d1' in answerable_dict['True']

In [None]:
what_f1 = 0
what_em = 0
ct = 0
for id in answerable_dict['False']:
  if id in scores:
    ct += 1
    what_f1 += scores[id][0]
    what_em += scores[id][1]

print('F1: ' , what_f1/ct)


In [None]:
# EXAMPLE

context = data[0]['paragraphs'][0]['context']
qas= data[0]['paragraphs'][0]['qas']
qas[0]

In [None]:
count = 0
net_f1 = 0
net_em = 0
for i in range(len(qas)):
  if len(qas[i]['answers']) == 0:
    count += 1
    f1_score , em_score = give_an_answer(context, qas[i]['question'] , '')
    net_f1 += f1_score
    net_em += em_score
  else:
    for j in range(len(qas[i]['answers'])):
      count += 1
      f1_score , em_score = give_an_answer(context, qas[i]['question'] , qas[i]['answers'][j]['text'])
      net_f1 += f1_score
      net_em += em_score

print('Total F1 score is ' , net_f1/count)
print('Net EM score is ' , net_em/count)