In [None]:
import os
import sys
import json
from pathlib import Path
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
os.environ['CUDAIDX'] = 'cuda:2'
import random
random.seed(42)

In [None]:
%env OPENAI_API_KEY=xxx
%env OPENAI_API_BASE=xxx

In [None]:
from PIL import Image
from IPython.core.display import HTML
from functools import partial

from engine.utils import ProgramGenerator, ProgramInterpreter
from prompts.magicbrush import MAGICBRUSH_CURATED_EXAMPLES

In [None]:
interpreter = ProgramInterpreter(dataset='magicbrush')

In [None]:
def create_prompt(inputs,num_prompts=8,method='random',seed=42,group=0):
    if method=='all':
        prompt_examples = MAGICBRUSH_CURATED_EXAMPLES
    elif method=='random':
        random.seed(seed)
        prompt_examples = random.sample(MAGICBRUSH_CURATED_EXAMPLES,num_prompts)
    else:
        raise NotImplementedError

    prompt_examples = '\n'.join(prompt_examples)
    prompt_examples = f"""Considering the examples provided:\n\n
    {prompt_examples}
    """
    return prompt_examples + "\nInstruction: {instruction}\nProgram:".format(**inputs)

prompter = partial(create_prompt,method='all')
generator = ProgramGenerator(prompter=prompter,dataset='magicbrush')

In [None]:
from .utils.image_eval import eval_distance, eval_clip_i, eval_clip_t
import clip
import torch
from torchvision.transforms import transforms
from transformers import CLIPModel, ViTModel

evaluation = 'l1,l2,clip-i,dino,clip-t'
evaluation = 'l1,l2,clip-i,dino'
device = os.environ.get('CUDAIDX', 'cuda:0')

from tqdm import tqdm
test_file = os.path.join(Path.home(), 'codes/ExoViP/datasets/magicbrush/test.json')
with open(test_file) as jp:
    test = json.load(jp)
    
eval_pred = 0
eval_cnt = 0

result_dir = os.path.join(Path.home(), 'codes/ExoViP/results/magicbrush/')
if not os.path.exists(result_dir): 
    os.makedirs(result_dir)

result_pairs = []
cnt = 0

for idx, dct in tqdm(test.items()):
    source_img = dct['source']
    source_img_path = os.path.join(Path.home(), 'codes/ExoViP/datasets/magicbrush/imgs', source_img)
    source_image = Image.open(source_img_path)
    source_image.thumbnail((512, 512),Image.Resampling.LANCZOS)
    target_img = dct['target']
    target_img_path = os.path.join(Path.home(), 'codes/ExoViP/datasets/magicbrush/imgs', target_img)
    # target_image = Image.open(target_img_path)
    # target_image.thumbnail((512, 512),Image.Resampling.LANCZOS)
    init_state = dict(
        IMAGE=source_image.convert('RGB'),
        CT_SCORE=0
    )
    instruction = dct['instruction']
    
    results = []
    ct_scores = []
    initial_prompts = []
    initial_prompts.append(generator.generate_prompt(dict(instruction=instruction)))
    prompt_examples = MAGICBRUSH_CURATED_EXAMPLES
    pre_instruct = "Think step by step to carry out the instruction. \
            while taking rejected solutions into account and learning from them. Here are evaluated solutions that were rejected: {rejected_solutions}\n\n \
                Answer the question without making the same mistakes you did with the evaluated rejected solutions. Be simple, Be direct, don't repeat or reply other thing \n\n \
                    Applicable modules include: SEG, SELECT, REPLACE, RESULT \
                        Following the examples provided:\n\n"
    
    # prog,_ = generator.generate(instruction)
    # print(instruction)
    # print(prog)
    # cnt += 1
    # if cnt < 4: continue
    # result, prog_state = interpreter.execute(prog, init_state, inspect=False)

#     prog,_ = generator.generate(dict(instruction=instruction))
#     result, prog_state = interpreter.execute(prog, init_state, inspect=False)
    
    # result, ct_score, cd_score = interpreter.aug_execute(initial_prompts, init_state, instruction, pre_instruct, prompt_examples, task='magicbrush')
    
    try:
        
        result, ct_score, cd_score = interpreter.aug_execute(initial_prompts, init_state, instruction, pre_instruct, prompt_examples, task='magicbrush')
        if result == 'NA': result = source_image
        # prog,_ = generator.generate(dict(instruction=instruction))
        # result, prog_state = interpreter.execute(prog, init_state, inspect=False)
    except Exception as e:
        print(e)
        result = source_image
        # continue
    
    result_path = os.path.join(result_dir, idx+'.png')
    result.save(result_path)
    
    result_pairs.append((result_path, target_img_path))
    
    # if cnt>4: break

# distance
if 'l1' in evaluation:
    l1_score = eval_distance(result_pairs, 'l1')
    print('l1: ', l1_score)
if 'l2' in evaluation:
    l2_score = eval_distance(result_pairs, 'l2')
    print('l2: ', l2_score)
# quality
if 'clip-i' in evaluation:
    model = CLIPModel.from_pretrained('openai/clip-vit-base-patch32').to(device)
    transform = transforms.Compose([
        transforms.Resize(256, interpolation=3),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    # model, transform = clip.load("ViT-B/32", device)
    # print("CLIP-I model loaded: ", model)
    clip_i_score = eval_clip_i(result_pairs, model, transform, device)
    print('clip-i: ', clip_i_score)
if 'dino' in evaluation:
    # model = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
    model = ViTModel.from_pretrained('facebook/dino-vits16')
    model.eval()
    model.to(device)
    transform = transforms.Compose([
        transforms.Resize(256, interpolation=3),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    dino_score = eval_clip_i(result_pairs, model, transform, device, metric='dino')
    print('dino: ', dino_score)
if 'clip-t' in evaluation:
    # model, transform = clip.load("ViT-B/32", device)
    model = CLIPModel.from_pretrained('openai/clip-vit-base-patch32').to(device)
    transform = transforms.Compose([
        transforms.Resize(256, interpolation=3),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    clip_t_score, final_turn_oracle_score = eval_clip_t(result_pairs, model, transform, device)
    print('clip-t', clip_t_score)
    
