In [None]:
import os
import json
import torch
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
os.environ['CUDAIDX'] = 'cuda:0'
import random
random.seed(42)

In [None]:
%env OPENAI_API_KEY=xxx
%env OPENAI_API_BASE=xxx

In [None]:
from PIL import Image
from IPython.core.display import HTML
from functools import partial
from pathlib import Path

from engine.utils import ProgramGenerator, ProgramInterpreter
from prompts.kilogram import KILOGRAM_CURATED_EXAMPLES

In [None]:
interpreter = ProgramInterpreter(dataset='kilogram')

In [None]:
def create_prompt(inputs,num_prompts=8,method='random',seed=42,group=0):
    if method=='all':
        prompt_examples = KILOGRAM_CURATED_EXAMPLES
    elif method=='random':
        random.seed(seed)
        prompt_examples = random.sample(KILOGRAM_CURATED_EXAMPLES,num_prompts)
    else:
        raise NotImplementedError

    prompt_examples = '\n'.join(prompt_examples)
    prompt_examples = f"""Considering the examples provided:\n\n
    {prompt_examples}
    """
    return prompt_examples + "\nInstruction: {instruction}\nProgram:".format(**inputs)
   
prompter = partial(create_prompt,method='all')
generator = ProgramGenerator(prompter=prompter,dataset='kilogram')

In [None]:
from tqdm import tqdm
from PIL import ImageDraw
test_file = os.path.join(Path.home(), 'codes/ExoViP/datasets/kilogram/test_whole.json')
with open(test_file) as jp:
    test = json.load(jp)
eval_pred = 0
eval_cnt = 0

for idx, dct in tqdm(test.items()):
    
    # eval_cnt += 1
    # if eval_cnt < 35: continue
    
    img_ids = dct['images']
    assert idx == img_ids[0]
    img_paths = [os.path.join(Path.home(), 'codes/ExoViP/datasets/kilogram/imgs', img_id+'.png') for img_id in img_ids]
    images = []
    for img_path in img_paths:
        image = Image.open(img_path)
        # display(image)
        image.thumbnail((224,224))
        images.append(image.convert('RGB'))
    init_state = dict(
        IMAGE=images,
        CT_SCORE=0
    )    
    instruction = dct['texts'][0]
    
    # prog,_ = generator.generate(dict(instruction=instruction))
    # # print(prog)
    # # print(instruction)
    # result, prog_state = interpreter.execute(prog,init_state,inspect=False)
    
    initial_prompts = []
    initial_prompts.append(generator.generate_prompt(dict(instruction=instruction)))
    prompt_examples = KILOGRAM_CURATED_EXAMPLES
    pre_instruct = "Think step by step to carry out the instruction. \
            while taking rejected solutions into account and learning from them. Here are evaluated solutions that were rejected: {rejected_solutions}\n\n \
                Answer the question without making the same mistakes you did with the evaluated rejected solutions. Be simple, Be direct, don't repeat or reply other thing \n\n \
                    Applicable modules include:  PART, SEGS, ALIGN, RESULT\
                        Following the examples provided:\n\n"

    # result, ct_score, cd_score = interpreter.aug_execute(initial_prompts, init_state, instruction, pre_instruct, prompt_examples, inspect=False, task='kilogram')
    
    try:
        result, ct_score, cd_score = interpreter.aug_execute(initial_prompts, init_state, instruction, pre_instruct, prompt_examples, inspect=False, task='kilogram')
        # prog,_ = generator.generate(dict(instruction=instruction))
        # result, prog_state = interpreter.execute(prog,init_state,inspect=False)
    except Exception as e:
        print(e)
        result = -1
    
    test[idx]['predict'] = result
    eval_pred += int(result == 0)
    eval_cnt += 1
    
    if eval_cnt % 100 == 0:
        print(f'step {eval_cnt} accuracy: ', round(eval_pred/eval_cnt, 2))
        

print('accuracy: ', eval_pred/len(test.keys()))
result_file = os.path.join(Path.home(), 'codes/ExoViP/results/kilogram/exovip_kilogram.json')
with open(result_file, 'w') as jp:
    json.dump(test, jp)    