In [1]:
from PIL import Image
import base64
from together import Together
from dotenv import load_dotenv
import os
from collections import defaultdict
import os
import random
from tqdm import tqdm

random.seed(42)
load_dotenv()


def resize_image(image_path, size):
    """resize image so that the largest edge is atmost size"""
    img = Image.open(image_path)
    width, height = img.size

    if width <= size and height <= size:
        return img

    if width > height:
        new_width = size
        new_height = int(height * (size / width))
    else:
        new_height = size
        new_width = int(width * (size / height))
    img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
    return img


def encode_image(image_path):
    img = resize_image(image_path, 512)
    temp_name = "temp.jpg"
    img.save(temp_name)
    with open(temp_name, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")


def construct_mcq(options, correct_option):
    correct_option_letter = None
    i = "a"
    mcq = ""

    for option in options:
        if option == correct_option:
            correct_option_letter = i
        mcq += f"{i}. {option}\n"
        i = chr(ord(i) + 1)

    mcq = mcq[:-1]
    return mcq, correct_option_letter


def add_row(content, data, i, with_answer=False):
    encoded_image = encode_image(data["image_path"])

    content.append(
        {
            "type": "text",
            "text": "Image " + str(i) + ": " + data["question"] + "\n" + data["mcq"],
        }
    )
    content.append(
        {
            "type": "image_url",
            "image_url": {
                "url": f"data:image/jpeg;base64,{encoded_image}",
            },
        }
    )

    if with_answer:
        content.append(
            {
                "type": "text",
                "text": "Answer {}: ".format(i) + data["correct_option_letter"],
            }
        )
    else:
        content.append(
            {
                "type": "text",
                "text": "Answer {}: ".format(i),
            }
        )

    return content

In [2]:
FEWSHOT_JSON = "illusionVQA/comprehension/fewshot_labels.json"
FEWSHOT_IMAGE_DIR = "illusionVQA/comprehension/FEW_SHOTS/"
EVAL_JSON = "illusionVQA/comprehension/eval_labels.json"
EVAL_IMAGE_DIR = "illusionVQA/comprehension/EVAL/"

client = Together(api_key=os.environ.get("TOGETHER_API_KEY"))
MODEL_NAME = "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo"

In [3]:
import json

with open(FEWSHOT_JSON) as f:
    fewshot_dataset = json.load(f)


for data in fewshot_dataset:
    data["image_path"] = FEWSHOT_IMAGE_DIR + data["image"]
    data["mcq"], data["correct_option_letter"] = construct_mcq(
        data["options"], data["answer"]
    )

In [4]:
with open(EVAL_JSON) as f:
    eval_dataset = json.load(f)

random.shuffle(eval_dataset)

In [5]:
category_count = defaultdict(int)

for data in eval_dataset:
    if data["image"] not in os.listdir(EVAL_IMAGE_DIR):
        print(data["image"])
        continue
    data["image_path"] = EVAL_IMAGE_DIR + data["image"]
    data["mcq"], data["correct_option_letter"] = construct_mcq(
        data["options"], data["answer"]
    )
    category_count[data["category"]] += 1

print(category_count)
print(len(eval_dataset))

defaultdict(<class 'int'>, {'deceptive design': 37, 'impossible object': 134, 'counting': 11, 'color': 23, 'edited-scene': 21, 'size': 46, 'hidden': 45, 'real-scene': 64, 'angle illusion': 26, 'angle constancy': 2, 'perspective': 2, 'circle-spiral': 6, 'upside-down': 7, 'positive-negative space': 7, 'occlusion': 2, 'repeating pattern': 2})
435


In [6]:
content = []

content.append(
    {
        "type": "text",
        "text": "Image 1\n",
    }
)
content.append(
    {
        "type": "image_url",
        "image_url": {
            "url": f"https://gratisography.com/wp-content/uploads/2024/03/gratisography-funflower-1170x780.jpg",
        },
    }
)
content.append(
    {
        "type": "text",
        "text": "Explain the two images above",
    }
)

In [7]:
content

[{'type': 'text', 'text': 'Image 1\n'},
 {'type': 'image_url',
  'image_url': {'url': 'https://gratisography.com/wp-content/uploads/2024/03/gratisography-funflower-1170x780.jpg'}},
 {'type': 'text', 'text': 'Explain the two images above'}]

In [8]:
response = client.chat.completions.create(
    model=MODEL_NAME,
    messages=[
        {
            "role": "user",
            "content": content,
        }
    ],
    max_tokens=512,
    temperature=0.7,
    top_p=0.7,
    top_k=50,
    stop=["<|eot_id|>", "<|eom_id|>"],
)

In [9]:
print(response.choices[0].message.content)

The image features a sunflower with sunglasses placed on its center, creating a playful and whimsical scene. The sunflower is positioned centrally in the image, with its bright yellow petals and dark brown center standing out against the orange background. The sunglasses add a touch of humor and personality to the flower, making it appear as if it's ready to take on the world with its stylish and cool demeanor.

The overall effect of the image is one of joy and playfulness, inviting the viewer to smile and appreciate the simple pleasures in life. The use of bright colors and bold shapes creates a lively and energetic atmosphere, making the image feel like a celebration of summer and all its delights.


### 0-shot

In [10]:
content = [
    {
        "type": "text",
        "text": "You'll be given an image, an instruction and some options. You have to select the correct one. Do not explain your reasoning. Answer with only the letter which corresponds to the correct option. Do not repeat the entire answer. Do not output anything other than the correct letter. For example, if the correct option is 'a', you should only output 'a'.",
    }
]
next_idx = 1

### 4-shot

In [11]:
# content = [
#     {
#         "type": "text",
#         "text": "You'll be given an image, an instruction and some choices. You have to select the correct one. Do not explain your reasoning. Answer with the option's letter from the given choices directly. Here are a few examples:",
#     }
# ]


# i = 1
# for data in fewshot_dataset:
#     content = add_row(content, data, i, with_answer=True)
#     i += 1

# content.append({
#                     "type": "text",
#                     "text": "Now you try it!",
#                 })

# next_idx = i

### Evaluation Loop

In [12]:
def parse_output(output: str):
    if output.startswith("Answer"):
        return output.split(": ")[1][0].lower()
    else:
        return output[0].lower()

In [13]:
ytrue = []
ypred = []

In [14]:
import time

MAX_RETRIES = 2
for data in tqdm(eval_dataset):
    content_t = add_row(content.copy(), data, next_idx, with_answer=False)
    retries = MAX_RETRIES
    while True:
        try:
            response = client.chat.completions.create(
                model=MODEL_NAME,
                messages=[
                    {
                        "role": "user",
                        "content": content_t,
                    }
                ],
                max_tokens=512,
                temperature=0.7,
                top_p=0.7,
                top_k=50,
                stop=["<|eot_id|>", "<|eom_id|>"],
            )
            intern_ans = parse_output(response.choices[0].message.content.strip())
            break
        except Exception as e:
            print(e)
            retries -= 1
            time.sleep(30)
            if retries == 0:
                intern_ans = "GPT4 could not answer this question."
                print("retries exhausted")
                break
            continue

    answer = data["correct_option_letter"].strip()
    ytrue.append(answer)
    ypred.append(intern_ans)

100%|██████████| 435/435 [10:09<00:00,  1.40s/it]


In [15]:
for i in range(len(ypred)):
    if ypred[i] == "Llama 3.2 could not answer this question.":
        print(i)

In [16]:
len(ytrue), len(ypred)

(435, 435)

### Result Analysis

In [17]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from collections import Counter

print(accuracy_score(ytrue, ypred))
print(confusion_matrix(ytrue, ypred))
print(classification_report(ytrue, ypred))

print(Counter(ytrue))
print(Counter(ypred))

0.4091954022988506
[[62 22 18 11  1  0  0  1  1]
 [34 53 11  7  1  0  0  1  2]
 [34 20 35  8  2  0  1  3  6]
 [18 18 13 26  2  1  0  3  5]
 [ 2  2  2  2  2  0  0  0  1]
 [ 1  1  1  0  1  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0]]
              precision    recall  f1-score   support

           a       0.41      0.53      0.46       116
           b       0.46      0.49      0.47       109
           c       0.44      0.32      0.37       109
           d       0.48      0.30      0.37        86
           e       0.22      0.18      0.20        11
           f       0.00      0.00      0.00         4
           i       0.00      0.00      0.00         0
           t       0.00      0.00      0.00         0
           w       0.00      0.00      0.00         0

    accuracy                           0.41       435
   macro avg       0.22      0.20      0.21       435
weighted avg       0.43      0.41      0.41       435

Count

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [18]:
import prettytable
from collections import defaultdict

table = prettytable.PrettyTable()
table.field_names = ["Category", "Total", "Wrong", "Accuracy"]

got_wrong_dict = defaultdict(int)

for i in range(len(ypred)):
    if ypred[i] != ytrue[i]:
        got_wrong_dict[eval_dataset[i]["category"]] += 1
    else:
        got_wrong_dict[eval_dataset[i]["category"]] += 0


for k, v in got_wrong_dict.items():
    table.add_row([k, category_count[k], v, 1 - (v / category_count[k])])


# sort by total
table.sortby = "Total"
table.reversesort = True
print(table)

+-------------------------+-------+-------+---------------------+
|         Category        | Total | Wrong |       Accuracy      |
+-------------------------+-------+-------+---------------------+
|    impossible object    |  134  |   74  |  0.4477611940298507 |
|        real-scene       |   64  |   34  |       0.46875       |
|           size          |   46  |   33  | 0.28260869565217395 |
|          hidden         |   45  |   27  |         0.4         |
|     deceptive design    |   37  |   27  |  0.2702702702702703 |
|      angle illusion     |   26  |   15  | 0.42307692307692313 |
|          color          |   23  |   16  | 0.30434782608695654 |
|       edited-scene      |   21  |   12  |  0.4285714285714286 |
|         counting        |   11  |   7   | 0.36363636363636365 |
|       upside-down       |   7   |   2   |  0.7142857142857143 |
| positive-negative space |   7   |   4   |  0.4285714285714286 |
|      circle-spiral      |   6   |   3   |         0.5         |
|    repea

In [19]:
METRIC_SAVE_DIR = "../../result_jsons/"

eval_dataset_copy = eval_dataset.copy()

print(len(eval_dataset_copy))
for i, data in enumerate(eval_dataset_copy):
    # map letter to option f
    if "BLOCK" in ypred[i]:
        data["vlm_answer"] = "BLOCK"
    else:
        data["vlm_answer"] = ypred[i]

435


In [20]:
import json

with open(METRIC_SAVE_DIR + "llama3.2_11b_comprehension_0shot.json", "w") as f:
    json.dump(eval_dataset_copy, f, indent=4)