In [None]:
import os
import re
import sys
import json
from pathlib import Path
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
from PIL import Image
from IPython.core.display import HTML
from functools import partial

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
import torch

In [None]:
# Note: The default behavior now has injection attack prevention off.
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat", trust_remote_code=True)

# use bf16
# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat", device_map="auto", trust_remote_code=True, bf16=True).eval()
# use fp16
# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat", device_map="auto", trust_remote_code=True, fp16=True).eval()
# use cpu only
# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat", device_map="cpu", trust_remote_code=True).eval()
# use cuda device
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat", device_map="cuda", trust_remote_code=True).eval()

# Specify hyperparameters for generation
model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-VL-Chat", trust_remote_code=True)


In [None]:
from tqdm import tqdm
from PIL import ImageDraw
test_file = os.path.join(Path.home(), 'codes/ExoViP/datasets/refcoco/test.json')
with open(test_file) as jp:
    test = json.load(jp)
eval_pred = 0
eval_cnt = 0

for idx, dct in tqdm(test.items()):
    # eval_cnt += 1
    # if eval_cnt < 5: continue
    
    img_id = dct['img']
    img_path = os.path.join(Path.home(), 'codes/ExoViP/datasets/refcoco/imgs', img_id)
    image = Image.open(img_path)
    h, w = image.height, image.width
    
    instruction = dct['instruction']
    # print(instruction)
    
    query = tokenizer.from_list_format([
        {"image": img_path,
         "text": instruction}
    ])
    
    response, history = model.chat(tokenizer, query=query, history=None)
    # image = tokenizer.draw_bbox_on_latest_picture(response, history)
    # image.save(str(eval_cnt)+'.jpg')
    # display(image)
    PATTERN = re.compile(r'\((.*?)\),\((.*?)\)')
    predict_bbox = re.findall(PATTERN, response)
    try:
        if ',' not in predict_bbox[0][0] or ',' not in predict_bbox[0][
                1]:
            predict_bbox = (0., 0., 0., 0.)
        else:
            x1, y1 = [
                float(tmp) for tmp in predict_bbox[0][0].split(',')
            ]
            x2, y2 = [
                float(tmp) for tmp in predict_bbox[0][1].split(',')
            ]
            
            # x1, y1, x2, y2 = box['box']
            x1, y1, x2, y2 = (int(x1 / 1000 * w), int(y1 / 1000 * h), int(x2 / 1000 * w), int(y2 / 1000 * h))
            predict_bbox = (x1, y1, x2, y2)
    except:
        predict_bbox = (0., 0., 0., 0.)
    box = predict_bbox
    label = dct['box']
    # print(box)
    # print(label)
    # print()
    # draw = ImageDraw.Draw(image)
    # draw.rectangle(box,outline='red',width=4)
    # draw.rectangle(label,outline='green',width=4)
    # image.save(str(eval_cnt)+'.jpg')
    
    # calculate iou
    label_area = (label[2]-label[0]) * (label[3] - label[1])
    box_area = (box[2]-box[0]) * (box[3] - box[1])
    x1 = max(box[0], label[0])
    x2 = min(box[2], label[2])
    y1 = max(box[1], label[1])
    y2 = min(box[3], label[3])
    intersection = max(0, x2-x1) * max(0, y2-y1)
    iou = intersection / (label_area + box_area - intersection)
    # print(iou)
    eval_pred += iou
    eval_cnt += 1
    
    
    # # visualize
    # # W,H=image.size
    # draw = ImageDraw.Draw(result)
    # draw.rectangle(label,outline='red',width=4)
    # result.save(f'{idx}.jpg')
    # print(idx, instruction)
    # if eval_cnt > 5:
    #     break
    
    if eval_cnt % 20 == 0:
        print(f'step {eval_cnt} iou: ', round(eval_pred/eval_cnt, 2))
        # break

print('iou: ', eval_pred/len(test.keys()))
result_file = os.path.join(Path.home(), 'codes/visprog/results/refcoco/qwen.json')
with open(result_file, 'w') as jp:
    json.dump(test, jp)
