In [None]:
from PIL import Image
import base64
import anthropic
from dotenv import load_dotenv
import os
from collections import defaultdict
import os
import random
from tqdm import tqdm

random.seed(42)
load_dotenv()


def resize_image(image_path, size):
    """resize image so that the largest edge is atmost size"""
    img = Image.open(image_path)
    width, height = img.size

    if width <= size and height <= size:
        return img

    if width > height:
        new_width = size
        new_height = int(height * (size / width))
    else:
        new_height = size
        new_width = int(width * (size / height))
    img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
    return img


def encode_image(image_path):
    img = resize_image(image_path, 512)
    temp_name = "temp.jpg"
    img.save(temp_name)
    with open(temp_name, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")


def construct_mcq(options, correct_option):
    correct_option_letter = None
    i = "a"
    mcq = ""

    for option in options:
        if option == correct_option:
            correct_option_letter = i
        mcq += f"{i}. {option}\n"
        i = chr(ord(i) + 1)

    mcq = mcq[:-1]
    return mcq, correct_option_letter


def add_row(content: list[dict], data: dict, i: int, with_answer=False) -> list[dict]:
    content.append(
        {
            "type": "text",
            "text": "Image " + str(i) + ": " + data["question"] + "\n" + data["mcq"],
        }
    )

    content.append(
        {
            "type": "image",
            "source": {
                "type": "base64",
                "media_type": "image/jpeg",
                "data": encode_image(data["image_path"]),
            },
        }
    )

    if with_answer:
        content.append(
            {
                "type": "text",
                "text": "Reasoning: {}".format(data["reasoning"]),
            }
        )
        content.append(
            {
                "type": "text",
                "text": "Answer {}: ".format(i) + data["correct_option_letter"],
            }
        )
    else:
        content.append(
            {
                "type": "text",
                "text": "Answer {}: ".format(i),
            }
        )

    return content

In [None]:
FEWSHOT_JSON = "./illusionVQA/sofloc/fewshot_labels.json"
FEWSHOT_IMAGE_DIR = "./illusionVQA/sofloc/FEW_SHOTS/"
EVAL_JSON = "./illusionVQA/sofloc/eval_labels.json"
EVAL_IMAGE_DIR = "./illusionVQA/sofloc/EVAL/"

client = anthropic.Anthropic()
model_name = "claude-3-5-sonnet-20240620"

In [None]:
import json

with open(FEWSHOT_JSON) as f:
    fewshot_dataset = json.load(f)

for data in fewshot_dataset:
    data["image_path"] = FEWSHOT_IMAGE_DIR + data["image"]
    data["mcq"], data["correct_option_letter"] = construct_mcq(
        data["options"], data["answer"]
    )

In [None]:
with open(EVAL_JSON) as f:
    eval_dataset = json.load(f)

random.shuffle(eval_dataset)

In [None]:
category_count = defaultdict(int)
misc_cats = [
    "counting",
    "repeating pattern",
    "perspective",
    "occlusion",
    "angle constancy",
]

for data in eval_dataset:
    if data["image"] not in os.listdir(EVAL_IMAGE_DIR):
        print(data["image"], "not found")
        continue
    data["image_path"] = EVAL_IMAGE_DIR + data["image"]
    data["mcq"], data["correct_option_letter"] = construct_mcq(
        data["options"], data["answer"]
    )
    
    if data["category"] in misc_cats:
        data["category"] = "misc"
        
    category_count[data["category"]] += 1

print(category_count)
print(len(eval_dataset))

In [None]:
content = [
    {
        "type": "text",
        "text": """You'll be given an image, an instruction and some choices. You have to select the correct one. Reason about the choices in the context of the question and the image. End your answer with "Answer": {letter_of_correct_choice} without the curly brackets. Here are a few examples:""",
    }
]


i = 1
for data in fewshot_dataset:
    content = add_row(content, data, i, with_answer=True)
    i += 1

content.append(
    {
        "type": "text",
        "text": "Now you try it!",
    }
)

next_idx = i

### Evaluation Loop

In [None]:
ytrue = []
ypred = []

In [None]:
import time

MAX_RETRIES = 2
for data in tqdm(eval_dataset):
    content_t = add_row(content.copy(), data, next_idx, with_answer=False)
    retries = MAX_RETRIES
    while retries:
        try:
            message = client.messages.create(
                model=model_name,
                max_tokens=1024,
                messages=[
                    {
                        "role": "user",
                        "content": content_t,
                    }
                ],
            )
            # print(message)
            claude_ans = message.content[0].text.lower().strip()
            break
        except Exception as e:
            print(e)
            retries -= 1
            time.sleep(30)
            if retries == 0:
                claude_ans = "Claude could not answer this question."
                print("retries exhausted")
                break

    answer = data["correct_option_letter"].strip()
    ytrue.append(answer)
    ypred.append(claude_ans)

In [None]:
ypred = [x[-1] for x in ypred]

In [None]:
for i in range(len(ypred)):
    if ypred[i] == "Claude could not answer this question.":
        print(i)

In [None]:
len(ytrue), len(ypred)

### Result Analysis

In [None]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from collections import Counter

print(accuracy_score(ytrue, ypred))
print(confusion_matrix(ytrue, ypred))
print(classification_report(ytrue, ypred))

print(Counter(ytrue))
print(Counter(ypred))

In [None]:
import prettytable
from collections import defaultdict

table = prettytable.PrettyTable()
table.field_names = ["Category", "Total", "Wrong", "Accuracy"]

got_wrong_dict = defaultdict(int)
total_wrong = 0

for i in range(len(ypred)):
    if ypred[i] != ytrue[i]:
        got_wrong_dict[eval_dataset[i]["category"]] += 1
        total_wrong += 1


for k, v in got_wrong_dict.items():
    table.add_row([k, category_count[k], v, 1 - (v / category_count[k])])

table.add_row(
    [
        "Total",
        len(eval_dataset),
        total_wrong,
        1 - (total_wrong / len(eval_dataset)),
    ]
)


# sort by total
table.sortby = "Total"
table.reversesort = True
print(table)

In [None]:
eval_dataset_copy = eval_dataset.copy()

print(len(eval_dataset_copy))
for i, data in enumerate(eval_dataset_copy):
    # map letter to option f
    if "BLOCK" in ypred[i]:
        data["vlm_answer"] = "BLOCK"
    else:
        try:
            data["vlm_answer"] = data["options"][ord(ypred[i]) - ord("a")]
        except Exception as e:
            data["vlm_answer"] = ypred[i]
            print(ypred[i])

In [None]:
import json

METRIC_SAVE_DIR = "../../results_and_evaluation/closed_source/results/"

with open(METRIC_SAVE_DIR + "illusionvqa_soft_loc_claude_with_reasoning.json", "w") as f:
    json.dump(eval_dataset_copy, f, indent=4)