<a href="https://colab.research.google.com/github/hassansardar193/Natural_Language_Processing-Question_Answering/blob/main/QA_bert_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering

def calculate_f1_score(predictions, ground_truths):
    tp, fp, fn = 0, 0, 0
    for prediction, ground_truth in zip(predictions, ground_truths):
        if prediction == ground_truth:
            tp += 1
        else:
            fp += 1
            fn += 1
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    f1 = 2 * (precision * recall) / (precision + recall)
    return f1

def calculate_exact_match(predictions, ground_truths):
    em = 0
    for prediction, ground_truth in zip(predictions, ground_truths):
        if prediction == ground_truth:
            em += 1
    em = em / len(predictions)
    return em

def answer_question(model, tokenizer, question, text):
    input_ids = tokenizer.encode(question, text)
    attention_mask = [1] * len(input_ids)
    start_scores, end_scores = model(torch.tensor([input_ids]), attention_mask=torch.tensor([attention_mask]))
    answer_start = torch.argmax(start_scores)
    answer_end = torch.argmax(end_scores)
    answer = tokenizer.decode(input_ids[answer_start:answer_end+1])
    return answer

model_name = "distilbert-base-cased-distilled-squad"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)

questions = [..., ..., ...]  # List of questions
texts = [..., ..., ...]  # List of text passages
predictions = []
ground_truths = []

for question, text in zip(questions, texts):
    prediction = answer_question(model, tokenizer, question, text)
    ground_truth = ...  # Get the ground truth answer for the question
    predictions.append(prediction)
    ground_truths.append(ground_truth)

f1 = calculate_f1_score(predictions, ground_truths)
em = calculate_exact_match(predictions, ground_truths)

print("F1 Score: ", f1)
print("Exact Match: ", em)
