In [None]:
import json
from concurrent.futures import ProcessPoolExecutor
from utils.general import load_video_graph
from utils.chat_api import generate_messages, get_response_with_retry
from retrieve import answer_with_retrieval
from prompts import prompt_agent_verify_answer


def process_qa(qa):
    mem = load_video_graph(qa["mem_path"])
    question = qa["question"]
    agent_answer, session = answer_with_retrieval(mem, question)
    qa["agent_answer"] = agent_answer
    qa["session"] = session
    return qa


def verify_qa(qa):
    questions = qa["question"]
    ground_truth = qa["answer"]
    agent_answer = qa["agent_answer"]
    qa_sample = {
        "question": questions,
        "ground_truth_answer": ground_truth,
        "agent_answer": agent_answer,
    }

    input = [
        {
            "type": "text",
            "content": json.dumps(qa_sample),
        },
        {
            "type": "text",
            "content": prompt_agent_verify_answer,
        },
        {
            "type": "text",
            "content": "Now answer if the answer from the baseline is correct or not:",
        },
    ]
    messages = generate_messages(input)
    model = "gpt-4o-2024-11-20"
    response = get_response_with_retry(model, messages)
    qa["verify_result"] = response[0]

    return qa


def process_qa_list(qa_list):
    for qa in qa_list:
        process_qa(qa)
    max_workers = 16
    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        new_qa_list = list(executor.map(process_qa, qa_list))
    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        new_qa_list = list(executor.map(verify_qa, new_qa_list))
    return new_qa_list

In [None]:
qa_list = []
dataset = "data/annotations/small_train.jsonl"
dataset_with_agent_answer = "data/annotations/small_train_with_agent_answer.jsonl"

with open(dataset, "r") as f:
    for line in f:
        qa = json.loads(line)
        if qa["mem_path"]:
            qa_list.append(qa)

new_qa_list = process_qa_list(qa_list[:5])
with open(dataset_with_agent_answer, "w") as f:
    for qa in new_qa_list:
        f.write(json.dumps(qa) + "\n")