# CL Fall School 2024 in Passau: Multimodal NLP
Carina Silberer, University of Stuttgart

*This notebook is based on the [evaluation notebook](https://colab.research.google.com/drive/1RfcUhBTHvREx5X7TMY5UAgMYX8NMKy7u?usp=sharing) of the authors of the work [IRFL: Image Recognition of Figurative Language](https://github.com/irfl-dataset/IRFL).*

---

# Lab 4: Metaphor Detection
In this lab, we will use CLIP in a zero-shot setting to address the metaphor detection task as defined by [IRFL](https://irfl-dataset.github.io/). As rightfully pointed out, the task may be better considered a cross-modal inference task, but here we will study it the way it was defined by the authors of IRFL.

### Setup

#### Required packages: Installation
If one of the packages below are not yet installed on your computer, run the corresponding commands. 
For torchvision, see also https://pytorch.org/get-started/locally/

In [None]:
!pip3 install transformers --quiet
!pip3 install -U datasets --quiet
!pip3 install pip install tqdm --quiet
!pip3 install fsspec==2023.6.0 --quiet
!pip3 install torch torchvision --quiet
!pip3 install matplotlib --quiet

In [None]:
#!pip3 install -q git+https://github.com/huggingface/transformers.git

**Run just this to make sure you have compatible versions of pytorch and torchvision:**

In [None]:
!pip3 install torch torchvision -U

In [None]:
import torch
torch.__version__
import torchvision
print("Torch vision version: ", torchvision.__version__)

#### Package imports

In [None]:
import requests
import operator
import os

import pandas as pd
import torch
from PIL import Image

# We need additional packages for loading (and visualising) the images
import IPython.display
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

# from huggingface:
from datasets import Dataset
Dataset.cleanup_cache_files
from datasets import load_dataset

from transformers import CLIPProcessor, CLIPModel

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

## Warm-Up: Setting up the data
### Loading the IRFL dataset

There is a dataset for each figure of speech (idiom, metaphor, simile) in IRFL. We will only work with metaphors, but feel free to also work on the other tasks (you need to adapt the script accordingly).


In [None]:
# from huggingface:
from datasets import Dataset
Dataset.cleanup_cache_files
from datasets import load_dataset

# loads the IRFL dataset from the huggingface hub
IRFL_images = load_dataset("lampent/IRFL", data_files='IRFL_images.zip')['train']

# IRFL dataset of figurative phrase-image pairs (10k+ images)
#IRFL_idioms_dataset = load_dataset("lampent/IRFL", 'idioms-dataset')['dataset']
#IRFL_similes_dataset = load_dataset("lampent/IRFL", 'similes-dataset')['dataset']
IRFL_metaphors_dataset = load_dataset("lampent/IRFL", 'metaphors-dataset')['dataset']

print('Successfully loaded IRFL dataset and metaphor task')

#### IRFL metaphor dataset

In [None]:
pd.DataFrame(IRFL_metaphors_dataset).head()

#### Retrieve and visualise an image

In [None]:
import PIL.Image as Image

def get_image(image_name, image_folder_path='data/IRFL/images/'):
  return Image.open(os.path.join(image_folder_path, image_name.split(".")[0] + ".jpeg")).convert("RGB")

image = get_image('105928442888727985035455816965889552425794851956670719249414602380285680963206')
image

## IRFL Multimodal Figurative Language Detection task
The task is defined as follows: Given a metaphor and four candidate images, choose the image that conveys the metaphorical message. 

The approach we adopt is very simple: To find the "correct" image given a linguistic metaphor, the similarity between the metaphor and each of the four candidate images are measured, and the image with the highest similarity is chosen as the correct one (i.e., that conveys the meaning of the metaphor).

### Load IRFL Multimodal Metaphor Detection task
We need to load the images (copy of the code above) as well as the task, i.e., the data of the task: the metaphors (*phrases*) and the image ids for each metaphor (4 candidates per metaphor, with 1 being the correct representation of the metaphor (*answer*) and three incorrect images (*distractors*)).

We load the *test* split, since we want to evaluate a model on the task in a zero-shot setting (in contrast to training a model on it).

In [None]:
from datasets import load_dataset

IRFL_images = load_dataset("lampent/IRFL", data_files='IRFL_images.zip')['train']

def get_image(image_name, image_folder_path='data/IRFL/images/'):
    return Image.open(os.path.join(image_folder_path, image_name.split(".")[0] + ".jpeg")).convert("RGB")

# Detection task
IRFL_metaphor_detection_task = load_dataset("lampent/IRFL", 'metaphor-detection-task')["test"]

There are 333 items in the test data: 

In [None]:
print(IRFL_metaphor_detection_task)
IRFL_metaphor_detection_task.data

### Approach: Image--Text Matching
You can use any function that gives you a matching score between an image and a text. We will use CLIP, i.e., encode images and phrases with CLIP and then calculate the dot product between each phrase and its four candidate images (i.e., we obtain four matching scores).

In [None]:
import torch
from transformers import CLIPProcessor, CLIPModel

device = "cuda" if torch.cuda.is_available() else "cpu"

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", clean_up_tokenization_spaces=True)

# we don't perform training, so we don't need "gradient" (gradient descent for backpropagation)
total_parameters = sum(p.numel() for p in model.parameters())
for param in model.parameters():
    param.requires_grad = False


# If you want to work directly with representations, e.g. because you are using another approach, you can use this function
def get_vectors_similarity(v1, v2):
    similarity = v1.detach().cpu().numpy() @ v2.detach().cpu().numpy().T
    return similarity

def get_clip_similarity(phrase, images):
    inputs = processor(text=phrase, images=images, return_tensors="pt", padding=True)
    # we use CLIP to calculate the dot product
    outputs = model(**inputs)
    return outputs

# This function will return the matching probability of the phrase and its image.
# If `definitions` are provided, it will concatenate the definitions of the phrase. Only idiom instances pass definitions.
# If you wish to test your model replace CLIP with the desired model.
def get_clip_phrase_image_similarity_score(phrase, img_name, definitions):
    if definitions:
        definition_prompt = '.'.join(definitions) + '.'
        phrase += '.' + definition_prompt

    image = get_image(img_name)
    outputs = get_clip_similarity(phrase, image)
    logits_per_image = outputs.logits_per_image.item()/100.0 # this is the image-text similarity score
    return logits_per_image

In [None]:
img_name = '94721155644044834970605402726680935822854330271009500854369319195697461678173'
phrase="burnt toast"
get_clip_phrase_image_similarity_score(phrase, img_name, None)

### Evaluation
#### Evaluation Script

In [None]:
#!pip install scikit-learn --quiet

from collections import Counter
import json
from sklearn.metrics import accuracy_score
import random
from tqdm import tqdm

def solve_IRFL_detection_task(task_instances):
    ground_truth = []
    predictions = []
    for task_instance in tqdm(task_instances):
        phrase, answer, distractors, candidates, definitions = preprocessing(task_instance)
        prediction = solve_IRFL_detection_task_instance(phrase, candidates, definitions)

        ground_truth.append(answer)
        predictions.append(prediction)

    return ground_truth, predictions, int(accuracy_score(ground_truth, predictions) * 100)

def solve_IRFL_detection_task_instance(phrase, candidates, definitions):
    sim_for_image = {}
    for img_name in candidates:
        sim_for_image[img_name] = get_clip_phrase_image_similarity_score(phrase, img_name, definitions)

    clip_prediction = Counter(sim_for_image).most_common()[0][0]
    return clip_prediction

def preprocessing(task_instance):
    answer = json.loads(task_instance['answer'])
    distractors = json.loads(task_instance['distractors'])
    if task_instance.get('definition'):
        definitions = json.loads(task_instance['definition'])
    else:
        definitions = None
    candidates = answer+distractors
    random.shuffle(candidates)
    return task_instance['phrase'], answer[0], distractors, candidates, definitions

#### Evaluate multimodal metaphor detection task
We iterate over all instances, use CLIP to predict which image conveys the meaning of each corresponding metaphor, and compare the prediction with the true answers. 
To assess the model, we measure accuracy: The proportion of correct predictions of all items.

In [None]:
ground_truth, predictions, accuracy = solve_IRFL_detection_task(IRFL_metaphor_detection_task)
print(f"model_accuracy: {accuracy}")

### Visualisation
#### Plot examles with predictions and labels.
True Positive: Correct model prediction marked in <font color='green'>green</font> color.
</br>
False Positive: Incorrect model prediction marked in <font color='red'>red</font> color.
</br>
Ground Truth: Incorrect model prediction marked in <font color='blue'>blue</font> color.
</br>

In [None]:
# !pip install matplotlib --quiet
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np

def plot_IRFL_detection_task_instance(phrase, candidates, answer, clip_prediction, score, definitons):
  fig, axs = plt.subplots(1, 4, figsize=(40,20))
  plot_text = f"Phrase: {phrase}. Score: {score}."
  if definitons:
    plot_text = plot_text + f" \n Definitons: {definitons}"
  plt.suptitle(plot_text, fontsize=50)

  plotted_img = 0
  for ax_row in axs:
    color = None
    curr_img = get_image(candidates[plotted_img])
    curr_candidate = candidates[plotted_img]
    ax_row.imshow(curr_img)
    height, width, _ = np.array(curr_img).shape
    draw_rect = False
    if curr_candidate == answer and curr_candidate == clip_prediction:
      color = 'g'
      draw_rect = True
    elif curr_candidate == answer:
      color = 'b'
      draw_rect = True
    elif curr_candidate == clip_prediction:
      color = 'r'
      draw_rect = True
    plotted_img += 1
    if color is not None:
      rect = patches.Rectangle((0, 0), width, height, linewidth=20, edgecolor=color, facecolor='none')
      ax_row.add_patch(rect)
  ax = plt.gca()
  ax.axes.xaxis.set_visible(False)
  ax.axes.yaxis.set_visible(False)
  ax.axes.xaxis.set_ticks([])
  ax.axes.yaxis.set_ticks([])
  plt.tight_layout()
  return fig, axs

In [None]:
task_instance = IRFL_metaphor_detection_task[1]
phrase, answer, distractors, candidates, definitions = preprocessing(task_instance)

clip_prediction = solve_IRFL_detection_task_instance(phrase, candidates, None)
fig, axs = plot_IRFL_detection_task_instance(phrase, candidates, answer, clip_prediction, 1 if answer == clip_prediction else 0, task_instance.get('definition',""))


## Exercise: Analysis
Conduct a small analysis with the goal to get more insights into CLIP's predictions. Do this by performing zero-shot object classification: 
1. For each instance of one metaphor and four candidate images, have CLIP predict the top K labels for the four images. Then, compare the labels and the probability scores with CLIP's prediction on the metaphor task, i.e., which image did it predict as the one conveying the metaphor, and which class labels did it predict for the image? Are the class labels related to the metaphor?
2. To do that, you will need a vocabulary of class labels. You create that by extracting the words of all phrases (metaphors).
   Recall that for zero-shot classification with CLIP, CLIP expects a phrase as input. E.g., "umbrella" --> "this is a photo of an umbrella ."

**Tasks**
* Sample 20 images and analyse CLIP's predictions. Can you identify possible reasons for CLIP's (in)correct predictions?
* Take also into account the instances themselves: Do they and the "correct" answer make sense to you?
* Improve the set of class labels (vocabulary) by filtering out functions words, and/or applying lemmatisation. You need to write code for that. You may also devise a larger vocabulary.
* If you, just for fun, also want to generate captions for the images, see the code at the bottom (GIT from the first lab).

### Setup: Data preparation and approach

#### Create vocabulary of class labels

In [None]:
def load_labels(IRFL_metaphor_detection_task):
    all_words = []
    for task_instance in tqdm(IRFL_metaphor_detection_task):
        phrase, _, _, _, _ = preprocessing(task_instance)
        all_words.extend(phrase.split())
    return list(set(all_words))

In [None]:
IRFL_metaphor_detection_task = load_dataset("lampent/IRFL", 'metaphor-detection-task')["test"]
classes = load_labels(IRFL_metaphor_detection_task)
prompts = [f"this is not showing a {label} ." for label in classes]
prompts[:5]

**Remark:** The vocabulary could be improved, e.g., through lemmatisation or by discarding function words.

#### Approach: Zero-shot classification with CLIP

In [None]:
# functions for zero-shot classification
def get_clip_scores(prompts, images):
    probs = None
    with torch.no_grad():
        inputs = processor(text=prompts, images=images, return_tensors="pt", padding=True)
        outputs = model(**inputs)
        logits_per_image = outputs.logits_per_image # this is the image-text similarity score
        probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
    return probs

def clip_zero_chot_classification(texts, images, topKlabels=5):
    probs = get_clip_scores(texts, images)
    top_probs, top_labels = probs.cpu().topk(topKlabels, dim=-1)
    return top_probs, top_labels

#### Visualisation of the predictions (topK)

In [None]:
def plot_zero_shot_classification(phrase, images, top_probs, top_labels, classes):
    plt.figure(figsize=(16, 16))
    plot_text = f"Phrase: {phrase}."
    plt.suptitle(plot_text, fontsize=12)
    
    for i, image in enumerate(images):
        plt.subplot(4, 6, 3 * i + 1)
        plt.imshow(image)
        plt.axis("off")
    
        plt.subplot(4, 6, 3 * i + 2)
        y = np.arange(top_probs.shape[-1])
        plt.grid()
        plt.barh(y, top_probs[i])
        plt.gca().invert_yaxis()
        plt.gca().set_axisbelow(True)
        plt.yticks(y, [classes[index] for index in top_labels[i].numpy()])
        plt.xlabel("probability")
    
    plt.subplots_adjust(wspace=0.5)
    plt.show()

### Analysis: Putting it all together

In [None]:
plot_predictions = True
plot_zeroshot_classes = True

IRFL_metaphor_detection_task = load_dataset("lampent/IRFL", 'metaphor-detection-task')["test"]
# Sample an instance. You can also hard-code the parameter below to get a specific instance of the dataset
instance_num = np.random.randint(IRFL_metaphor_detection_task.num_rows)

# load an instance of the dataset
task_instance = IRFL_metaphor_detection_task[instance_num]
phrase, answer, distractors, candidates, definitions = preprocessing(task_instance)

# if we do zero-shot classification, we create a vocabulary of classes 
# by extracting the words of all phrases (metaphors)
if plot_zeroshot_classes:
    classes = load_labels(IRFL_metaphor_detection_task)
    prompts = [f"this is a rendering of a {label} ." for label in classes]
    #prompts = [f"this is a photo of a {label} ." for label in classes]

# do the metaphor detection task and plot the prediction
if plot_predictions:
    clip_prediction = solve_IRFL_detection_task_instance(phrase, candidates, None)
    fig, axs = plot_IRFL_detection_task_instance(phrase, candidates, answer, clip_prediction, 1 if answer == clip_prediction else 0, task_instance.get('definition',""))

# do zero-shot classification and plot the top-K predicted classes
print("Instance index: ", instance_num)
if plot_zeroshot_classes:
    images = [get_image(img_name) for img_name in candidates]
    top_probs, top_labels = clip_zero_chot_classification(prompts, images)
    plot_zero_shot_classification(phrase, images, top_probs, top_labels, classes)

### Bonus: Captioning the images

In [None]:
!pip install -q git+https://github.com/huggingface/transformers.git

In [None]:
import torch
from transformers import AutoProcessor, AutoModelForCausalLM

## Two different models, trained on textcaps and vatex, resp.
#git_processor = AutoProcessor.from_pretrained("microsoft/git-base-textcaps", clean_up_tokenization_spaces=True)
#git_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-textcaps")
git_processor = AutoProcessor.from_pretrained("microsoft/git-base-vatex", clean_up_tokenization_spaces=True)
git_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vatex")

# run on the GPU if you have one
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# we don't perform training, so we don't need "gradient" (gradient descent for backpropagation)
total_parameters = sum(p.numel() for p in git_model.parameters())
for param in git_model.parameters():
    param.requires_grad = False

In [None]:
images = [get_image(img_name) for img_name in candidates]

for idx, img in enumerate(images):
    inputs = git_processor(images=img, return_tensors="pt")
    pixel_values = inputs.pixel_values.to(device)
    generated_ids = git_model.generate(pixel_values=pixel_values, max_length=50)
    caption = git_processor.batch_decode(generated_ids, skip_special_tokens=True)
    print("Generated caption:", caption)
    display(img)