In [None]:
import torch
from transformers import AutoTokenizer,BertTokenizerFast, BertForQuestionAnswering

In [None]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

model = BertForQuestionAnswering.from_pretrained('bert-base-uncased-finetuned-squad')
model.eval()

In [None]:
def predict(context,query):

  inputs = tokenizer.encode_plus(query, context, return_tensors='pt')

  outputs = model(**inputs)
  answer_start = torch.argmax(outputs[0]) 
  answer_end = torch.argmax(outputs[1]) + 1 

  answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))

  return answer

def normalize_text(s):
  import string, re

  def remove_articles(text):
    regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
    return re.sub(regex, " ", text)

  def white_space_fix(text):
    return " ".join(text.split())

  def remove_punc(text):
    exclude = set(string.punctuation)
    return "".join(ch for ch in text if ch not in exclude)

  def lower(text):
    return text.lower()

  return white_space_fix(remove_articles(remove_punc(lower(s))))

def compute_exact_match(prediction, truth):
    return int(normalize_text(prediction) == normalize_text(truth))

def compute_f1(prediction, truth):
  pred_tokens = normalize_text(prediction).split()
  truth_tokens = normalize_text(truth).split()
  
  if len(pred_tokens) == 0 or len(truth_tokens) == 0:
    return int(pred_tokens == truth_tokens)
  
  common_tokens = set(pred_tokens) & set(truth_tokens)
  
  if len(common_tokens) == 0:
    return 0
  
  prec = len(common_tokens) / len(pred_tokens)
  rec = len(common_tokens) / len(truth_tokens)
  
  return 2 * (prec * rec) / (prec + rec)

In [None]:
def give_an_answer(context,query,answer):

  prediction = predict(context,query)
  em_score = compute_exact_match(prediction, answer)
  f1_score = compute_f1(prediction, answer)

  print(f"Question: {query}")
  print(f"Prediction: {prediction}")
  print(f"True Answer: {answer}")
  print(f"EM: {em_score}")
  print(f"F1: {f1_score}")
  print("\n")

In [None]:
import json
from pathlib import Path
import torch
from torch.utils.data import DataLoader
import time

In [None]:
# Give the path for validation data
path = Path('/content/drive/MyDrive/Spring22/CS769/Project/squad/dev-v2.0.json')

# Open .json file
with open(path, 'rb') as f:
    squad_dict = json.load(f)

texts = []
queries = []
answers = []
qa_dict = {'what':[] , 'where': [], 'how': [], 'why':[], 'when': [], 'which':[],  'misc': [], 'who' : []}
qa_keys = ['what', 'where', 'how', 'why', 'when', 'which', 'who']

# Search for each passage, its question and its answer
for group in squad_dict['data']:
    for passage in group['paragraphs']:
        context = passage['context']
        for qa in passage['qas']:
            question = qa['question']
            id = qa['id']
            misc_flag = True
            for answer in qa['answers']:
              for key in qa_keys:
                  if key in question.lower():
                      qa_dict[key].append((id , question, context, answer))
                      misc_flag = False
                      break
              if misc_flag == True:
                  qa_dict['misc'].append((id , question, context, answer))
            
                # Store every passage, query and its answer to the lists
                # texts.append(context)
                # queries.append(question)
                # answers.append(answer)

# val_texts, val_queries, val_answers = texts, queries, answers
qa_dict
# lens = [(key , len(qa_dict[key])) for key in qa_dict.keys()]
# lens

In [None]:
# Give the path for validation data
path = Path('/content/drive/MyDrive/Spring22/CS769/Project/squad/dev-v2.0.json')

# Open .json file
with open(path, 'rb') as f:
    squad_dict = json.load(f)

texts = []
queries = []
answers = []
qa_dict = {'what':[] , 'where': [], 'how': [], 'why':[], 'when': [], 'which':[],  'misc': [], 'who' : []}
qa_keys = ['what', 'where', 'how', 'why', 'when', 'which', 'who']

for group in squad_dict['data']:
    for passage in group['paragraphs']:
        context = passage['context']
        for qa in passage['qas']:
            question = qa['question']
            id = qa['id']
            misc_flag = True
            for key in qa_keys:
                if key in question.lower():
                    qa_dict[key].append((id , question, context, [i['text'] for i in qa['answers']]))
                    misc_flag = False
                    break
            if misc_flag == True:
                qa_dict['misc'].append((id , question, context, [i['text'] for i in qa['answers']]))
            
                # texts.append(context)
                # queries.append(question)
                # answers.append(answer)

# val_texts, val_queries, val_answers = texts, queries, answers
qa_dict['how'][0]
# lens = [(key , len(qa_dict[key])) for key in qa_dict.keys()]
# lens

In [None]:
[(key , len(qa_dict[key])) for key in qa_dict.keys()]

In [None]:
# context = "Hi! My name is Alexa and I am 21 years old. I used to live in Peristeri of Athens, but now I moved on in Kaisariani of Athens."

# queries = ["How old is Alexa?",
#            "Where does Alexa live now?",
#            "Where Alexa used to live?"
#           ]
# answers = ["21",
#            "Kaisariani of Athens",
#            "Peristeri of Athens"]
pred_dict = {}
for key in ['what']:
  for element in qa_dict[key][0:3000]:
    if len(element[2]) > 512:
      context = element[2][0:512]
    else:
      context = element[2]
    query = element[1]
    pred_dict[element[0]] = predict(context , query)

# for q,a in zip(queries,answers):
#   give_an_answer(context,q,a)


In [None]:
len(pred_dict)

In [None]:
final_f1 = 0
ctr = 0
for element in qa_dict['what'][0:3000]:
  f1 = 0
  for gt in element[3]:
    ctr+=1
    f1 += compute_f1(pred_dict[element[0]] , gt)
  final_f1 += f1
final_f1 /= 3000



final_f1





#   print(pred_dict[element[0]], ':', element[3], end = '\t')
#   if pred_dict[element[0]] in [x.lower() for x in element[3]]:
#     pos_count += 1
#     print('matched')
#   elif pred_dict[element[0]] == '' and len(element[3]) == 0:
#     pos_count += 1
#     print('matched')
#   else:
#     neg_count += 1
#     print('')
# print(count / len(pred_dict))


In [None]:
print(sum(([len(qa_dict[key]) for key in qa_dict.keys()])) , len(pred_dict))

In [None]:
%cd /content/drive/MyDrive/Spring22/CS769/Project

In [None]:
with open('results_where.json', 'w') as fp:
    json.dump(pred_dict, fp)

In [None]:
d0context = """ Queen are a British rock band formed in London in 1970. Their classic line-up was Freddie Mercury (lead vocals, piano), 
            Brian May (guitar, vocals), Roger Taylor (drums, vocals) and John Deacon (bass). Their earliest works were influenced 
            by progressive rock, hard rock and heavy metal, but the band gradually ventured into more conventional and radio-friendly 
            works by incorporating further styles, such as arena rock and pop rock. """

queries = ["When did Queen found?",
           "Who were the basic members of Queen band?",
           "What kind of band they are?"
          ]
answers = ["1970",
           "Freddie Mercury, Brian May, Roger Taylor and John Deacon",
           "rock"
          ]

for q,a in zip(queries,answers):
  give_an_answer(context,q,a)

In [None]:
context = """ Mount Olympus is the highest mountain in Greece. It is part of the Olympus massif near 
              the Gulf of Thérmai of the Aegean Sea, located in the Olympus Range on the border between 
              Thessaly and Macedonia, between the regional units of Pieria and Larissa, about 80 km (50 mi) 
              southwest from Thessaloniki. Mount Olympus has 52 peaks and deep gorges. The highest peak, 
              Mytikas, meaning "nose", rises to 2917 metres (9,570 ft). It is one of the 
              highest peaks in Europe in terms of topographic prominence. """

queries = [
           "How many metres is Olympus?",
           "Where Olympus is near?",
           "How far away is Olympus from Thessaloniki?"
          ]
answers = [
           "2917",
           "Gulf of Thérmai of the Aegean Sea",
           "80 km (50 mi)"
          ]

for q,a in zip(queries,answers):
  give_an_answer(context,q,a)

In [None]:
context = """ The COVID-19 pandemic, also known as the coronavirus pandemic, is an ongoing pandemic of coronavirus disease 2019 (COVID-19) 
              caused by severe acute respiratory syndrome coronavirus 2 (SARS-CoV-2). It was first identified in December 2019 in Wuhan, China. 
              The World Health Organization declared the outbreak a Public Health Emergency of International Concern in January 2020 and a pandemic 
              in March 2020. As of 6 February 2021, more than 105 million cases have been confirmed, with more than 2.3 million deaths attributed to COVID-19.
              Symptoms of COVID-19 are highly variable, ranging from none to severe illness. The virus spreads mainly through the air when people are 
              near each other.[b] It leaves an infected person as they breathe, cough, sneeze, or speak and enters another person via their mouth, nose, or eyes. 
              It may also spread via contaminated surfaces. People remain infectious for up to two weeks, and can spread the virus even if they do not show symptoms.[9]"""

queries = [
           "What is COVID-19?",
           "What is caused by COVID-19?",
           "How many cases have been confirmed from COVID-19?",
           "How many deaths have been confirmed from COVID-19?",
           "How is COVID-19 spread?",
           "How long can an infected person remain infected?",
           "Can a infected person spread the virus even if they don't have symptoms?",
           "What do elephants eat?"
          ]
answers = [
           "an ongoing pandemic of coronavirus disease 2019",
           "severe acute respiratory syndrome coronavirus 2 (SARS-CoV-2)",
           "more than 105 million cases",
           "more than 2.3 million deaths",
           "mainly through the air when people are near each other. It leaves an infected person as they breathe, cough, sneeze, or speak and enters another person via their mouth, nose, or eyes. It may also spread via contaminated surfaces.",
           "up to two weeks",
           "yes",
           ""
          ]

for q,a in zip(queries,answers):
  give_an_answer(context,q,a)

In [None]:
context = """ Harry Potter is a series of seven fantasy novels written by British author, J. K. Rowling. The novels chronicle the lives of a young wizard, 
              Harry Potter, and his friends Hermione Granger and Ron Weasley, all of whom are students at Hogwarts School of Witchcraft and Wizardry. 
              The main story arc concerns Harry's struggle against Lord Voldemort, a dark wizard who intends to become immortal, overthrow the wizard 
              governing body known as the Ministry of Magic and subjugate all wizards and Muggles (non-magical people). Since the release of the first novel, 
              Harry Potter and the Philosopher's Stone, on 26 June 1997, the books have found immense popularity, positive reviews, and commercial success worldwide. 
              They have attracted a wide adult audience as well as younger readers and are often considered cornerstones of modern young adult literature.[2] 
              As of February 2018, the books have sold more than 500 million copies worldwide, making them the best-selling book series in history, and have been translated 
              into eighty languages.[3] The last four books consecutively set records as the fastest-selling books in history, with the final installment selling roughly 
              eleven million copies in the United States within twenty-four hours of its release.  """

queries = [
           "Who wrote Harry Potter's novels?",
           "Who are Harry Potter's friends?",
           "Who is the enemy of Harry Potter?",
           "What are Muggles?",
           "Which is the name of Harry Poter's first novel?",
           "When did the first novel release?",
           "Who was attracted by Harry Potter novels?",
           "How many languages Harry Potter has been translated into? "
          ]
answers = [
           "J. K. Rowling",
           "Hermione Granger and Ron Weasley",
           "Lord Voldemort",
           "non-magical people",
           "Harry Potter and the Philosopher's Stone",
           "26 June 1997",
           "a wide adult audience as well as younger readers",
           "eighty"
          ]

for q,a in zip(queries,answers):
  give_an_answer(context,q,a)