# Socratic Models for Image Captioning & Multimodal Reasoning
## Introduction
In this notebook, we use the Socratic models approach applied to image captioning and multimodal reasoning tasks, i.e, chain-of-thought (CoT) reasoning & visual question-answering (VQA).

### Imports

In [1]:
# Global
import sys
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from transformers import set_seed
from datasets import load_dataset

# Local 
sys.path.insert(0, '../')
import scripts.image_captioning as ic
from scripts.utils import get_device
# Extensions
%load_ext autoreload
%autoreload 2

#### Set seed & device

In [2]:
set_seed(42)    # set seed for reproducibility
# set the device to use
device = get_device()
print(f'Using device: {device}')

#### Class instantiation

In [3]:
# instantiate managers
clip_manager = ic.ClipManager(device=device)
image_manager = ic.ImageManager()
vocab_manager = ic.VocabManager()
lm_manager = ic.LmManager(version='google/flan-ul2', use_api=True)

#### Create image & text embeddings

In [None]:
# compute place & objects features
place_feats = clip_manager.get_text_emb([f'Photo of a {p}.' for p in vocab_manager.place_list])
obj_feats = clip_manager.get_text_emb([f'Photo of a {o}.' for o in vocab_manager.object_list])

## Image captioning
We use extract info from an input image using CLIP and use to to construct a text summary, which is then fed as a prompt into the LM to generate captions. 

In [None]:
# set the image path
img_dir = '../data/images/example_images'
fname = 'astronaut_with_beer.jpg'
img_path = f'{img_dir}/{fname}'

# load image
img = image_manager.load_image(img_path)
# get image representation
img_feats = clip_manager.get_img_emb(img)
# show image
plt.imshow(img)
plt.show()

#### Get img info
Extract image info using CLIP (zero-shot classification)

In [None]:
img_type, n_people, location, sorted_obj_texts, obj_list, obj_scores = clip_manager.get_img_info(img, place_feats, obj_feats, vocab_manager)

#### Filter unique objects
Filter out unique objects using cosine similarity of their embeddings

In [None]:
# filter unique objects
filtered_objs = ic.filter_objs(sorted_obj_texts, obj_scores, clip_manager, obj_topk=10, sim_threshold=0.7)
# filtered_objs = vlm.filter_objs_alt(vocab_manager.object_list, sorted_obj_texts, obj_feats, img_feats, clip_manager, obj_top=10)
print(f'filtered objects: {filtered_objs}')

### Generate captions
Generate captions by composing a prompt using info extracted from CLIP, and use the LM to generate captions from the prompt.

In [None]:
# Generate n captions, order them and print out the best.
n_captions = 20
# Create a creative beautiful caption from this context:
prompt = f'''This image is a {img_type}. It was taken in a {location}. It has {n_people}. It contains a {', '.join(filtered_objs)}. A caption I can generate to describe this image is:'''
print(f'prompt: {prompt}')

lm_params = {"min_new_tokens": 5, "max_new_tokens": 30, "length_penalty": 2, "num_beams": 8, "no_repeat_ngram_size": 3, "temperature": 0.9,  "early_stopping": True, "do_sample": True, "num_return_sequences": 1}

caption_texts = lm_manager.generate_response([prompt] * n_captions, lm_params)

# rank captions
clip_manager.rank_gen_outputs(img, caption_texts)

## Chain-of-thought reasoning
### Data
We use the [ScienceQA](https://scienceqa.github.io/) dataset.

In [None]:
# load scienceQA dataset
scienceQA_dataset = load_dataset('derek-thomas/ScienceQA', split='validation')
# filter out samples with no image
scienceQA_dataset = [sample for sample in scienceQA_dataset if sample['image'] is not None]

#### Show samples

In [None]:
# show samples
# good samples: 68, 90, 122
for i, sample in enumerate(scienceQA_dataset[120:130]):
    print(f'sample {i+1}:')
    plt.figure(figsize=(5, 5))
    plt.imshow(sample['image'])
    plt.axis('off')
    plt.show()
    # sample['image'].show()
    print('question:', sample['question'])
    print('choices:', sample['choices'])
    print('hint:', sample['hint'])
    print('lecture:', sample['lecture'])
    print('answer:', sample['answer'])
    print('solution:', sample['solution'])
    print('-'*50)

#### Select sample

In [None]:
# get sample
sample_idx = 122
sample = scienceQA_dataset[sample_idx]
# show sample
# plt.figure(figsize=(5, 5))
plt.imshow(sample['image'])
# plt.axis('off')
# plt.show()
print('question:', sample['question'])
print('choices:', sample['choices'])
print('hint:', sample['hint'])
print('lecture:', sample['lecture'])
print('answer:', sample['answer'])
print('solution:', sample['solution'])

#### Get img info
Extract image info using CLIP (zero-shot classification)

In [None]:
# get image info
img_feats = clip_manager.get_img_emb(sample['image'])
img_type, n_people, location, sorted_obj_texts, obj_list, obj_scores = clip_manager.get_img_info(sample['image'], place_feats, obj_feats, vocab_manager)
# filter unique objects
# filtered_objs = vlm.filter_objs(sorted_obj_texts, obj_scores, clip_manager, obj_topk=10, sim_threshold=0.7)
filtered_objs = ic.filter_objs_alt(vocab_manager.object_list, sorted_obj_texts, obj_feats, img_feats, clip_manager)
print(f'filtered objects: {filtered_objs}')

### Zero-shot CoT reasoning
Generate prompts using image info (CLIP) and questions, hints and choices from the dataset, along with a sentence to induce zero-shot CoT reasoning. Generate outputs (solution + rationale) from LM.

In [None]:
# generate n outputs from LM using prompt
num_outputs = 5

# compose prompt
prompt = f'''This image is a {img_type}. It was taken in a {location}. It has {n_people}. It contains a {', '.join(filtered_objs)}.
Question: {sample['question']}
Choices: {sample['choices']}
Hint: {sample['hint']}
Answer: Let's think step by step...'''
# Lecture: {sample['lecture']}
print(f'prompt: {prompt}\n')

# generate outputs from LM
lm_params = {"min_new_tokens": 5, "max_new_tokens": 30, "length_penalty": 2, "num_beams": 8, "no_repeat_ngram_size": 3, "temperature": 0.9,  "early_stopping": True, "do_sample": True, "num_return_sequences": 1}
outputs = lm_manager.generate_response([prompt] * num_outputs, lm_params)
for i, output in enumerate(outputs):
    print(f'{i + 1}. {output}')
    
print(f'\ngt solution: {sample["solution"]}\ngt answer: {sample["answer"]}')