In [None]:
import os
import sys
import json
from pathlib import Path
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
os.environ['CUDAIDX'] = 'cuda:0'

import random
random.seed(42)

In [None]:
%env OPENAI_API_KEY=xxx
%env OPENAI_API_BASE=xxx

In [None]:
from PIL import Image
from IPython.display import HTML, display
from functools import partial

from engine.utils import ProgramGenerator, ProgramInterpreter
from prompts.gqa import GQA_CURATED_EXAMPLES

In [None]:
interpreter = ProgramInterpreter(dataset='gqa')

In [None]:
def create_prompt(inputs,num_prompts=8,method='random',seed=42,group=0):
    if method=='all':
        prompt_examples = GQA_CURATED_EXAMPLES
    elif method=='random':
        # random.seed(seed)
        prompt_examples = random.sample(GQA_CURATED_EXAMPLES,num_prompts)
    else:
        raise NotImplementedError

    prompt_examples = '\n'.join(prompt_examples)
    prompt_examples = f"""Considering the examples provided:\n\n
    {prompt_examples}
    """

    return prompt_examples + "\nQuestion: {question}\nProgram:".format(**inputs)

prompter = partial(create_prompt,method='all')
generator = ProgramGenerator(prompter=prompter,dataset='gqa')

In [None]:
from tqdm import tqdm
test_file = os.path.join(Path.home(), 'codes/ExoViP/datasets/gqa/test.json')
with open(test_file) as jp:
    test = json.load(jp)
eval_pred = 0
eval_cnt = 0
eval_ct_score = 0
eval_ct_cnt = 0

for idx, dct in tqdm(test.items()):
    img_id = dct['imageId']
    img_path = os.path.join(Path.home(), 'codes/ExoViP/datasets/gqa/imgs', img_id + '.jpg')
    image = Image.open(img_path)
    image.thumbnail((640, 640), Image.Resampling.LANCZOS)
    init_state = dict(
        IMAGE=image.convert("RGB"),
        CT_SCORE=0,
    )

    question = dct['question']
    answer = dct['answer']
    type = dct['type']
    if 'predict' not in test[idx]:
        # prog,_ = generator.generate(dict(question=question))
        # result, prog_state = interpreter.execute(prog,init_state,inspect=False)
        # break
        results = []
        ct_scores = []
        
        initial_prompts = []
        initial_prompts.append(generator.generate_prompt(dict(question=question)))
        result, prog_state = interpreter.aug_execute(initial_prompts, init_state)
        # print(result)
        # print(prog_state)
        ct_score = prog_state["CT_SCORE"]
        test[idx]["predict"] = result
        test[idx]["ct_score"] = ct_score
    else:
        result = test[idx]["predict"]
        ct_score = test[idx]["ct_score"]
        eval_ct_cnt = 1
    eval_pred += int(result == answer)
    eval_ct_score += ct_score
    # if ct_score:
    #     eval_ct_cnt += 1
    eval_cnt += 1
    if eval_cnt % 100 == 0:
        print(f'step {eval_cnt} accuracy: ', round(eval_pred/eval_cnt, 2), 'ct_score: ', round(eval_ct_score/eval_cnt * 100, 2))

print("accuracy: ", eval_pred / len(test.keys()))
print('ct_score: ', eval_ct_score / eval_ct_cnt * 100)

result_file = os.path.join(Path.home(), 'codes/ExoViP/results/gqa/exovip_gqa.json')
with open(result_file, 'w') as jp:
    json.dump(test, jp)

