# SocraticFlanT5 - Exploring the use of an additional vocabulary

## Introduction

The goal of this jupyter notebook is to reproduce the Socratic Models paper with the FlanT5 model.
Here we simply reimplement the pipeline used by A. Zeng et al. as closely as possible.
This provides a baseline for us to build upon.

## Imports

In [1]:
# Package loading
import os
import requests
import clip
import cv2
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from profanity_filter import ProfanityFilter
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from utils import print_time_dec
import pandas as pd

# Local imports
from image_captioning import ClipManager, ImageManager, VocabManager, FlanT5Manager, print_clip_info

## Set device

In [2]:
# Set the device to use
if getattr(torch, 'has_mps', False):
    device = 'mps'
elif torch.cuda.is_available():
    device = 'gpu'
else:
    device = 'cpu'

## Class instantiation

In [None]:
# Instantiate the clip manager
clip_manager = ClipManager(device)

# Instantiate the image manager
image_manager = ImageManager()

# Instantiate the vocab manager
vocab_manager = VocabManager()

# Instantiate the Flan T5 manager
flan_manager = FlanT5Manager()

# Print out clip model info
print_clip_info(clip_manager.model)

load_places starting!
load_places took 0.0s!
load_objects starting!


## Set image path

In [None]:
img_path = 'monkey_with_gun.jpg'

## CLIP model info

In [None]:
# Print out clip model info
print_clip_info(clip_manager.model)

## Create text embeddings

In [None]:
# Calculate the place features
place_feats = clip_manager.get_text_feats([f'Photo of a {p}.' for p in vocab_manager.place_list])

# Calculate the object features
object_feats = clip_manager.get_text_feats([f'Photo of a {o}.' for o in vocab_manager.object_list])

## Load image and compute image embedding

In [None]:
# Load image.
img = image_manager.load_image(img_path)
# Get image representation
img_feats = clip_manager.get_img_feats(img)
# Show the image
plt.imshow(img)
plt.show()

## Zero shot VLM - Image type classification

In [None]:
# Zero-shot VLM: classify image type.
img_types = ['photo', 'cartoon', 'sketch', 'painting']
img_types_feats = clip_manager.get_text_feats([f'This is a {t}.' for t in img_types])
sorted_img_types, img_type_scores = clip_manager.get_nn_text(img_types, img_types_feats, img_feats)
img_type = sorted_img_types[0]
print(f'This is a {img_type}.')

## Zero shot VLM - Number of people classification

In [None]:
# Zero-shot VLM: classify number of people.
ppl_texts = [
    'are no people', 'is one person', 'are two people', 'are three people', 'are several people', 'are many people'
]
ppl_feats = clip_manager.get_text_feats([f'There {p} in this photo.' for p in ppl_texts])
sorted_ppl_texts, ppl_scores = clip_manager.get_nn_text(ppl_texts, ppl_feats, img_feats)
ppl_result = sorted_ppl_texts[0]

## Zero shot VLM - Image place classification

In [None]:
# Zero-shot VLM: classify places.
place_topk = 3
sorted_places, places_scores = clip_manager.get_nn_text(vocab_manager.place_list, place_feats, img_feats)
print(f'Location: {sorted_places[0]}')

## Zero shot VLM - Image object classification

In [None]:
# Zero-shot VLM: classify objects.
obj_topk = 10
sorted_obj_texts, obj_scores = clip_manager.get_nn_text(vocab_manager.object_list, object_feats, img_feats)
object_list = ''
for i in range(obj_topk):
    object_list += f'{sorted_obj_texts[i]}, '
object_list = object_list[:-2]
print(f'Top 10 objects recognized: \n{sorted_obj_texts[:10]}')

## Exploring an additional vocabulary

In [2]:
# Text file available here: https://github.com/nizarhabash1/catvar/blob/master/English-Morph.txt

english_vocab = pd.read_csv('English-Morph.txt', sep=' ')

active_verbs = []
with open('English-Morph.txt', 'r') as file:
    lines = file.readlines()
    for l in lines:
        for w in l.split('\t'):
            if len(w) > 4 and w.endswith('ing'):
                active_verbs.append(w)

active_verbs = list(set(active_verbs))

active_verbs_fea = clip_manager.get_text_feats(active_verbs)
active_verbs_texts, active_verbs_scores = clip_manager.get_nn_text(active_verbs, active_verbs_fea, img_feats)

active_verb_map = dict(zip(active_verbs_texts, active_verbs_scores))

if len(terms_to_include) > 1:
    data_list = []
    for v in active_verbs_texts[:100]:

        test_term = f'{terms_to_include[0]} {v} {terms_to_include[1]}'

        score = clip_manager.get_image_caption_score(test_term, img_feats)

        data_list.append({'verb': v, 'new_term': test_term, 'score': score})

    verb_df = pd.DataFrame(data_list).sort_values('score', ascending=False)

NameError: name 'pd' is not defined

In [None]:
for v in verb_df['verb'].iloc[:10]:
    num_captions = 10
    prompt = f'''Create a creative beautiful caption from this context:
    "This image is a {img_type}. There {ppl_result}.
    The context is: {', '.join(terms_to_include)}.
    The verb is: {v}
    A creative short caption I can generate to describe this image is:'''

    model_params = {'temperature': 0.9, 'max_length': 40, 'do_sample': True}
    caption_texts = [flan_manager.generate_response(prompt, model_params) for _ in range(num_captions)]