## Setup

In [None]:
from google.colab import userdata
from openai import OpenAI

# OpenAI initialization
client = OpenAI(api_key=userdata.get('OPENAI_PROJECT_KEY'))

In [None]:
import gdown
import os
import pandas as pd
import json

In [None]:
from google.colab import files

## Data loading and preprocessing
select one of extended or test set to continue with.

##### test set

In [None]:
file_id = "1GSw-7lcRmBUiypk3iyskvLLbC5Ip4BrE" # test
file_name = "taskA_test.zip"
url = f"https://drive.google.com/uc?id={file_id}"
gdown.download(url, file_name, quiet=True)
! unzip -q - taskA_test.zip

taska_folder = "test"
taska_tsv_filename = "subtask_a_test.tsv"
dataset = "test"

#####  uncomment to instead run for extended set

In [None]:
# file_id = "1MPD814bn8lktCTJZt2naTQAz02ffjd9l"
# file_name = "taskA_ext.zip"
# url = f"https://drive.google.com/uc?id={file_id}"
# gdown.download(url, file_name, quiet=True)
# ! unzip -q - taskA_ext.zip

# taska_folder = "xeval"
# taska_tsv_filename = "subtask_a_xe.tsv"
# dataset = "xe"

### Dataframe creation

#### initialize df & preprocess image paths

In [None]:
df = pd.read_csv(f"{taska_folder}/{taska_tsv_filename}", delimiter="\t")

#### preprocess image paths for ranking

In [None]:
def preprocess_data(df, dir_name):
    """
    Loads and preprocesses the dataset, setting up image paths.
    """
    image_name_cols = ['image1_name', 'image2_name', 'image3_name', 'image4_name', 'image5_name']
    df['image_paths'] = df.apply(lambda row: [os.path.join(dir_name, row['compound'].replace("'", "_"), row[image_name]) for image_name in image_name_cols], axis=1)
    return df

df = preprocess_data(df, taska_folder)

## Sentence type classification

##### Load saved results
For extended dataset only. Instead of querying for new classifications (below), here's some we made earlier.

In [None]:
classification_results_file_id = "1ZWztepcLYrjJu-3iww3MoJDgFPLyZ33C"
classification_results_file_name = "classification_responses_ext.tsv"

url = f"https://drive.google.com/uc?id={classification_results_file_id}"
gdown.download(url, classification_results_file_name, quiet=True)

classification_df = pd.read_csv(classification_results_file_name, delimiter="\t")

In [None]:
df['sentence_type_pred'] = classification_df['result']

##### Define classification prompt

In [None]:
def get_gpt_sentence_types_prompt(samples):
  return f"""
You are a linguistics expert specializing in figurative language. You will be given a set of samples, each containing a 'target phrase' paired with a 'context sentence' containing a usage of said phrase.
The target phrases all have idiomatic (i.e. figurative) meanings, but they might be used literally in these context sentences!
For each sample, you are to do the following:
1. Looking at the target phrase in isolation, state its *idiomatic* meaning and its *literal* meaning. The literal meaning might be awkward, as some of these phrases are almost always used idiomatically.
2. *Carefully* consider how the target phrase is used in the context sentence. Is it used in its idiomatic sense (in most cases, they way that we're used to understanding it), or is it used as a literal composition of its component words?
3. Verbose explanation: Given your familiarity with the phrase's possible meanings, and having considered how it"s used in the sentence, give an explanation of what the phrase means in the context of the sentence. Remember: if the literal usage is *plausible*, it is probably used literally.
4. Final usage determination: Based on steps 1-3, state whether the phrase"s use in the context sentence is 'literal' or 'idiomatic'.

Example input:
Target phrase: 'cold turkey'
Context sentence: 'John quit smoking cold turkey and never looked back, not that it was easy.'

Target phrase: 'ghost town'
Context sentence: 'Our wanderings had led us perilously close to the walls of the ghost town where restless spirits haunted the streets, eager to absorb the vitatlity of the living.'

---

Example response:
{{"data": [
    {{
      "target_phrase": "cold turkey",
      "idiomatic_meaning": "To stop a habit or addiction abruptly and completely, without gradually reducing or tapering off. It often refers to ceasing a harmful behavior or substance like smoking or drugs.",
      "literal_meaning": "A turkey that is cold."
      "contextual_considerations": "In the sentence, 'John quit smoking cold turkey and never looked back, not that it was easy,' the phrase 'cold turkey' clearly does not refer to food. It is used in the context of quitting smoking, which aligns with the idiomatic usage of the term.",
      "verbose_explanation": "The phrase 'cold turkey' in this sentence means that John abruptly stopped smoking without tapering off or using substitutes like nicotine patches. The description highlights the difficulty of this approach, suggesting that quitting 'cold turkey' was challenging but ultimately successful. The context does not mention anything about literal turkey, further affirming the idiomatic interpretation.",
      "result": "idiomatic"
    }},
    {{
      "target_phrase": "ghost town",
      "idiomatic_meaning": "A deserted town or settlement that was once populated but is now abandoned, often evoking a sense of desolation or emptiness.",
      "literal_meaning": "A town inhabited by ghosts or supernatural entities, as in fictional or mythological contexts."
      "contextual_considerations": "In the sentence, 'Our wanderings had led us perilously close to the walls of the ghost town where restless spirits haunted the streets, eager to absorb the vitality of the living,' the description explicitly mentions 'restless spirits' and their interaction with the living. This strongly suggests a literal interpretation involving supernatural elements.",
      "verbose_explanation": "Here, 'ghost town' refers to a literal place inhabited by ghosts or spirits, as indicated by the detailed imagery of 'restless spirits' and their haunting presence. The context does not suggest the metaphorical use of the term as an abandoned, non-supernatural settlement.",
      "result": "literal"
    }}
]}}

---

You must return a valid JSON object formatted exactly as follows:
- Do not use double-quotes inside of JSON values. If quotes are necessary, use single-quotes.
- Do not include line breaks inside JSON values.
- Strictly follow the schema:

Output JSON schema:
{{
  "type": "object",
  "properties": {{
    "data": {{
      "type": "array",
      "items": {{
        "type": "object",
        "properties": {{
          "target_phrase": {{"type": "string"}},
          "idiomatic_meaning": {{"type": "string"}},
          "literal_meaning": {{"type": "string"}},
          "contextual_considerations": {{"type": "string"}},
          "verbose_explanation": {{"type": "string"}},
          "result": {{"type": "string", "enum": ["idiomatic", "literal"]}}
        }},
        "required": ["target_phrase", "idiomatic_meaning", "literal_meaning", "contextual_considerations", "verbose_explanation", "result"]
      }}
    }}
  }},
  "required": ["data"]
}}

Ensure the response is a **valid** JSON object with escaped quotes.
DO NOT include extra commentary.

Your turn. These are the samples:
{samples}
"""

##### Define prompting function

In [None]:
def gpt_sentence_types(compounds, sentences):
    """
    Prompt GPT-4 to get the sentence type, literal or figurative, for a batch of sentences.
    """
    # Create a combined prompt
    samples = "\n\n".join([
        f'Target phrase: "{nc}"\nContext sentence: "{sentence}"' for nc, sentence in zip(compounds, sentences)
    ])

    prompt = get_gpt_sentence_types_prompt(samples)

    response = client.chat.completions.create(
        model="gpt-4",
        messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt}
        ],
    )
    try:
        raw_content = response.choices[0].message.content.strip()
        content = json.loads(raw_content)
        # print("Raw response content:")
        # print(json.dumps(content, indent=4))  # Debug formatted output
        return content
    except json.JSONDecodeError as e:
        print(f"JSON decoding error: {e}")
        print("Response content that caused error:", raw_content)
        return {}

##### Get classification responses

In [None]:
# Run sentence type classifier in batches
def run_classification(df, batch_size):
    all_samples = []
    for i in range(0, len(df), batch_size):
        print(f"\nStarting batch: {i}-{i+batch_size-1}")

        batch = df.iloc[i:i + batch_size]
        responses = gpt_sentence_types(batch['compound'].tolist(),
                                       batch['sentence'].tolist())

        if "data" in responses:
            all_samples.extend(responses["data"])
        else:
            print(f"Warning: No 'data' key in response for batch {i}-{i+batch_size-1}")

    return all_samples

In [None]:
for i in range(5):
    run_idx = i+1
    print(f"\n\nStarting run #{run_idx}")
    classification_responses = run_classification(df, 20)
    print(f"Run #{run_idx} -- Total samples collected: {len(classification_responses)}")
    classification_responses_df = pd.DataFrame(classification_responses)
    cls_filename = f"{dataset}_classification_responses_{run_idx}.tsv"
    classification_responses_df.to_csv(cls_filename, sep='\t', index=False)
    print(f"Downloading file: {cls_filename}")
    files.download(cls_filename)

#### classification voting

In [None]:
df_cls_1 = pd.read_csv(f"{dataset}_classification_responses_1.tsv", delimiter="\t")
df_cls_2 = pd.read_csv(f"{dataset}_classification_responses_2.tsv", delimiter="\t")
df_cls_3 = pd.read_csv(f"{dataset}_classification_responses_3.tsv", delimiter="\t")
df_cls_4 = pd.read_csv(f"{dataset}_classification_responses_4.tsv", delimiter="\t")
df_cls_5 = pd.read_csv(f"{dataset}_classification_responses_5.tsv", delimiter="\t")

In [None]:
cls_dfs = [df_cls_1, df_cls_2, df_cls_3, df_cls_4, df_cls_5]

df_voting = df['compound'].to_frame()

for i, df_cls in enumerate(cls_dfs, start=1):
    df_cls.rename(columns={'result': f'result_{i}'}, inplace=True)
    df_voting = df_voting.merge(df_cls[['target_phrase', f'result_{i}']],
              left_on='compound',
              right_on='target_phrase',
              how='left')
    df_voting.drop(columns=['target_phrase'], inplace=True)

In [None]:
def vote(row):
    results = row[['result_1', 'result_2', 'result_3', 'result_4', 'result_5']]

    valid_results = results[results.isin(['idiomatic', 'literal'])]

    counts = valid_results.value_counts()
    idiomatic_count = counts.get('idiomatic', 0)
    literal_count   = counts.get('literal', 0)

    if idiomatic_count > literal_count:
        return 'idiomatic'
    elif literal_count > idiomatic_count:
        return 'literal'
    else:
        # return 'tie'
        return 'literal' # break the tie

df_voting['final_result'] = df_voting.apply(vote, axis=1)

In [None]:
# add classification result to main df
df['sentence_type_pred'] = df_voting['final_result']

## Idiom definitions for text input

##### Load stored responses
For extended dataset only. Instead of generating new definitions (below), here's some we made earlier.

In [None]:
def_responses_file_id = "1fOZu7wA14JtSoKPF9L2g2ODyoXzPqbv9"
def_responses_file_name = "idiom_definitions.csv"

url = f"https://drive.google.com/uc?id={def_responses_file_id}"
gdown.download(url, def_responses_file_name, quiet=True)
defs_df = pd.read_csv(def_responses_file_name)

In [None]:
# merge defs_df into df on defs_df['target_phrase'] == df['compound']
df = df.merge(defs_df[['target_phrase', 'result']],
              left_on='compound',
              right_on='target_phrase',
              how='left').rename(columns={'result': 'idiom_def'}).drop(columns=['target_phrase'])

In [None]:
# fill nan with compound
df['text_input'] = df.apply(lambda x: x['idiom_def'] if x['idiom_def'] == x['idiom_def'] else x['compound'], axis=1)

### Define prompts and get definitions
use **either** prompt 1 or prompt 2 to generate definitions for idiomatic sentences (if not loading existing definitions).

In [None]:
def gpt_definitions(compounds, sentence_types, base_prompt):
    """
    Generate definitions for target phrases using GPT-4, in batches.
    """

    input_data = [
        nc for nc, sentence_type in zip(compounds, sentence_types)
    ]

    # Create a combined prompt
    examples = "\n".join([
        f'The idiom is: "{nc}".' for nc in input_data
    ])
    # print(f"examples:\n{examples}")

    prompt = base_prompt + examples

    response = client.chat.completions.create(
        model="gpt-4",
        messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt}
        ],
    )
    try:
        raw_content = response.choices[0].message.content.strip()
        content = json.loads(raw_content)
        # print(json.dumps(content, indent=4))  # Debug formatted output
        return content
    except json.JSONDecodeError as e:
        print(f"JSON decoding error: {e}")
        print("Response content that caused error:", raw_content)
        return {}

#### prompt 1

In [None]:
prompt_exp6_idiomatic = f"""
You are a linguistics expert specializing in idioms. You will be given a set of idioms to process. For each one, do the following steps aloud (in writing):
1. Give a verbose explanation of the idiom, including what connotations it carries or undertones it evokes.
2. Give a definition of the *literal* meaning of the phrase. For noun phrases representing physical objects, focus on unambiguous visual descriptors.
3. Taking into consideration your response for #1 and #2, list three potential definitions, no longer than 20 words each, that capture the **core emotional or situational essence** conveyed by the idiom. Use **simple language that an average high-schooler would understand** and avoid figurative or overly abstract language. Focus on clear, visually interpretable descriptions that are distinct from the literal definition.
4. Choose the best definition.

---

Example outputs:
{{
  "data": [
    {{
      "target_phrase": "glass ceiling",
      "explanation": "Refers to an invisible barrier that prevents certain groups, often women or minorities, from advancing in their careers or social positions. Evokes frustration, inequality, and hidden obstacles. Frequently used in discussions of systemic discrimination.",
      "literal_definition": "A ceiling made of transparent glass.",
      "potential_definition_1": "A hidden obstacle that blocks people from reaching higher positions.",
      "potential_definition_2": "An unseen barrier that stops progress for qualified individuals.",
      "potential_definition_3": "A quiet limit that keeps certain groups from moving upward.",
      "result": "A hidden obstacle that blocks people from reaching higher positions."
    }},
    {{
      "target_phrase": "missing link",
      "explanation": "Suggests a crucial piece of information or evidence needed to bridge a gap in knowledge or understanding. Evokes the sense of an incomplete puzzle, emphasizing the importance of finding what’s absent.",
      "literal_definition": "A link in a chain that is not present, creating a gap.",
      "potential_definition_1": "A key piece that completes an unfinished idea or puzzle.",
      "potential_definition_2": "Something crucial that holds everything together but is absent.",
      "potential_definition_3": "An important connecting factor that is missing or unknown.",
      "result": "A key piece that completes an unfinished idea or puzzle."
    }},
    {{
      "target_phrase": "paper tiger",
      "explanation": "Describes someone or something that appears threatening or powerful but is actually weak or ineffective. Connotes empty threats or superficial strength.",
      "literal_definition": "A tiger made of paper, such as origami or a paper figure.",
      "potential_definition_1": "Something that seems strong but has little real power.",
      "potential_definition_2": "A fragile threat that looks more dangerous than it is.",
      "potential_definition_3": "A force that seems scary but collapses under pressure.",
      "result": "Something that seems strong but has little real power."
    }}
    ...
  ]
}}

---

You must return a valid JSON object:
- Do not use double quotes inside your value strings.
- Do not include line breaks inside JSON values.
- Strictly follow the schema.

Schema:
{{
  "type": "object",
  "properties": {{
    "data": {{
      "type": "array",
      "items": {{
        "type": "object",
        "properties": {{
          "target_phrase": {{ "type": "string" }},
          "explanation": {{ "type": "string" }},
          "literal_definition": {{ "type": "string" }},
          "potential_definition_1": {{ "type": "string" }},
          "potential_definition_2": {{ "type": "string" }},
          "potential_definition_3": {{ "type": "string" }},
          "result": {{ "type": "string" }}
        }},
        "required": ["target_phrase", "explanation", "potential_definition_1", "potential_definition_2", "potential_definition_3", "result"]
      }}
    }}
  }},
  "required": ["data"]
}}

Ensure the response is a valid JSON object with escaped quotes.

Here are the samples:
"""

In [None]:
batch_size = 15
all_data = []

for i in range(0, len(df), batch_size):
    print(f"\n Starting batch: {i}-{i+batch_size-1}")
    batch = df.iloc[i:i + batch_size]
    responses = gpt_definitions(batch['compound'].tolist(),
                                batch['sentence_type_pred'].tolist(),
                                prompt_exp6_idiomatic)

    if "data" in responses:
        all_data.extend(responses["data"])
        print(f"len(all_data): {len(all_data)}")
    else:
        print(f"Warning: no 'data' in response for batch {i}-{i+batch_size-1}")

df_exp6_defs_idiomatic = pd.DataFrame(all_data)

In [None]:
df_exp6_defs_idiomatic.to_csv(f"{dataset}_exp6_definitions_idiomatic.csv", index=False)

In [None]:
files.download(f"{dataset}_exp6_definitions_idiomatic.csv")

In [None]:
# merge defs_df into df on defs_df['target_phrase'] == df['compound']
df = df.merge(df_exp6_defs_idiomatic[['target_phrase', 'result']],
              left_on='compound',
              right_on='target_phrase',
              how='left').rename(columns={'result': 'idiom_def'}).drop(columns=['target_phrase'])

In [None]:
df['text_input'] = df.apply(lambda x: x['idiom_def'] if x['sentence_type_pred'] == 'idiomatic' else x['compound'], axis=1)

#### prompt 2

In [None]:
prompt_exp7_idiomatic = """"
You are a linguistics and visual storytelling expert, with an expertise on differentiating idiomatic from literal language. For each sample idiom below, your task is to create visual and textual representations that align well with the idiom’s figurative meaning for use in matching with images. Follow these steps:

1. Identify the phrase: Give a concise definition of the phrase in its idiomatic sense.
2. Note the literal usage (briefly): Mention the plain or surface meaning, but clarify that you are focusing on the figurative interpretation for your examples.
3. Generate 5 distinct image ideas: For the given idiom, imagine 5 different scenes or situations that visually depict its figurative meaning. Describe each scene in 1-2 sentences, focusing on visual details.
4. Generalize the captions: Write a single caption that could apply to all 5 scenes. It should capture the essence of the idiom in a way that is broad enough to fit any of the scenes.
5. Refine: Reflect on how well your caption generalizes to all five scenes, then attempt to improve on it.
6. Consider which caption is best: Weigh the captions against each other, then pick the one that best fits all 5 scenes.
7. Select the best caption: Repeat the caption you selected.

---

Example outputs:
{
  "data": [
    {
      "target_phrase": "glass ceiling",
      "explanation": "Refers to an invisible barrier that prevents certain groups (often women or minorities) from advancing to higher levels of power or responsibility. Implies a hidden form of discrimination that is not overtly acknowledged but still limits upward mobility.",
      "literal_definition": "A ceiling made of glass.",
      "image_ideas": [
        "A businesswoman standing just below a transparent barrier in a large corporate office, looking up at executives in the floor above.",
        "A group of female or minority employees reaching a fancy mezzanine level only to find an unseen barrier between them and the boardroom.",
        "A symbolic representation of cracks forming in a transparent barrier overhead as a woman holds a briefcase, showing determination to break through.",
        "A silhouette of a person pressed against a clear pane, with a hand raised as though trying to push past it.",
        "A visually layered office setting, where higher floors are accessible but separated by a nearly invisible division, highlighting the subtlety of the barrier."
      ],
      "generalized_caption_1": "Facing an unseen barrier to advancement.",
      "generalized_caption_2": "Pushing against a hidden boundary in pursuit of progress.",
      "thinking": "Both captions address the concept of a hidden obstruction. The second one, 'Pushing against a hidden boundary in pursuit of progress,' suggests active resistance and forward motion, which suits the idiom’s connotation of striving to break through.",
      "result": "Pushing against a hidden boundary in pursuit of progress."
    },
    {
      "target_phrase": "paper tiger",
      "explanation": "Describes someone or something that appears threatening or powerful but is actually weak or ineffectual. Connotes false bravado or an overestimation of strength.",
      "literal_definition": "A tiger made out of paper.",
      "image_ideas": [
        "A large, menacing figure looming over a crowd, only to be revealed as hollow or easily torn.",
        "A roaring tiger image on a billboard that looks scary but is just thin paper peeling at the edges.",
        "A towering cardboard cutout of a tiger in a political rally, symbolizing empty threats or exaggerated power.",
        "A fierce-looking trophy made of paper mache, displayed in a spotlight to highlight its fragile nature.",
        "An intimidating sign with a tiger illustration in front of a building, but the sign is tattered and flapping in the wind, showing its vulnerability."
      ],
      "generalized_caption_1": "A formidable appearance that masks a fragile reality.",
      "generalized_caption_2": "Something that looks strong but lacks real power.",
      "thinking": "The second caption directly addresses the core meaning—'Something that looks strong but lacks real power.' It's concise and precise.",
      "result": "Something that looks strong but lacks real power."
    },
    {
      "target_phrase": "missing link",
      "explanation": "Refers to a crucial piece of information or element that helps connect different ideas, theories, or facts. Connotes something vital that completes a puzzle or fills a gap in understanding.",
      "literal_definition": "A link in a chain (like a ring or segment) that is absent.",
      "image_ideas": [
        "A detective at a crime board tapping a blank space among photos and clues, indicating a vital piece of evidence that’s not yet found.",
        "An evolutionary chart with a silhouette in the middle missing, leaving a gap in the progression from ape to human.",
        "A jigsaw puzzle nearly completed, except for a conspicuously empty spot in the center.",
        "A timeline pinned on a wall with a significant date missing, highlighting the gap in recorded history.",
        "A scientific lab setting where a researcher stands before a half-finished hypothesis, gazing at a large question mark on the board."
      ],
      "generalized_caption_1": "A crucial piece that completes the bigger picture.",
      "generalized_caption_2": "The vital connecting factor that brings everything together.",
      "thinking": "Between the two, 'A crucial piece that completes the bigger picture' fits the notion of something vital and absent, capturing the idiomatic essence succinctly.",
      "result": "A crucial piece that completes the bigger picture."
    }
    ...
    ...
    ...
  ]
}

---

You must return a valid JSON object:
- Do not use double quotes inside your value strings.
- Do not include line breaks inside JSON values.
- Strictly follow the schema.

Schema:
{
  "type": "object",
  "properties": {
    "data": {
      "type": "array",
      "items": {
        "type": "object",
        "properties": {
          "target_phrase": { "type": "string" },
          "explanation": { "type": "string" },
          "literal_definition": { "type": "string" },
          "image_ideas": { "type": "array", "items": { "type": "string" } },
          "generalized_caption_1": { "type": "string" },
          "generalized_caption_2": { "type": "string" },
          "thinking": { "type": "string" },
          "result": { "type": "string" }
        },
        "required": ["target_phrase", "image_ideas", "generalized_caption_1", "generalized_caption_2", "thinking", "result"]
      }
    }
  },
  "required": ["data"]
}

Ensure the response is a valid JSON object with properly escaped quotes.

Your turn. Here are the samples:
"""


In [None]:
batch_size = 8
all_data = []

for i in range(0, len(df), batch_size):
    print(f"Starting batch {i}-{i+batch_size-1}...")

    batch = df.iloc[i:i + batch_size]
    responses = gpt_definitions(batch['compound'].tolist(),
                                batch['sentence_type_pred'].tolist(),
                                prompt_exp7_idiomatic)

    if "data" in responses:
        all_data.extend(responses["data"])
        print(f"all_data now has length: {len(all_data)}")
    else:
        print(f"Warning: no 'data' in response for batch {i}-{i+batch_size-1}")

df_exp7_defs_idiomatic = pd.DataFrame(all_data)

In [None]:
df_exp7_defs_idiomatic.to_csv(f'{dataset}_exp7_definitions_idiomatic.csv', index=False)

In [None]:
files.download(f"{dataset}_exp7_definitions_idiomatic.csv")

In [None]:
# merge defs_df into df on defs_df['target_phrase'] == df['compound']
df = df.merge(df_exp7_defs_idiomatic[['target_phrase', 'result']],
              left_on='compound',
              right_on='target_phrase',
              how='left').rename(columns={'result': 'idiom_def'}).drop(columns=['target_phrase'])

In [None]:
df['text_input'] = df.apply(lambda x: x['idiom_def'] if x['sentence_type_pred'] == 'idiomatic' else x['compound'], axis=1)

### Run prompts on GPT

### handle GPT responses

In [None]:
# Merge all batches into a single DataFrame
if all_responses:
    response_df = pd.concat(all_responses, ignore_index=True)
    response_df.rename(columns={"result": "idiom_def"}, inplace=True)

    # Merge 'text_input' and additional definition columns back into main DataFrame
    df = df.merge(response_df[['target_phrase', 'idiom_def', 'generalized_caption_1', 'generalized_caption_2', 'generalized_caption_3']],
                  left_on='compound',
                  right_on='target_phrase',
                  how='left')

    # Drop 'target_phrase' since it's redundant after merging
    df.drop(columns=['target_phrase'], inplace=True)

## Multimodal model setup

In [None]:
import torch
from PIL import Image
from ast import literal_eval

device = "cuda" if torch.cuda.is_available() else "cpu"

#### OpenCLIP

In [None]:
!pip install open_clip_torch
import open_clip

In [None]:
# define model config
openclip_model_version = "ViT-B-32"
model_openclip, _, preprocess_openclip = open_clip.create_model_and_transforms(openclip_model_version, pretrained='laion2b_s34b_b79k')
model_openclip.to(device)
open_clip_tokenizer = open_clip.get_tokenizer(openclip_model_version)
model_openclip.eval()  # model in train mode by default, impacts some models with BatchNorm or stochastic depth active

In [None]:
def openclip_image_ranking(model, image_processor, tokenizer, image_paths, sentence):
    image_inputs = torch.stack([preprocess_openclip(Image.open(ipath)) for ipath in image_paths]).to(device)
    text_input = tokenizer([sentence]).to(device)

    with torch.no_grad():
        image_features = model.encode_image(image_inputs)
        text_features = model.encode_text(text_input)

    # normalise features
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    # dot product & softmax
    similarity = (100.0 * text_features @ image_features.T).softmax(dim=-1)

    # order by similarity
    probs, indices = similarity[0].topk(5)
    return probs, indices

#### CLIP

In [None]:
# install clip
!pip install -q ftfy regex tqdm
!pip install -q git+https://github.com/openai/CLIP.git

import clip

In [None]:
def get_image_ranking_clip(model, image_processor, image_paths, sentence):
    image_inputs = torch.stack([image_processor(Image.open(ipath)) for ipath in image_paths]).to(device)
    text_input = clip.tokenize(sentence).to(device)

    with torch.no_grad():
        # compute embeddings
        image_features = model.encode_image(image_inputs)
        text_features = model.encode_text(text_input)

    # normalize features
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    # compute similarity scores
    similarity = (100.0 * text_features @ image_features.T).softmax(dim=-1)

    # rank images by similarity
    probs, indices = similarity[0].topk(5)
    return probs, indices

#### ALIGN

In [None]:
from transformers import AlignProcessor, AlignModel

In [None]:
def get_image_ranking_align(model, processor, image_paths, sentence):
    image_inputs = [Image.open(ipath) for ipath in image_paths]
    inputs = processor(images=image_inputs ,text=sentence, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
    logits_per_text = outputs.logits_per_text[0]
    probs = logits_per_text.softmax(dim=-1)
    ids_sorted = torch.argsort(probs, descending=True)
    return probs[ids_sorted], ids_sorted

### Define model config

In [None]:
model_configs = [
    {
        "base_model": "CLIP",
        "model_name": "ViT-B/32",
        "display_name": "CLIP1",
        "model": clip.load("ViT-B/32", device)[0],
        "preprocess": clip.load("ViT-B/32", device)[1]
    },
    {
        "base_model": "CLIP",
        "model_name": "ViT-L/14",
        "display_name": "CLIP2",
        "model": clip.load("ViT-L/14", device)[0],
        "preprocess": clip.load("ViT-L/14", device)[1]
    },
    {
        "base_model": "CLIP",
        "model_name": "RN50x64",
        "display_name": "CLIP3",
        "model": clip.load("RN50x64", device)[0],
        "preprocess": clip.load("RN50x64", device)[1]
    },
    {
        "base_model": "Align",
        "model_name": "Base",
        "display_name": "Align",
        "model": AlignModel.from_pretrained("kakaobrain/align-base"),
        "preprocess": AlignProcessor.from_pretrained("kakaobrain/align-base")
    },
    {
        "base_model": "open_clip",
        "model_name": openclip_model_version,
        "display_name": "openclip",
        "model": model_openclip,
        "preprocess": preprocess_openclip,
        "tokenizer": open_clip_tokenizer,
    }
]

### Define inference function

In [None]:
def get_predictions(model, processor, image_paths_list, text_inputs, base_model, model_name, tokenizer=None, model_display_name=None):
    """
    Uses get_image_ranking to generate predictions and confidence scores for a image-list, text-input pairs
    """
    print(f"get_predictions for {model_display_name}")
    predictions, confidence_scores = [], []

    for ipaths, text in zip(image_paths_list, text_inputs):
        if len(ipaths) == 0:
            predictions.append([])
            confidence_scores.append([])
            continue

        if base_model == "CLIP":
          values, indices = get_image_ranking_clip(model, processor, ipaths, text)
        elif base_model == "Align":
          values, indices = get_image_ranking_align(model, processor, ipaths, text)
        elif base_model == "open_clip":
          values, indices = openclip_image_ranking(model, processor, tokenizer, ipaths, text)
        else:
          raise ValueError(f"Unknown base_model: {base_model}")
        predictions.append(list(indices.cpu()))
        confidence_scores.append(100 * values)

    return predictions, confidence_scores

## Generate predictions

In [None]:
def run_experiment(model_config):
    # get predictions on images
    predictions, confidence_scores = get_predictions(
        model=model_config['model'],
        processor=model_config['preprocess'],
        image_paths_list=df['image_paths'],
        text_inputs=df['text_input'],
        base_model=model_config['base_model'],
        model_name=model_config['model_name'],
        tokenizer=open_clip_tokenizer,
        model_display_name=model_config['display_name']
    )
    print(f"Done ({len(predictions)} predictions)")

    # format results
    ranked_data = [
        {
            "compound": df["compound"].iloc[i],
            "expected_order": [os.path.basename(df["image_paths"].iloc[i][j]) for j in pred],
            "confidence_scores": [f"{x:.3f}" for x in conf]
        }
        for i, (pred, conf) in enumerate(zip(predictions, confidence_scores))
    ]
    ranked_df = pd.DataFrame(ranked_data)
    return ranked_df


In [None]:
for model_config in model_configs:
    print(f"Running model: {model_config['display_name']}")
    ranked_df = run_experiment(model_config)

    # write out
    filename = f"{dataset}_{model_config['display_name']}.csv"
    ranked_df.to_csv(filename)