# Image Captioning Model
 multi-modal large language models (MLLMs): image-to-text and text-to-image.  Transformer architectures are general sequence-processing architectures that can process both images and text
1. Connect pretrained vision and text models as a single transformer architecture for image captioning
1. Train  image captioning model
1. Qualitatively inspect the quality of the caption generated
1. Compare the output with an pretrained image captioning model

We are going to use `sbu_captions` as our dataset to train our own image captioning model.

`sbu_captions` contains 1 million pairs of image urls and captions. Visit [this link](https://huggingface.co/datasets/sbu_captions) to view their dataset on Hugging Face Datasets hub.

In [None]:
from datasets import load_dataset

data = load_dataset("sbu_captions", split="train").shuffle(seed=42)
data

Found cached dataset sbu_captions (/root/.cache/huggingface/datasets/sbu_captions/default/0.0.0/0c994175c0b31f384f0b9a6aa51c5c9969c468df4c3d33a52074686ea51ee684)
Loading cached shuffled indices for dataset at /root/.cache/huggingface/datasets/sbu_captions/default/0.0.0/0c994175c0b31f384f0b9a6aa51c5c9969c468df4c3d33a52074686ea51ee684/cache-eaeeb44c757dcd3e.arrow


Dataset({
    features: ['image_url', 'user_id', 'caption'],
    num_rows: 1000000
})

## Initialize an image-to-text model
 use [Vision Encoder Decoder Model](https://huggingface.co/docs/transformers/model_doc/vision-encoder-decoder) to initialize our image-to-text model. The encoder would be a transformer-based vision model to process the images and the decoder is a language model to generate the caption. Then, we will connect both the vision and language model together and further train the vision-language model on a subset of `sbu_captions`.

### Define data processing function



In [None]:
import torch
from torch.utils.data import Dataset
from PIL import Image
import requests
from io import BytesIO

class ProcessDataset(Dataset):
    def __init__(self, df, tokenizer,feature_extractor, decoder_max_length=20):
        self.df = df
        self.tokenizer = tokenizer # this is for language model
        self.feature_extractor = feature_extractor # this is for vision model
        self.decoder_max_length = decoder_max_length # this is for caption output

    def __len__(self):
        # this is necessary so that HuggingFace won't complain that the dataset doesn't have __len__ method
        # when it starts training
        return len(self.df)

    def __getitem__(self, idx):
        # this is another method name that HuggingFace expects
        # get file name + text
        img_path = self.df["image_url"][idx]
        caption = self.df["caption"][idx]

        # process image
        response = requests.get(img_path)
        image = Image.open(BytesIO(response.content))
        pixel_values = self.feature_extractor(image, return_tensors="pt").pixel_values

        # labels here refer to each token in the caption
        labels = self.tokenizer(caption,
                                truncation=True,
                                padding="max_length",
                                max_length=self.decoder_max_length).input_ids

        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding

### Initialize tokenizer and image feature extractor

Next,initialize  tokenizer to process text and feature extractor to process images respectively. After this, we are ready to pass our training dataset for processing.

In [None]:
from transformers import GPT2TokenizerFast, ViTFeatureExtractor

tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir=DA.paths.datasets+"/models")
# GPT2 doesn't have a pad token
tokenizer.pad_token = tokenizer.eos_token

feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k", cache_dir=DA.paths.datasets+"/models")



In [None]:
train_dataset = ProcessDataset(df=data[:2000],
                               tokenizer=tokenizer,
                               feature_extractor=feature_extractor)

### Using VisionEncoderDecoder

Here, we will finally use `VisionEncoderDecoder` to connect our pretrained image and text models of choice.

You might see in the output that some weights of the GPT2 model are not initialized from the model checkpoint; from HuggingFace: `You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.` For the best performance, we should ideally fine-tune this decoder on our own dataset separately and load the fine-tuned decoder. However, for simplicity's sake, we are simply going to use the model as is, and fine-tune the image-captioning model as a whole.

In [None]:
from transformers import VisionEncoderDecoderModel

model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained\
    (encoder_pretrained_model_name_or_path="google/vit-base-patch16-224-in21k",
     decoder_pretrained_model_name_or_path="gpt2",
     tie_encoder_decoder=True, cache_dir=DA.paths.datasets+"/models")

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.5.crossattention.q_attn.weight', 'h.6.ln_cross_attn.weight', 'h.5.crossattention.c_proj.weight', 'h.10.crossattention.c_proj.bias', 'h.5.crossattention.masked_bias', 'h.6.crossattention.c_attn.weight', 'h.3.crossattention.masked_bias', 'h.11.crossattention.c_proj.bias', 'h.11.crossattention.bias', 'h.5.ln_cross_attn.weight', 'h.0.ln_cross_attn.weight', 'h.11.crossattention.masked_bias', 'h.1.crossattention.q_attn.weight', 'h.0.crossattention.masked_bias', 'h.3.crossattention.bias', 'h.9.crossattention.masked_bias', 'h.5.crossattention.bias', 'h.10.crossattention.c_attn.weight', 'h.6.crossattention.c_proj.weight', 'h.7.ln_cross_attn.weight', 'h.9.crossattention.c_attn.weight', 'h.1.crossattention.c_attn.weight', 'h.9.ln_cross_attn.weight', 'h.0.crossattention.c_proj.bias', 'h.8.crossattention.masked_bias', 'h.7.crossattention.c_proj.bias', 'h.7.crossattention.masked_bias

In [None]:
# GPT2 only has bos/eos tokens but not decoder_start/pad tokens
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.eos_token_id = tokenizer.eos_token_id

# We will adjust several more model configuration settings here
model.config.vocab_size = model.config.decoder.vocab_size
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3 # this determines a sequence of N words that cannot be repeated
model.config.length_penalty = 2.0

# For decoder only
model.decoder.num_beams = 4
model.decoder.max_length = 20

## Train image-captioning model



In [None]:
from transformers import Trainer, TrainingArguments
from transformers import default_data_collator
import os

BATCH_SIZE = 16
TRAIN_EPOCHS = 20

output_directory = os.path.join(DA.paths.working_dir, "captioning_outputs")

training_args = TrainingArguments(
    output_dir=output_directory,
    per_device_train_batch_size=BATCH_SIZE,
    do_train=True,
    num_train_epochs=TRAIN_EPOCHS, # number of passes to see the entire dataset
    overwrite_output_dir=True,
    no_cuda=True, # Not using GPU
    dataloader_pin_memory=False # this specifies whether you want to pin memory in data loaders or not
)

trainer = Trainer(
    tokenizer=feature_extractor,
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=default_data_collator,
)

In [None]:
trainer.train()

Trainer is attempting to log a value of "{'return_dict': True, 'output_hidden_states': False, 'output_attentions': False, 'torchscript': False, 'torch_dtype': None, 'use_bfloat16': False, 'tf_legacy_loss': False, 'pruned_heads': {}, 'tie_word_embeddings': True, 'is_encoder_decoder': False, 'is_decoder': False, 'cross_attention_hidden_size': None, 'add_cross_attention': False, 'tie_encoder_decoder': False, 'max_length': 20, 'min_length': 0, 'do_sample': False, 'early_stopping': False, 'num_beams': 1, 'num_beam_groups': 1, 'diversity_penalty': 0.0, 'temperature': 1.0, 'top_k': 50, 'top_p': 1.0, 'typical_p': 1.0, 'repetition_penalty': 1.0, 'length_penalty': 1.0, 'no_repeat_ngram_size': 0, 'encoder_no_repeat_ngram_size': 0, 'bad_words_ids': None, 'num_return_sequences': 1, 'chunk_size_feed_forward': 0, 'output_scores': False, 'return_dict_in_generate': False, 'forced_bos_token_id': None, 'forced_eos_token_id': None, 'remove_invalid_values': False, 'exponential_decay_length_penalty': None, 

[0;31m---------------------------------------------------------------------------[0m
[0;31mTimeoutError[0m                              Traceback (most recent call last)
File [0;32m/databricks/python/lib/python3.10/site-packages/urllib3/connection.py:174[0m, in [0;36mHTTPConnection._new_conn[0;34m(self)[0m
[1;32m    173[0m [38;5;28;01mtry[39;00m:
[0;32m--> 174[0m     conn [38;5;241m=[39m [43mconnection[49m[38;5;241;43m.[39;49m[43mcreate_connection[49m[43m([49m
[1;32m    175[0m [43m        [49m[43m([49m[38;5;28;43mself[39;49m[38;5;241;43m.[39;49m[43m_dns_host[49m[43m,[49m[43m [49m[38;5;28;43mself[39;49m[38;5;241;43m.[39;49m[43mport[49m[43m)[49m[43m,[49m[43m [49m[38;5;28;43mself[39;49m[38;5;241;43m.[39;49m[43mtimeout[49m[43m,[49m[43m [49m[38;5;241;43m*[39;49m[38;5;241;43m*[39;49m[43mextra_kw[49m
[1;32m    176[0m [43m    [49m[43m)[49m
[1;32m    178[0m [38;5;28;01mexcept[39;00m SocketTimeout:

File [0;32m/datab

## Generate caption from an image

Now, let's try generating caption from a randomly picked image below.

In [None]:
test_img = data[2021]

test_img_path = test_img["image_url"]
test_img_response = requests.get(test_img_path)
test_image = Image.open(BytesIO(test_img_response.content))
display(test_image)

[0;31m---------------------------------------------------------------------------[0m
[0;31mTimeoutError[0m                              Traceback (most recent call last)
File [0;32m/databricks/python/lib/python3.10/site-packages/urllib3/connection.py:174[0m, in [0;36mHTTPConnection._new_conn[0;34m(self)[0m
[1;32m    173[0m [38;5;28;01mtry[39;00m:
[0;32m--> 174[0m     conn [38;5;241m=[39m [43mconnection[49m[38;5;241;43m.[39;49m[43mcreate_connection[49m[43m([49m
[1;32m    175[0m [43m        [49m[43m([49m[38;5;28;43mself[39;49m[38;5;241;43m.[39;49m[43m_dns_host[49m[43m,[49m[43m [49m[38;5;28;43mself[39;49m[38;5;241;43m.[39;49m[43mport[49m[43m)[49m[43m,[49m[43m [49m[38;5;28;43mself[39;49m[38;5;241;43m.[39;49m[43mtimeout[49m[43m,[49m[43m [49m[38;5;241;43m*[39;49m[38;5;241;43m*[39;49m[43mextra_kw[49m
[1;32m    176[0m [43m    [49m[43m)[49m
[1;32m    178[0m [38;5;28;01mexcept[39;00m SocketTimeout:

File [0;32m/datab

In [None]:
caption = tokenizer.decode(trainer.model.generate(feature_extractor(test_image, return_tensors="pt").pixel_values)[0])
print("--"*20)
print(caption)

[0;31m---------------------------------------------------------------------------[0m
[0;31mNameError[0m                                 Traceback (most recent call last)
File [0;32m<command-3263905024751150>:1[0m
[0;32m----> 1[0m caption [38;5;241m=[39m tokenizer[38;5;241m.[39mdecode(trainer[38;5;241m.[39mmodel[38;5;241m.[39mgenerate(feature_extractor(test_image, return_tensors[38;5;241m=[39m[38;5;124m"[39m[38;5;124mpt[39m[38;5;124m"[39m)[38;5;241m.[39mpixel_values)[[38;5;241m0[39m])
[1;32m      2[0m [38;5;28mprint[39m([38;5;124m"[39m[38;5;124m--[39m[38;5;124m"[39m[38;5;241m*[39m[38;5;241m20[39m)
[1;32m      3[0m [38;5;28mprint[39m(caption)

[0;31mNameError[0m: name 'test_image' is not defined


- Parts of our text decoder weights were not loaded from the pretrained checkpoints. So the best approach would have been to train the decoder separately on the training dataset first and load the fine-tuned decoder.
- Our image-captioning model needs more fine-tuning time! increasing the # of epochs, # of training data samples, and adjust other model hyperaparameters if you'd like.



## What if we use an existing image captioning model instead?


using a new model called `BLIP`, which stands for Bootstrapping Language-Image Pre-training. It's a modeling approach to unify vision-language understanding and generation by [Li et al 2022](https://arxiv.org/abs/2201.12086).

In [None]:
from transformers import BlipProcessor, BlipForConditionalGeneration

blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base", cache_dir=DA.paths.datasets+"/models")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", cache_dir=DA.paths.datasets+"/models")

# conditional image captioning
# in many of the initial vision-language models, adding a prefix text like below "a photo of " is crucial for models to do well
# the addition of the prefix text makes the caption generation "conditional"
text = "a photo of"
inputs = blip_processor(test_image, text, return_tensors="pt")

conditional_output = blip_model.generate(**inputs)
print("Conditional output: ", blip_processor.decode(conditional_output[0], skip_special_tokens=True))

# unconditional image captioning
# in newer model iterations, researchers have found improvements to remove the need of adding a prefix text
# therefore, the caption generation is "unconditional"
# notice that the `text` field is no longer filled out (it's now optional)
inputs = blip_processor(test_image, return_tensors="pt")

unconditional_output = blip_model.generate(**inputs)
print("Unconditional output: ", blip_processor.decode(unconditional_output[0], skip_special_tokens=True))

[0;31m---------------------------------------------------------------------------[0m
[0;31mNameError[0m                                 Traceback (most recent call last)
File [0;32m<command-3263905024751153>:10[0m
[1;32m      6[0m [38;5;66;03m# conditional image captioning[39;00m
[1;32m      7[0m [38;5;66;03m# in many of the initial vision-language models, adding a prefix text like below "a photo of " is crucial for models to do well[39;00m
[1;32m      8[0m [38;5;66;03m# the addition of the prefix text makes the caption generation "conditional"[39;00m
[1;32m      9[0m text [38;5;241m=[39m [38;5;124m"[39m[38;5;124ma photo of[39m[38;5;124m"[39m
[0;32m---> 10[0m inputs [38;5;241m=[39m blip_processor(test_image, text, return_tensors[38;5;241m=[39m[38;5;124m"[39m[38;5;124mpt[39m[38;5;124m"[39m)
[1;32m     12[0m conditional_output [38;5;241m=[39m blip_model[38;5;241m.[39mgenerate([38;5;241m*[39m[38;5;241m*[39minputs)
[1;32m     13[0m [38;5