In [None]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
os.environ['CUDAIDX'] = 'cuda:0'
import random
random.seed(42)

In [None]:
%env OPENAI_API_KEY=xxx
%env OPENAI_API_BASE=xxx

In [None]:
import json
from PIL import Image
from IPython.core.display import HTML
from functools import partial
from pathlib import Path

from engine.utils import ProgramGenerator, ProgramInterpreter
from prompts.nlvr import NLVR_CURATED_EXAMPLES

In [None]:
interpreter = ProgramInterpreter(dataset='nlvr')

In [None]:
def create_prompt(inputs,num_prompts=8,method='random',seed=42,group=0):
    if method=='all':
        prompt_examples = NLVR_CURATED_EXAMPLES
    elif method=='random':
        random.seed(seed)
        prompt_examples = random.sample(NLVR_CURATED_EXAMPLES,num_prompts)
    else:
        raise NotImplementedError

    prompt_examples = '\n'.join(prompt_examples)
    prompt_examples = f"""Considering the examples provided:\n\n
    {prompt_examples}
    """
    return prompt_examples + "\nStatement: {statement}\nProgram:".format(**inputs)

prompter = partial(create_prompt,method='all')
generator = ProgramGenerator(prompter=prompter,dataset='nlvr')


In [None]:
from tqdm import tqdm
test_file = os.path.join(Path.home(), 'codes/ExoViP/datasets/nlvr2/test.json')
with open(test_file) as jp:
    test = json.load(jp)
eval_pred = 0
eval_cnt = 0
        
for idx, dct in tqdm(test.items()):
    left_img_id = dct['left']
    left_img_path = os.path.join(Path.home(), 'codes/ExoViP/datasets/nlvr2/imgs', left_img_id)
    left_image = Image.open(left_img_path)
    left_image.thumbnail((640, 640), Image.Resampling.LANCZOS)
    right_img_id = dct['right']
    right_img_path = os.path.join(Path.home(), 'codes/ExoViP/datasets/nlvr2/imgs', right_img_id)
    right_image = Image.open(right_img_path)
    right_image.thumbnail((640, 640), Image.Resampling.LANCZOS)
    init_state = dict(
        LEFT=left_image.convert('RGB'),
        RIGHT=right_image.convert('RGB'),
        CT_SCORE=0
    )
    statement = dct['sentence']
    if 'predcit ' not in test[idx]:
        # prog,_ = generator.generate(dict(statement=statement))
        # result, prog_state = interpreter.execute(prog, init_state, inspect=False)
        initial_prompts = []
        initial_prompts.append(generator.generate_prompt(dict(statement=statement)))
        prompt_examples = NLVR_CURATED_EXAMPLES
        pre_instruct = "Think step by step if the statement is True or False. \
            while taking rejected solutions into account and learning from them. Here are evaluated solutions that were rejected: {rejected_solutions}\n\n \
                Answer the question without making the same mistakes you did with the evaluated rejected solutions. Be simple, Be direct, don't reply other thing \n\n \
                    Applicable modules include: VQA, EVAL, RESULT \
                        Following the examples provided:\n\n"
        
        result, ct_score, cd_score = interpreter.aug_execute(initial_prompts, init_state, statement, pre_instruct, prompt_examples, task='nlvr')
        
        test[idx]['predict'] = result
    else:
        result = test[idx]['predict']
    label = eval(dct['label'])
    
    # ############### eval #####################
    # if result != 'NA' and result != label:
    #     prog, _ = generator.generate(dict(statement=statement))
    #     print(prog)
    #     reuslt, prog_state, html_str = interpreter.execute(prog, init_state, inspect=True)
    #     print(statement)
    #     print(label)
    #     print(result)
    #     display(HTML(html_str))
    # ###########################################
    
    print('result: ', result)
    print('label: ', label)
    
    eval_pred += int(result == label)
    eval_cnt += 1
    if eval_cnt % 10  == 0:
        # break
        print(f'step {eval_cnt} accuracy: ', round(eval_pred/eval_cnt, 2))

print('accuracy: ', eval_pred/len(test.keys()))

result_file = os.path.join(Path.home(), 'codes/ExoViP/results/nlvr2/exovip_nlvr.json')
with open(result_file, 'w') as jp:
    json.dump(test, jp)
