In [None]:
from PIL import Image

def construct_mcq(options, correct_option):
    correct_option_letter = None
    i = "a"
    mcq = ""

    for option in options:
        if option == correct_option:
            correct_option_letter = i
        mcq += f"{i}. {option}\n"
        i = chr(ord(i) + 1)

    if correct_option_letter is None:
        print(options, correct_option)
        raise ValueError("Correct option not found in the options")
    
    mcq = mcq[:-1]
    return mcq, correct_option_letter

def resize_image(image_path, size):
    '''resize image so that the largest edge is atmost size'''
    img = Image.open(image_path)
    width, height = img.size

    if width <= size and height <= size:
        return img
    
    if width > height:
        new_width = size
        new_height = int(height * (size / width))
    else:
        new_height = size
        new_width = int(width * (size / height))
    img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
    return img


def add_row(content, data, i, with_answer=False):  

    content.append("Image "+str(i)+": ")
    content.append(resize_image(data["image_path"], 512))
    content.append(data["question"])
    content.append(data["mcq"])

    if with_answer:
        content.append("Reasoning: {}".format(data["reasoning"]))    
        content.append("Answer: {}".format(data["correct_option_letter"]))
    else:
        content.append("Reasoning: ")
    
    return content
   

In [None]:
import google.generativeai as genai

GOOGLE_API_KEY='YOUR_API_KEY_HERE'

FEWSHOT_JSON = "illusionVQA/comprehension/fewshot_labels.json"
FEWSHOT_IMAGE_DIR = "illusionVQA/comprehension/FEW_SHOTS/"
EVAL_JSON = "illusionVQA/comprehension/eval_labels.json"
EVAL_IMAGE_DIR = "illusionVQA/comprehension/EVAL/"

genai.configure(api_key=GOOGLE_API_KEY)
model = genai.GenerativeModel('gemini-pro-vision')
model_name = 'gemini-pro-vision'

In [None]:

import json
with open(FEWSHOT_JSON) as f:
    fewshot_dataset = json.load(f)

for data in fewshot_dataset:
    data["image_path"] = FEWSHOT_IMAGE_DIR + data["image"]
    data["mcq"], data["correct_option_letter"] = construct_mcq(data["options"], data["answer"])

In [None]:
with open(EVAL_JSON) as f:
    eval_dataset = json.load(f)

from collections import defaultdict

category_count = defaultdict(int)
import os
for data in eval_dataset:
    if data["image"] not in os.listdir(EVAL_IMAGE_DIR):
        print(data["image"])
        continue
    data["image_path"] = EVAL_IMAGE_DIR + data["image"]
    data["mcq"], data["correct_option_letter"] = construct_mcq(data["options"], data["answer"])
    category_count[data["category"]] += 1

print(category_count)
print(len(eval_dataset))

In [None]:
content = ["""You'll be given an image, an instruction and some choices. You have to select the correct one. Reason about the choices in the context of the question and the image. End your answer with "Answer": {letter_of_correct_choice} without the curly brackets. Here are a few examples:"""
]

i = 1
for data in fewshot_dataset:
    content = add_row(content, data, i, with_answer=True)
    i += 1
content.append("Now you try it.")

next_data_idx = i

In [None]:
from tqdm import tqdm
import time
ytrue = []
ypred = []

for i,data in tqdm(enumerate(eval_dataset)):
    content_t = add_row(content.copy(), data, next_data_idx, with_answer=False)
    # print(content_t)
    while True:
        try:
            response = model.generate_content(content_t,
                                              safety_settings=[
                {
                    "category": "HARM_CATEGORY_HARASSMENT",
                    "threshold": "HIGH",
                },
                {
                    "category": "HARM_CATEGORY_HATE_SPEECH",
                    "threshold": "HIGH",
                },
                {
                    "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
                    "threshold": "HIGH",
                },
                {
                    "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
                    "threshold": "HIGH",
                },
                ]
            )
            break
        except Exception as e:
            print(e)
            print("Internal Error")
            continue
    
    try:
        gemini_answer = response.text
    except Exception as e:
        try:
            gemini_answer = response.parts[0].text
        except Exception as e:
            print("External Error:", response.prompt_feedback)
            # print(response.candidates)
            # print(response.parts)
            gemini_answer = str(response.prompt_feedback)
        
    

    print("GEMINI: ", gemini_answer)

    if gemini_answer[-1] == ".":
        gemini_answer = gemini_answer[:-1]
    gemini_answer = gemini_answer[-1].lower()

    answer = data["correct_option_letter"]

    ytrue.append(answer)
    ypred.append(gemini_answer)

In [None]:
#replace \n with x
ypred = [x.replace("\n", "x") for x in ypred]

In [None]:
eval_dataset[0]
Image.open(eval_dataset[0]["image_path"])

In [None]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from collections import Counter

print(accuracy_score(ytrue, ypred))
print(confusion_matrix(ytrue, ypred))
print(classification_report(ytrue, ypred))

print(Counter(ytrue))
print(Counter(ypred))

In [None]:
import prettytable

table = prettytable.PrettyTable()
table.field_names = ["Category", "Total", "Wrong", "Accuracy"]

got_wrong_dict = defaultdict(int)

for i in range(len(ypred)):
    if ypred[i] != ytrue[i]:
        got_wrong_dict[eval_dataset[i]["category"]] += 1


for k, v in got_wrong_dict.items():
    table.add_row([k, category_count[k], v, 1 - (v/category_count[k])])


#sort by total
table.sortby = "Total"
table.reversesort = True
print(table)
        

In [None]:
METRIC_SAVE_DIR = "performance_metrics/"

eval_dataset_copy = eval_dataset.copy()

print(len(eval_dataset_copy))
for i,data in enumerate(eval_dataset_copy):
    if ypred[i] != ytrue[i]:
        # map letter to option f
        if "BLOCK" in ypred[i]:
            data["vlm_answer"] = "BLOCK"
        else:
            # print(ypred[i])
            # print(i)
            try:
                data["vlm_answer"] = data["options"][ord(ypred[i]) - ord("a")]
            except Exception as e:
                data["vlm_answer"] = ypred[i]
                print(ypred[i])


import json
with open(METRIC_SAVE_DIR + "gemini_reasoning_results.json", "w") as f:
    json.dump(eval_dataset_copy, f)
