In [4]:
import numpy as np
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import torch

In [6]:
class QA_Model:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased-distilled-squad")
        self.model = AutoModelForQuestionAnswering.from_pretrained("distilbert-base-cased-distilled-squad")
    
    def answer(self, text, questions):
        answerss = []
        for question in questions:
            inputs = self.tokenizer.encode_plus(question, text, add_special_tokens=True, return_tensors="pt")
            input_ids = inputs["input_ids"].tolist()[0]

            text_tokens = self.tokenizer.convert_ids_to_tokens(input_ids)
            answer_start_scores, answer_end_scores = self.model(**inputs, return_dict=False)

            answer_start = torch.argmax(
                answer_start_scores
            )  # Get the most likely beginning of answer with the argmax of the score
            answer_end = torch.argmax(answer_end_scores) + 1  # Get the most likely end of answer with the argmax of the score

            answer = self.tokenizer.convert_tokens_to_string(self.tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))

            print(f"Question: {question}")
            print(f"Answer: {answer}")
            print(f"Score: {torch.max(answer_start_scores)}")
            
            if torch.max(answer_start_scores) < 0:
                answer = " "
            answerss.append(answer)
        return answerss

In [None]:
folder = '../data/'
file = 'event_815871.txt'

questions = ["Who is the suspicious person?",
             "Where is the suspicious person?",
             "When did the suspicous activity take place?"]

qa = QA_Model()

answers = [[" "],[" "],[" "]]
f = open(folder + file)
text = f.readlines()

for i in range(len(text)):
    print(str(i) + " " + text[i])
    
for i in range(len(text)):
    print(f"Up to line {i}")
    section = " ".join(text[:i+1])
    answers_i = qa.answer(section, questions)
    for j in range(len(answers_i)):
        if answers_i[j] not in answers[j]:
            answers[j].append(answers_i[j])
print(answers)