In [None]:
import json
from concurrent.futures import ProcessPoolExecutor
from utils.general import load_video_graph
from utils.chat_api import generate_messages, parallel_get_response
from retrieve import answer_with_retrieval
from prompts import prompt_agent_verify_answer


def verify_answers(qas):
    inputs = [
        [
            {
                "type": "text",
                "content": json.dumps(qa),
            },
            {
                "type": "text",
                "content": prompt_agent_verify_answer,
            },
            {
                "type": "text",
                "content": "Now answer if the answer from the baseline is correct or not:",
            },
        ]
        for qa in qas
    ]
    messages = [generate_messages(input) for input in inputs]
    model = "gpt-4o-2024-11-20"
    responses = parallel_get_response(model, messages)

    results = responses[0]

    # calculate the accuracy of the answers
    correct = 0
    for result in results:
        if result.lower().startswith("yes"):
            correct += 1
    accuracy = correct / len(results)

    return accuracy, results


def process_qa(qa):
    mem = load_video_graph(qa["mem_path"])
    question = qa["question"]
    agent_answer, sessions = answer_with_retrieval(mem, question)
    qa["agent_answer"] = agent_answer
    qa["sessions"] = sessions
    return qa


def process_qa_list(qa_list):
    for qa in qa_list:
        process_qa(qa)
    max_workers = 32
    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        new_qa_list = list(executor.map(process_qa, qa_list))
    return new_qa_list

In [None]:
qa_list = []
dataset = "data/annotations/small_test.jsonl"
dataset_with_agent_answer = "data/annotations/small_test_with_agent_answer.jsonl"

with open(dataset, "r") as f:
    for line in f:
        qa = json.loads(line)
        if qa["mem_path"]:
            qa_list.append(qa)

new_qa_list = process_qa_list(qa_list)
with open(dataset_with_agent_answer, "w") as f:
    for qa in new_qa_list:
        f.write(json.dumps(qa) + "\n")