In [1]:
import string, re
def get_prediction(context, question):
    inputs = tokenizer.encode_plus(question, context, return_tensors='pt').to(device)
    outputs = model(**inputs)

    answer_start = torch.argmax(outputs[0])  
    answer_end = torch.argmax(outputs[1]) + 1 

    answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))

    return answer

def normalize_text(s):
  # """Removing articles and punctuation, and standardizing whitespace are all typical text processing steps."""
    def remove_articles(text):
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        return re.sub(regex, " ", text)
    def white_space_fix(text):
        return " ".join(text.split())
    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)
    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def exact_match(prediction, truth):
    return bool(normalize_text(prediction) == normalize_text(truth))

def compute_f1(prediction, truth):
    pred_tokens = normalize_text(prediction).split()
    truth_tokens = normalize_text(truth).split()

    # if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise
    if len(pred_tokens) == 0 or len(truth_tokens) == 0:
        return int(pred_tokens == truth_tokens)

    common_tokens = set(pred_tokens) & set(truth_tokens)

    # if there are no common tokens then f1 = 0
    if len(common_tokens) == 0:
        return 0

    prec = len(common_tokens) / len(pred_tokens)
    rec = len(common_tokens) / len(truth_tokens)

    return round(2 * (prec * rec) / (prec + rec), 2)
  
def question_answer(context, question,answer):
    prediction = get_prediction(context,question)
    em_score = exact_match(prediction, answer)
    f1_score = compute_f1(prediction, answer)

    print(f'Question: {question}')
    print(f'Prediction: {prediction}')
    print(f'True Answer: {answer}')
    print(f'Exact match: {em_score}')
    print(f'F1 score: {f1_score}\n')

In [5]:
from transformers import BertForQuestionAnswering, BertTokenizerFast
import torch

model_path = './bert-squadv6'
model = BertForQuestionAnswering.from_pretrained(model_path)
tokenizer = BertTokenizerFast.from_pretrained(model_path)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f'Working on {device}')

model = model.to(device)

Working on cuda


In [6]:
context = """Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-YON-say) (born September 4, 1981) is an American singer, 
          songwriter, record producer and actress. Born and raised in Houston, Texas, she performed in various singing 
          and dancing competitions as a child, and rose to fame in the late 1990s as lead singer of R&B girl-group Destiny\'s Child. 
          Managed by her father, Mathew Knowles, the group became one of the world\'s best-selling girl groups of all time. 
          Their hiatus saw the release of Beyoncé\'s debut album, Dangerously in Love (2003), which established her as a solo artist worldwide, 
          earned five Grammy Awards and featured the Billboard Hot 100 number-one singles "Crazy in Love" and "Baby Boy"."""


questions = ["Who is the passage about?",
             "What city was Beyonce born in?",
             "When did Beyonce born?",
             "What is Beyonce's nationality?",
             "Who was the Destiny's group manager?",
             "What name has the Beyoncé's debut album?",
             "How many Grammy Awards did Beyonce earn?",
             "When did the Beyoncé's debut album release?",
             "Who was the lead singer of R&B girl-group Destiny's Child?", 
             "What magazine was Beyonce featured on?"]

answers = ["Beyonce Giselle Knowles - Carter", "Houston, Texas", "September 4, 1981",  
           "American", "Mathew Knowles", "Dangerously in Love", "five", "2003", 
           "Beyonce Giselle Knowles - Carter", "Billboard Hot 100"]

for question, answer in zip(questions, answers):
    question_answer(context, question, answer)

Question: Who is the passage about?
Prediction: mathew knowles
True Answer: Beyonce Giselle Knowles - Carter
Exact match: False
F1 score: 0.33

Question: What city was Beyonce born in?
Prediction: houston
True Answer: Houston, Texas
Exact match: False
F1 score: 0.67

Question: When did Beyonce born?
Prediction: september 4, 1981
True Answer: September 4, 1981
Exact match: True
F1 score: 1.0

Question: What is Beyonce's nationality?
Prediction: american
True Answer: American
Exact match: True
F1 score: 1.0

Question: Who was the Destiny's group manager?
Prediction: mathew knowles
True Answer: Mathew Knowles
Exact match: True
F1 score: 1.0

Question: What name has the Beyoncé's debut album?
Prediction: mathew knowles
True Answer: Dangerously in Love
Exact match: False
F1 score: 0

Question: How many Grammy Awards did Beyonce earn?
Prediction: five
True Answer: five
Exact match: True
F1 score: 1.0

Question: When did the Beyoncé's debut album release?
Prediction: 2003
True Answer: 2003
Ex

In [7]:
context = """Athens is the capital and largest city of Greece. Athens dominates the Attica region and is one of the world's oldest cities, 
             with its recorded history spanning over 3,400 years and its earliest human presence starting somewhere between the 11th and 7th millennium BC.
             Classical Athens was a powerful city-state. It was a center for the arts, learning and philosophy, and the home of Plato's Academy and Aristotle's Lyceum.
             It is widely referred to as the cradle of Western civilization and the birthplace of democracy, largely because of its cultural and political impact on the European continent—particularly Ancient Rome.
             In modern times, Athens is a large cosmopolitan metropolis and central to economic, financial, industrial, maritime, political and cultural life in Greece. 
             In 2021, Athens' urban area hosted more than three and a half million people, which is around 35% of the entire population of Greece.
             Athens is a Beta global city according to the Globalization and World Cities Research Network, and is one of the biggest economic centers in Southeastern Europe. 
             It also has a large financial sector, and its port Piraeus is both the largest passenger port in Europe, and the second largest in the world."""

questions = ["Which is the largest city in Greece?",
             "For what was the Athens center?",
             "Which city was the home of Plato's Academy?"]

answers = ["Athens", "center for the arts, learning and philosophy", "Athens"]

for question, answer in zip(questions, answers):
    question_answer(context, question, answer)

Question: Which is the largest city in Greece?
Prediction: athens
True Answer: Athens
Exact match: True
F1 score: 1.0

Question: For what was the Athens center?
Prediction: the arts, learning and philosophy
True Answer: center for the arts, learning and philosophy
Exact match: False
F1 score: 0.8

Question: Which city was the home of Plato's Academy?
Prediction: athens
True Answer: Athens
Exact match: True
F1 score: 1.0

