# SocraticFlanT5 - Evaluation on MS COCO | DL2 Mini-project, May 2023
---

This notebook downloads the images from the validation split of the [MS COCO Dataset (2017 version)](https://cocodataset.org/#download) and the corresponding ground-truth captions, generates captions based on the Socratic model pipeline outlined below, and evaluates the generated captions based on the MS COCO ground-truth captions. We will evaluate the folowing two approaches: 
1. Baseline: a Socratic model based on the work by [Zeng et al. (2022)](https://socraticmodels.github.io/) where GPT-3 is replaced by [FLAN-T5-xl](https://huggingface.co/docs/transformers/model_doc/flan-t5). 

2. Improved prompting: an improved baseline model where the template prompt filled by CLIP is processed before passing to FLAN-T5-xl.

There are two approaches to this evaluation: rule-based and embedding-based.

---
For the **rule-based approach**, the following metrics will be used, based on [this](https://github.com/salaniz/pycocoevalcap) repository:

* *BLEU-4*: BLEU (Bilingual Evaluation Understudy) is a metric that measures the similarity between the generated captions and the ground truth captions based on n-gram matching. The BLEU-4 score measures the precision of the generated captions up to four-grams compared to the ground truth captions.

* *METEOR*: METEOR (Metric for Evaluation of Translation with Explicit ORdering) is another metric that measures the similarity between the generated captions and the ground truth captions. It also takes into account word order and synonymy by using a set of reference summaries to compute a harmonic mean of precision and recall.

* *CIDEr*: CIDEr (Consensus-based Image Description Evaluation) is a metric that measures the consensus between the generated captions and the ground truth captions. It computes the similarity between the generated captions and the reference captions based on their TF-IDF weights, which helps capture important words in the captions.

* *SPICE*: SPICE (Semantic Propositional Image Caption Evaluation) is a metric that measures the semantic similarity between the generated captions and the ground truth captions. It analyzes the similarity between the semantic propositions present in the generated captions and those in the reference captions, taking into account the structure and meaning of the propositions.

* *ROUGE-L*: ROUGE (Recall-Oriented Understudy for Gisting Evaluation) is a metric that measures the similarity between the generated captions and the ground truth captions based on overlapping sequences of words. ROUGE-L measures the longest common subsequence (LCS) between the generated captions and the reference captions, taking into account sentence-level structure and word order.

---

For the **embedding-based** approach (based on CLIP embeddings), we calculate the cosine similarities between each image embedding and embeddings of the ground truth captions and then we calculate the cosine similarities between each image embedding and embeddings of the captions generated with FLAN-T5-xl.

### Loading the required packages

In [2]:
from image_captioning import ClipManager, ImageManager, VocabManager, FlanT5Manager, COCOManager
from eval import SocraticEvalCap
from utils import get_device
import os
import re
import json
import numpy as np
import pickle
import time
import random

### Step 1: Downloading the MS COCO images and annotations

In [3]:
coco_manager = COCOManager()
coco_manager.download_data()

### Step 2: Generating the captions via the Socratic pipeline


#### Set the device and instantiate managers

In [4]:
# Set the device to use
device = get_device()

# Instantiate the clip manager
clip_manager = ClipManager(device)

# Instantiate the image manager
image_manager = ImageManager()

# Instantiate the vocab manager
vocab_manager = VocabManager()

load_places starting!
load_places took 0.0s!
load_objects starting!
load_objects took 0.0s!


#### Compute place and object features

In [5]:
# Calculate the place features
if not os.path.exists('cache/place_feats.npy'):

    # Calculate the place features
    place_feats = clip_manager.get_text_feats([f'Photo of a {p}.' for p in vocab_manager.place_list])
    np.save('cache/place_feats.npy', place_feats)
else:
    place_feats = np.load('cache/place_feats.npy')

# Calculate the object features
if not os.path.exists('cache/object_feats.npy'):
    # Calculate the object features
    object_feats = clip_manager.get_text_feats([f'Photo of a {o}.' for o in vocab_manager.object_list])
    np.save('cache/object_feats.npy', object_feats)
else:
    object_feats = np.load('cache/object_feats.npy')

#### Define the parameters of the template prompt passed to the VLM (CLIP)

In [7]:
# Zero-shot VLM: classify image type
img_types = ['photo', 'cartoon', 'sketch', 'painting']
img_types_feats = clip_manager.get_text_feats([f'This is a {t}.' for t in img_types])

# Zero-shot VLM: classify number of people
ppl_texts = ['no people', 'people']
ppl_feats = clip_manager.get_text_feats([f'There are {p} in this photo.' for p in ppl_texts])

# Zero-shot VLM: how many top places are returned by the VLM
place_topk = 3

# Zero-shot VLM: how many top objects are returned by the VLM
obj_topk = 10

# Zero-shot LM: how many captions are generated
num_captions = 10

get_text_feats starting!
get_text_feats took 0.8s!
get_text_feats starting!
get_text_feats took 0.1s!


#### Generate image captions

In [None]:
# A helper function to caption s single image
def caption_this_image(filename, imgs_folder, ix, random_numbers, image_manager, clip_manager, vocab_manager, flan_manager):
    start_time = time.time()
    if file_name.endswith(".jpg") and ix in random_numbers:  # consider only image files
        # Getting image id
        file_name_strip = file_name.strip('.jpg')
        match = re.search('^0+', file_name_strip)
        sequence = match.group(0)
        image_id = int(file_name_strip[len(sequence):])

        img_path = os.path.join(imgs_folder, file_name)
        img = image_manager.load_image(img_path)
        img_feats = clip_manager.get_img_feats(img)
        img_feats = img_feats.flatten()
        

        # Zero-shot VLM: classify image type.
        sorted_img_types, img_type_scores = clip_manager.get_nn_text(img_types, img_types_feats, img_feats)
        img_type = sorted_img_types[0]

        # Zero-shot VLM: classify number of people.
        sorted_ppl_texts, ppl_scores = clip_manager.get_nn_text(ppl_texts, ppl_feats, img_feats)
        ppl_result = sorted_ppl_texts[0]
        if ppl_result == 'people':
            ppl_texts = ['is one person', 'are two people', 'are three people', 'are several people', 'are many people']
            ppl_feats = clip_manager.get_text_feats([f'There {p} in this photo.' for p in ppl_texts])
            sorted_ppl_texts, ppl_scores = clip_manager.get_nn_text(ppl_texts, ppl_feats, img_feats)
            ppl_result = sorted_ppl_texts[0]
        else:
            ppl_result = f'are {ppl_result}'

        # Zero-shot VLM: classify places.
        sorted_places, places_scores = clip_manager.get_nn_text(vocab_manager.place_list, place_feats, img_feats)

        # Zero-shot VLM: classify objects.
        sorted_obj_texts, obj_scores = clip_manager.get_nn_text(vocab_manager.object_list, object_feats, img_feats)
        object_list = ''
        for i in range(obj_topk):
            object_list += f'{sorted_obj_texts[i]}, '
        object_list = object_list[:-2]

        # Zero-shot LM: generate captions.
        prompt = f'''I am an intelligent image captioning bot.
        This image is a {img_type}. There {ppl_result}.
        I think this photo was taken at a {sorted_places[0]}, {sorted_places[1]}, or {sorted_places[2]}.
        I think there might be a {object_list} in this {img_type}.
        A creative short caption I can generate to describe this image is:'''

        # Generate multiple captions
        model_params = {'temperature': 0.9, 'max_length': 40, 'do_sample': True}
        caption_texts = flan_manager.generate_response(num_captions * [prompt], model_params)

        # Zero-shot VLM: rank captions.
        caption_feats = clip_manager.get_text_feats(caption_texts)
        sorted_captions, caption_scores = clip_manager.get_nn_text(caption_texts, caption_feats, img_feats)
        best_caption = sorted_captions[0]
        print(f'time taken {time.time()-start_time}')
        cpt_feats = clip_manager.get_text_feats([best_caption]).flatten()

        return image_id, best_caption, img_feats, cpt_feats
    

In [None]:
if not os.path.exists('cache/res.pickle'):
    # Instantiate the Flan T5 manager
    flan_manager = FlanT5Manager(version="google/flan-t5-xl", use_api=False)

    res = {}
    embed_imgs = {}
    embed_capt_res = {}

    # N = len(os.listdir(imgs_folder))
    random.seed(42)
    N = 100
    random_numbers = random.sample(range(len(os.listdir(imgs_folder))), N)

    # for ix, file_name in enumerate(os.listdir(imgs_folder)[:N]):
    for ix, file_name in enumerate(os.listdir(imgs_folder)):
        
        image_id, best_caption, img_feats, cpt_feats = caption_this_image(filename, imgs_folder, ix, random_numbers, image_manager, clip_manager, vocab_manager, flan_manager)
        res[image_id] = [{
            'image_id': image_id,
            'id': image_id,
            'caption': best_caption
        }]
        embed_imgs[image_id] = img_feats
        embed_capt_res[image_id] = cpt_feats

    # Saving the generated captions, image and generated caption embeddings
    with open('cache/res.pickle', 'wb') as handle:
        pickle.dump(res, handle, protocol=pickle.HIGHEST_PROTOCOL)
    with open('cache/embed_imgs.pickle', 'wb') as handle:
        pickle.dump(embed_imgs, handle, protocol=pickle.HIGHEST_PROTOCOL)
    with open('cache/embed_capt_res.pickle', 'wb') as handle:
        pickle.dump(embed_capt_res, handle, protocol=pickle.HIGHEST_PROTOCOL)
else:
    with open('cache/res.pickle', 'rb') as handle:
        res = pickle.load(handle)

### Step 3: Evaluate the generated captions against the ground truth

#### Load the ground truth annotations

In [None]:
with open(annotation_file, 'r') as f:
    lines = json.load(f)['annotations']
gts = {}
for item in lines:
    if item['image_id'] not in gts:
        gts[item['image_id']] = []
    gts[item['image_id']].append({'image_id': item['image_id'], 'caption': item['caption']})

#### Compute the embeddings for the gt captions

In [None]:
if not os.path.exists('cache/embed_capt_gt.pickle'):
    embed_capt_gt = {}
    for img_id, list_of_capt_dict in gts.items():
        list_of_captions = [capt_dict['caption'] for capt_dict in list_of_capt_dict]

        # Dims of img_feats_gt: 5 x 768
        img_feats_gt = clip_manager.get_text_feats(list_of_captions)

        embed_capt_gt[img_id] = img_feats_gt

    with open('cache/embed_capt_gt.pickle', 'wb') as handle:
        pickle.dump(embed_capt_gt, handle, protocol=pickle.HIGHEST_PROTOCOL)

#### Evaluation

In [None]:
eval_cap = {}
evaluator = SocraticEvalCap(gts, res)

# Rule-based metrics
evaluator.evaluate_rulebased()
eval_rulebased = {}
for metric, score in evaluator.eval.items():
    print(f'{metric}: {score:.3f}')
    eval_rulebased[metric] = round(score, 5)
eval_cap['rulebased'] = eval_rulebased

# Embedding-based metric
evaluator.evaluate_cossim()
for source_caption, sim in evaluator.sims.items():
    print(f'{source_caption}: avg = {sim[0]:.3f}, std = {sim[1]:.3f}')
eval_cap['cossim'] = evaluator.sims

#### Save the evaluation scores

In [None]:
with open('eval_cap.pickle', 'wb') as handle:
    pickle.dump(eval_cap, handle, protocol=pickle.HIGHEST_PROTOCOL)
