# SocraticFlanT5 - Caption Generation (baseline) | DL2 Project, May 2023
---

This notebook downloads the images from the validation split of the [MS COCO Dataset (2017 version)](https://cocodataset.org/#download) and the corresponding ground-truth captions and generates captions based on the Socratic model pipeline outlined below. The caption will be generated by the baseline approach:
* <span style="color:#006400">**Baseline**</span>: a Socratic model based on the work by [Zeng et al. (2022)](https://socraticmodels.github.io/) where GPT-3 is replaced by [FLAN-T5-xl](https://huggingface.co/docs/transformers/model_doc/flan-t5). 

In other words, the goal of this jupyter notebook is to reproduce the Socratic Models paper with the Flan-T5 model. This provides a baseline for us to build upon.

## Set-up
If you haven't done so already, please activate the corresponding environment by running in the terminal: `conda env create -f environment.yml`. Then type `conda activate socratic`.

### Loading the required packages

In [1]:
# Package loading
import pandas as pd
from transformers import set_seed
import random
import sys
sys.path.append('..')

# Local imports
from scripts.image_captioning import ClipManager, ImageManager, VocabManager, LmManager, CocoManager
from scripts.image_captioning import LmPromptGenerator as pg
from scripts.image_captioning import CacheManager as cm
from scripts.utils import get_device

### Set seeds for reproducible results

In [2]:
# Set HuggingFace seed
set_seed(42)

# Set seed for 100 random images of the MS COCO validation split
random.seed(42)

## Step 1: Downloading the MS COCO images and annotations

In [3]:
imgs_folder = '../data/coco/val2017/'
annotation_file = '../data/coco/captions_val2017.json'

coco_manager = CocoManager()

## Step 2: Generating the captions via the Socratic pipeline


### Set the device and instantiate managers

In [None]:
# Set the device to use
device = get_device()

# Instantiate the clip manager
clip_manager = ClipManager(device)

# Instantiate the image manager
image_manager = ImageManager()

# Instantiate the vocab manager
vocab_manager = VocabManager()

# Instantiate the Flan T5 manager
flan_manager = LmManager()

load_places starting!
load_places took 0.0s!
load_objects starting!
load_objects took 0.0s!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

### Compute place and object features

In [None]:
# Calculate the place features
place_emb = cm.get_place_emb(clip_manager, vocab_manager)

# Calculate the object features
object_emb = cm.get_object_emb(clip_manager, vocab_manager)

### Load images and compute image embedding

In [None]:
# Randomly select images from the COCO dataset
img_files = coco_manager.get_random_image_paths(num_images=100)

# Set the approach to use
approach = 'baseline'

# Create dictionaries to store the images features
img_dic = {}
img_feat_dic = {}

for img_file in img_files:
    # Load the image
    img_dic[img_file] = image_manager.load_image(coco_manager.image_dir + img_file)
    # Generate the CLIP image embedding
    img_feat_dic[img_file] = clip_manager.get_img_emb(img_dic[img_file]).flatten()

### Zero-shot VLM (CLIP)
We zero-shot prompt CLIP to produce various inferences of an iage, such as image type or the number of people in an image:

#### Classify image type

In [None]:
img_types = ['photo', 'cartoon', 'sketch', 'painting']
img_types_emb = clip_manager.get_text_emb([f'This is a {t}.' for t in img_types])

# Create a dictionary to store the image types
img_type_dic = {}
for img_name, img_feat in img_feat_dic.items():
    # Score the image types
    sorted_img_types, img_type_scores = clip_manager.get_nn_text(img_types, img_types_emb, img_feat)
    # Store the best image type
    img_type_dic[img_name] = sorted_img_types[0]

#### Classify number of people

In [None]:
ppl_texts_bool = ['no people', 'people']
ppl_emb_bool = clip_manager.get_text_emb([f'There are {p} in this photo.' for p in ppl_texts_bool])

ppl_texts_mult = ['is one person', 'are two people', 'are three people', 'are several people', 'are many people']
ppl_emb_mult = clip_manager.get_text_emb([f'There {p} in this photo.' for p in ppl_texts_mult])

# Create a dictionary to store the number of people
num_people_dic = {}

for img_name, img_feat in img_feat_dic.items():
    sorted_ppl_texts, ppl_scores = clip_manager.get_nn_text(ppl_texts_bool, ppl_emb_bool, img_feat)
    ppl_result = sorted_ppl_texts[0]
    if ppl_result == 'people':
        sorted_ppl_texts, ppl_scores = clip_manager.get_nn_text(ppl_texts_mult, ppl_emb_mult, img_feat)
        ppl_result = sorted_ppl_texts[0]
    else:
        ppl_result = f'are {ppl_result}'

    num_people_dic[img_name] = ppl_result

#### Classify image place

In [None]:
place_topk = 3

# Create a dictionary to store the number of people
location_dic = {}
for img_name, img_feat in img_feat_dic.items():
    sorted_places, places_scores = clip_manager.get_nn_text(vocab_manager.place_list, place_emb, img_feat)
    location_dic[img_name] = sorted_places[0]

#### Classify image object

In [None]:
obj_topk = 10

# Create a dictionary to store the similarity of each object with the images
obj_list_dic = {}
for img_name, img_feat in img_feat_dic.items():
    sorted_obj_texts, obj_scores = clip_manager.get_nn_text(vocab_manager.object_list, object_emb, img_feat)
    object_list = ''
    for i in range(obj_topk):
        object_list += f'{sorted_obj_texts[i]}, '
    object_list = object_list[:-2]
    obj_list_dic[img_name] = object_list

#### Generate captions

In [None]:
num_captions = 50

# Set LM params
model_params = {'temperature': 0.9, 'max_length': 40, 'do_sample': True}

# Create dictionaries to store the outputs
prompt_dic = {}
sorted_caption_map = {}
caption_score_map = {}

for img_name in img_dic:
    # Create the prompt for the language model
    prompt_dic[img_name] = pg.create_baseline_lm_prompt(
        img_type_dic[img_name], num_people_dic[img_name], location_dic[img_name], obj_list_dic[img_name]
    )

    # Generate the caption using the language model
    caption_texts = flan_manager.generate_response(num_captions * [prompt_dic[img_name]], model_params)

    # Zero-shot VLM: rank captions.
    caption_emb = clip_manager.get_text_emb(caption_texts)
    sorted_captions, caption_scores = clip_manager.get_nn_text(caption_texts, caption_emb, img_feat_dic[img_name])
    sorted_caption_map[img_name] = sorted_captions
    caption_score_map[img_name] = dict(zip(sorted_captions, caption_scores))

### Save the outputs

In [None]:
data_list = []
for img_name in img_dic:
    generated_caption = sorted_caption_map[img_name][0]
    data_list.append({
        'image_name': img_name,
        'image_path': img_paths[img_name],
        'generated_caption': generated_caption,
        'cosine_similarity': caption_score_map[img_name][generated_caption]
    })
pd.DataFrame(data_list).to_csv(f'{approach}_outputs.csv', index=False)