In [2]:
import pandas as pd
import nltk
from tqdm import tqdm
from transformers import AutoTokenizer, T5ForConditionalGeneration
import torch
import json

In [None]:
agent_model_name = 'google/flan-t5-xl'
agent_tokenizer = AutoTokenizer.from_pretrained(agent_model_name)
agent_model = T5ForConditionalGeneration.from_pretrained(agent_model_name)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps')
agent_model = agent_model.to(device)
agent_model.eval()

In [5]:
# Load JSON from a file
with open('../checkpoint/results.json', 'r') as file:
    results = json.load(file)

In [6]:
question3 = "Given a sentence and a relative timing word, transform the timing word into its corresponding numeric value. Return only the number. If no number can be determined, return 'None'."

question4 = """Given a sentence and a relative timing word, select the option that best matches the unit of the timing word. If none of the options are appropriate, choose 'Other':
hour or hours
minute or minutes
day or days
month or months
week or weeks
year or years
other
"""

question5 = "Given a sentence and a relative timing expression, identify whether the timing indicates 'before' or 'after' the reference point in the sentence."

question345 = [question3, question4, question5]

In [7]:
def model_QA(query, question, timing_word):
    input_text = f"""
    question: {question}
    given sentence: {query}
    relative timing word: {timing_word}
    answer:
    """
    inputs = agent_tokenizer(input_text, return_tensors="pt").to(device)
    output_ids = agent_model.generate(inputs["input_ids"], max_length=500, num_beams=4, early_stopping=True,do_sample=True, temperature = 0.9)
    answer = agent_tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return answer

def sequence_QA(query, questions, timing_word):
    answers = []
    for question in questions:
        answer = model_QA(query, question, timing_word)
        answers.append(answer)
    keys = ['number', 'unit', 'before_or_after']
    answer_dict = dict(zip(keys, answers))
    return answer_dict

In [None]:
filtered_results = {}

for key in tqdm(results):
    tmp_timing_words = []
    tmp_category_ls = []
    tmp_useful_sentences = []
    tmp_exact_timing = []

    values = results[key]
    useful_sentences = values['useful_sentence']
    timing_words = values['relative_timing']
    category_ls = values['category']
    for idx, sentence in enumerate(useful_sentences):
        timing_word = timing_words[idx]
        query_category = category_ls[idx]
        if query_category != 'Other':
            tmp_timing_words.append(timing_word)
            tmp_category_ls.append(query_category)
            tmp_useful_sentences.append(sentence)
            answer = sequence_QA(sentence, question345, timing_word)
            tmp_exact_timing.append(answer)
    if len(tmp_timing_words) != 0:
        filtered_results[key] = {
            "useful_sentence": tmp_useful_sentences,
            "original_timing_word": tmp_timing_words,
            "category": tmp_category_ls,
            "exact_timing": tmp_exact_timing
        }        

In [None]:
output_file = "../checkpoint/filtered_results.json"
with open(output_file, "w") as json_file:
    json.dump(filtered_results, json_file, indent=4)

print(f"Data successfully saved to {output_file}")