In [None]:
import pathlib
import textwrap

import numpy as np
import pickle

from IPython.display import display
from IPython.display import Markdown
import torch
import torch.nn.functional as F
import time

from sklearn.metrics import roc_auc_score
import re

def to_markdown(text):
  text = text.replace('•', '  *')
  return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))

import google.generativeai as genai

In [None]:
genai.configure(api_key='<your-api-key>')

In [None]:
for m in genai.list_models():
  if 'generateContent' in m.supported_generation_methods:
    print(m.name)

In [None]:
safety_settings = [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
                   {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, 
                   {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
                   {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}]



In [None]:
model = genai.GenerativeModel('models/gemini-1.5-flash-latest', safety_settings = safety_settings)

In [None]:
import PIL.Image
import json

with open("/home/haovan/hateful_memes/test_unseen" + '.jsonl', 'r') as file:
    test_data_meta = list(file)

with open("/home/haovan/hateful_memes/train" + '.jsonl', 'r') as file:
    train_data_meta = list(file)

img_dir = '/home/haovan/hateful_memes/'



In [None]:
train_data = {}
test_data = {}
for idx, json_str in enumerate(train_data_meta):
    train_sample = json.loads(json_str)
    train_data[train_sample['id']] = train_sample
for idx, json_str in enumerate(test_data_meta):
    test_sample = json.loads(json_str)
    test_data[test_sample['id']] = test_sample

In [None]:
test_embeddings = torch.load("/home/haovan/hateful_memes/test_unseen_meme_embeddings.pt")
train_embeddings = torch.load("/home/haovan/hateful_memes/train_meme_embeddings.pt")

In [None]:
test_train_scores = {}
for test_id in test_embeddings.keys():
    scores = []
    test_embedding = test_embeddings[test_id]
    for train_id in train_embeddings.keys():
        scores.append((train_id, F.cosine_similarity(test_embedding, train_embeddings[train_id]).cpu()))
    print(test_id)
    test_train_scores[test_id] = scores
torch.save(test_train_scores, "unseen_test_train_scores.pt")

In [None]:
test_train_scores = torch.load("unseen_test_train_scores.pt")

In [None]:
last_processed_idx = -1
unprocessed_ids = []
responses = []

In [None]:
prompt_template = lambda text, label: f"""Classification: {label}\n"""

train_ids = list(train_embeddings.keys())

num_shots = 4

inst = f"""Hatefulness definition: Hate speech is strictly defined as a direct or indirect attack on people based on characteristics, including ethnicity, race, nationality, immigration status, religion, caste, sex, gender identity, sexual orientation, disability, or disease. The definition of attack includes violent or dehumanizing speech, statements of inferiority, and calls for exclusion or segregation. Mocking hate crime is also considered hate speech.
Classification criteria: The meme is hateful if one or both of the visual and textual meanings are hateful. If both visual and textual meanings are non-hateful but together they lead to implicitly hateful meaning, the meme is also hateful. Otherwise, the meme is non-hateful.
You consider visual and textual meanings to classify the meme as hateful or non-hateful based on the hatefulness definition and classification criteria.\n\n"""

for idx, test_id in enumerate(test_data.keys()):
    if idx <= last_processed_idx:
        continue
    test_sample = test_data[test_id]
        
    sorted_scores = sorted(test_train_scores[test_id], key = lambda x: x[1], reverse=True)
    icl_prompt = ["Here are some demonstrations on classifying memes:\n"]
    train_samples = [None for i in range(num_shots)]
    ct_non = 0
    ct_hate = 0
    for (train_id, _) in sorted_scores:
        train_label = train_data[train_id]['label']
        if train_label == 0 and ct_non < int(num_shots/2):
            train_samples[ct_non*2] = train_data[train_id]
            ct_non += 1
        elif train_label == 1 and ct_hate < int(num_shots/2):
            train_samples[ct_hate*2+1] = train_data[train_id]
            ct_hate += 1
        if ct_hate == int(num_shots/2) and ct_non == int(num_shots/2):
            break
    for train_sample in train_samples:
        img = PIL.Image.open(img_dir+train_sample['img'])
        label = "hateful" if train_sample['label'] == 1 else "non-hateful"
        prompt = prompt_template(train_sample['text'], label)
        icl_prompt += [img]
        icl_prompt += [prompt]

    print(f"Processing image: {test_id}, label: {test_sample['label']}")
    prompt = f"""For this test image, please give the classification and probability of the meme being hateful (from 0 to 1) in the following format:
Classification: 
Probability:"""
    test_img = PIL.Image.open(img_dir+test_sample['img'])
    final_prompt = [inst]+icl_prompt+[test_img, prompt]
    response = model.generate_content(final_prompt)
    try:
        answer = response.text
    except Exception as e:
        unprocessed_ids.append(test_sample['img'])
        answer = "MODEL ERROR"
    print(answer)
    responses.append(answer)
    last_processed_idx = idx
    time.sleep(5)

In [None]:
response.prompt_feedback

In [None]:
len(unprocessed_ids)

In [None]:
file_name = f'results_gemini_unseen_test_prompt_icl_{num_shots}_shots.pkl'
with open(file_name, 'wb') as f:
    pickle.dump({"responses": responses, "unprocessed_ids": unprocessed_ids}, f)

In [None]:
with open(file_name, 'rb') as f:
    data = pickle.load(f)
    responses = data['responses']

In [None]:
final_prompt

In [None]:
actuals = []
predictions = []
prob_list = []

for idx, json_str in enumerate(test_data_meta):
    test_sample = json.loads(json_str)
    actual = test_sample['label']
    actuals.append(actual)

    lower_response = responses[idx].lower()
    all_found = re.findall("probability.*?[0-9]+\.?[0-9]*", lower_response)

    if len(all_found) == 0:
        hateful_keywords = ["classification: hateful"]

        is_hateful = False
        for kw in hateful_keywords:
            if kw in lower_response:
                is_hateful = True
                break
        if is_hateful:
            predicted_class = 1
            prob = 1.0
        elif "classification: non-hateful" in lower_response:
            predicted_class = 0
            prob = 0.0
        else:
            pedicted_class = 1
            prob = 1.0
        print(f'Idx: {idx} - Res: {lower_response}')
        print()
    else:
        num_str = re.sub('[^0-9\.]', '', (all_found[0].split(" "))[-1])
        if num_str[-1] == '.':
            num_str = num_str[:-1]
        if float(num_str) >= 0.5:
            predicted_class = 1
        else:
            predicted_class = 0
        prob = float(num_str)
    prob_list.append(prob)
    predictions.append(predicted_class)

acc = np.mean(np.array(actuals) == np.array(predictions))

In [None]:
print(f"Accuracy: {acc}")
print(f"AUROC: {roc_auc_score(actuals, prob_list)}")