In [None]:
import os
import sys
import json
from pathlib import Path
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
os.environ['CUDAIDX'] = 'cuda:0'
import numpy as np

import random
random.seed(42)

In [None]:
%env OPENAI_API_KEY=xxx
%env OPENAI_API_BASE=xxx

In [None]:
from PIL import Image
from IPython.display import HTML, display
from functools import partial

from engine.utils import ProgramGenerator, ProgramInterpreter
from prompts.agqa import AGQA_CURATED_EXAMPLES

In [None]:
interpreter = ProgramInterpreter(dataset='gqa')

In [None]:
def create_prompt(inputs,num_prompts=8,method='random',seed=42,group=0):
    if method=='all':
        prompt_examples = AGQA_CURATED_EXAMPLES
    elif method=='random':
        # random.seed(seed)
        prompt_examples = random.sample(AGQA_CURATED_EXAMPLES,num_prompts)
    else:
        raise NotImplementedError

    prompt_examples = '\n'.join(prompt_examples)
    prompt_examples = f"""Considering the examples provided:\n\n
    {prompt_examples}
    """

    return prompt_examples + "\nQuestion: {question}\nProgram:".format(**inputs)
prompter = partial(create_prompt,method='all')
generator = ProgramGenerator(prompter=prompter,dataset='gqa')

In [None]:
from tqdm import tqdm
test_file = os.path.join(Path.home(), 'codes/ExoViP/datasets/agqa/test.json')
with open(test_file) as jp:
    test = json.load(jp)
eval_pred = 0
eval_cnt = 0
eval_ct_score = 0
eval_ct_cnt = 0

for idx, dct in tqdm(test.items()):
    video_id = dct['video_id']
    video_dir_path = os.path.join(Path.home(), 'codes/visjoint/datasets/agqa/imgs', video_id)
    img_paths = sorted(os.listdir(video_dir_path))
    img_paths = [os.path.join(video_dir_path, i) for i in img_paths]
    # uniform sample 6 frames
    intv = np.linspace(start=0, stop=len(img_paths), num=7).astype(int)
    img_idxs = [(intv[i]+intv[i+1]-1)//2 for i in range(len(intv)-1)]
    assert len(img_idxs) == 6
    img_paths = [img_paths[i] for i in img_idxs]
    images = [Image.open(img).convert("RGB") for img in img_paths]
    init_state = dict(
        IMAGE=images,
        CT_SCORE=0,
    )

    question = dct['question']
    answer = dct['answer']

    if 'predict' not in test[idx]:
        # prog,_ = generator.generate(dict(question=question))
        # result, prog_state = interpreter.execute(prog,init_state,inspect=False)
        # break
        results = []
        ct_scores = []
        
        initial_prompts = []
        # initial_prompts.append(generator.generate_prompt(dict(question=question)))
        initial_prompts.append(generator.generate_prompt(dict(question=question)))
        prompt_examples = AGQA_CURATED_EXAMPLES
        pre_instruct = "Think step by step to answer the question. \
            while taking rejected solutions into account and learning from them. Here are evaluated solutions that were rejected: {rejected_solutions}\n\n \
                Answer the question without making the same mistakes you did with the evaluated rejected solutions. Be simple, Be direct, don't repeat or reply other thing \n\n \
                    Applicable modules include: FIND, VQA, EVAL, RESULT, MEASURE, GET/GET_BEFORE/GET_AFTER/GET_BETWEEN, CROP/CROP_RIGHTOF/CROP_LEFT/CROP_FRONTOF/CROP_INFRONTOF/CROP_INFRONT/CROP_BEHIND/CROP_AHEAD/CROP_BELOW/CROP_ABOVE \
                        Following the examples provided:\n\n"
        # pre_instruct =  "\nQuestion: {question}\nProgram:"
        # initial_instruct = "Here are evaluated solutions that were rejected: ###{rejected_solutions}###" + "without making the same mistakes you did with the evaluated rejected solutions. Be simple, Don't response anything unrelated to program\nQuestion: {question}\nProgram:"
        
        result, ct_score, cd_score = interpreter.aug_execute(initial_prompts, init_state, question, pre_instruct, prompt_examples, task='agqa')
        # print(result)
        # print(prog_state)
        test[idx]["predict"] = result
        test[idx]["ct_score"] = ct_score
        test[idx]["cd_score"] = cd_score
    else:
        result = test[idx]["predict"]
        ct_score = test[idx]["ct_score"]
        cd_score = test[idx]["cd_score"]
        eval_ct_cnt = 1
    eval_pred += int(result == answer)
    eval_ct_score += ct_score
    # if ct_score:
    #     eval_ct_cnt += 1
    eval_cnt += 1
    if eval_cnt % 100 == 0:
        print(f'step {eval_cnt} accuracy: ', round(eval_pred/eval_cnt, 2), 'ct_score: ', round(eval_ct_score/eval_cnt * 100, 2))

print("accuracy: ", eval_pred / len(test.keys()))
print('ct_score: ', eval_ct_score / eval_ct_cnt * 100)

result_file = os.path.join(Path.home(), 'codes/ExoViP/results/agqa/exovip_agqa.json')
with open(result_file, 'w') as jp:
    json.dump(test, jp)

