In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
import json
import glob 
import itertools
from PIL import Image

from transformers import (
    AutoImageProcessor, 
    TrainingArguments, 
    Trainer,
    AutoTokenizer, 
    BertModel,
    BertPreTrainedModel,
)
from transformers.modeling_outputs import SequenceClassifierOutput
import evaluate

cudnn.benchmark = True
plt.ion()   # interactive mode

<matplotlib.pyplot._IonContext at 0x7fda7c960710>

In [2]:
import json
from torch.utils.data import Dataset, DataLoader
import easyocr
reader = easyocr.Reader(['en'])

In [3]:
class TextCapsDataset(Dataset):
    def __init__(self, root_dir, json_file):
        self.root_dir = root_dir
        self.json_file = json_file
        self.data = []
        with open(json_file, 'r') as f:
            json_data = json.load(f)
            for item in json_data['data']:
                image_id = item['image_id']
                caption_str = item['caption_str']
                self.data.append((image_id, caption_str))

    def __getitem__(self, index):
        image_id, caption_str = self.data[index]
        image_pth = self.root_dir + image_id + ".jpg"
        try: 
            image = Image.open(image_pth).convert("RGB")
        except FileNotFoundError:
            print(f"Could not open image {image_pth}")
            return None
        text_lst = reader.readtext(image_pth, detail = 0)
        text = ' ; '.join(text_lst)
        return {
                "image": image,
                "text": text,
                "caption": caption_str
        }

    def __len__(self):
        return len(self.data)

In [4]:
from transformers import VisionEncoderDecoderModel
text_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
train_dataset = TextCapsDataset("./Dataset/train-images/", 
                                './Dataset/TextCaps_0.1_train.json')
eval_dataset = TextCapsDataset("./Dataset/train-images/", 
                                './Dataset/TextCaps_0.1_val.json')

In [5]:
print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(eval_dataset))

Number of training examples: 109765
Number of validation examples: 15830


In [6]:
print(train_dataset[1010])

{'image': <PIL.Image.Image image mode=RGB size=1024x576 at 0x7FD99C9CAB50>, 'text': '542 ; Epli ; 1z ; B ; B ; t# ; JJe ; Tela ; JJB ; Mi ; u', 'caption': 'A green colored roadway map with Asian writing on it as well as English numbers and letters.'}


In [7]:
# Create the text tokenizer and image pre-processors
max_length = 100
from transformers import AutoImageProcessor, BertTokenizer

image_preprocessor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
text_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", return_tensors="pt", model_max_length=max_length)

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


In [8]:
def collate_fn(batch):
    # Tokenize the text and pad as necessary
    tokenized_text = text_tokenizer([x["text"] for x in batch], return_tensors="pt", padding='max_length', truncation=True, max_length=max_length)
    tokenized_caption = text_tokenizer([x["caption"] for x in batch], return_tensors="pt", padding='max_length', truncation=True, max_length=max_length)
    
    # Process the images
    processed_images = image_preprocessor([x["image"] for x in batch], return_tensors="pt", padding=True)

    return {
        "image": processed_images,
        "text": tokenized_text,
        "caption": tokenized_caption
    }

In [9]:
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, 
    batch_size=16,
    shuffle=True, 
    num_workers=0, 
    collate_fn=collate_fn,
)
eval_dataloader = torch.utils.data.DataLoader(
    eval_dataset, 
    batch_size=16,
    shuffle=True, 
    num_workers=0, 
    collate_fn=collate_fn,
)

In [10]:
for batch in train_dataloader:
    break

In [11]:
batch["image"]["pixel_values"].shape

torch.Size([16, 3, 224, 224])

In [12]:
batch["text"]['attention_mask'].shape

torch.Size([16, 100])

In [13]:
batch["caption"]['attention_mask'].shape

torch.Size([16, 100])

In [14]:
batch["caption"]

{'input_ids': tensor([[  101,  1037,  3696,  ...,     0,     0,     0],
        [  101,  1037,  2450,  ...,     0,     0,     0],
        [  101,  1037,  2235,  ...,     0,     0,     0],
        ...,
        [  101,  2048, 11640,  ...,     0,     0,     0],
        [  101,  1037, 10250,  ...,     0,     0,     0],
        [  101,  2019,  2250,  ...,     0,     0,     0]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])}

In [15]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [16]:
# Create a custom multimodal model by modifying vit-gpt2-image-captioning
# to do image captioning
from transformers import BertModel, BertGenerationEncoder, BertGenerationDecoder, ResNetModel, EncoderDecoderModel

# Create a custom multimodal model by modifying BERT and using ResNet50
# to do multimodal classification
class MultimodalBertClassifier(nn.Module):
    def __init__(
        self,
    ):  
        super().__init__()
        
        self.encoder = BertModel.from_pretrained("bert-base-uncased")
        self.decoder = BertGenerationDecoder.from_pretrained("bert-base-uncased", add_cross_attention=True, is_decoder=True, )
        self.resnet = ResNetModel.from_pretrained("microsoft/resnet-50")
        self.image_tokenizer = nn.Linear(2048, self.encoder.config.hidden_size)
        self.image_pos_emb = nn.Embedding(49, self.encoder.config.hidden_size)
        self.txt_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", return_tensors="pt", model_max_length=max_length)
        self.sentence_fuser = EncoderDecoderModel.from_pretrained("google/roberta2roberta_L-24_discofuse")
        self.tokenizer = AutoTokenizer.from_pretrained("google/roberta2roberta_L-24_discofuse")

        
    def forward(
        self,
        image,
        text,
        caption,
    ):
        image = image.to(device)
        text = text.to(device)
        caption = caption.to(device)
        image_outputs = self.resnet(**image)
        image_emb = image_outputs.last_hidden_state.flatten(2).permute(0, 2, 1)
        image_emb = self.image_tokenizer(image_emb)  # [batch_size, 49, 2048] -> [batch_size, 49, 768]
        image_position_ids = torch.arange(image_emb.shape[1]).repeat(image_emb.shape[0], 1).to(image_emb.device)
        image_position_emb = self.image_pos_emb(image_position_ids)
        image_type_ids = torch.LongTensor([1] * image_emb.shape[1]).repeat(image_emb.shape[0], 1).to(image_emb.device)
        image_type_emb = self.encoder.embeddings.token_type_embeddings(image_type_ids)
        image_emb = image_emb + image_position_emb + image_type_emb
        image_emb = self.encoder.embeddings.LayerNorm(image_emb)
        image_emb = self.encoder.embeddings.dropout(image_emb)
        
        text_embedding_output = self.encoder.embeddings(
            input_ids=text.input_ids,
            token_type_ids=text.token_type_ids,
        )
        
        embedding_output = torch.cat([text_embedding_output, image_emb], 1)
        image_attention_mask = torch.LongTensor([1] * image_emb.shape[1]).repeat(image_emb.shape[0], 1).to(image_emb.device)
        extended_attention_mask = torch.cat([text.attention_mask, image_attention_mask], 1)
 
        input_shape = (embedding_output.shape[0], embedding_output.shape[1])
        extended_attention_mask = self.encoder.get_extended_attention_mask(extended_attention_mask, input_shape)

        # And then encode
        encoder_outputs = self.encoder.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
        )
        
        # Get the pooled output for classification and apply the classifier head      
        decoder_outputs = self.decoder(
            input_ids=caption.input_ids,
            encoder_hidden_states=encoder_outputs.last_hidden_state,
        )
        
        logits = decoder_outputs.logits
        loss_fn = nn.CrossEntropyLoss(ignore_index=self.txt_tokenizer.pad_token_id)
        batch_size, seq_len, vocab_size = logits.shape
        loss = loss_fn(logits, caption.input_ids)
        loss_fn(logits.view(batch_size * seq_len, vocab_size), caption.input_ids.view(-1))
        return loss, logits, decoder_outputs.past_key_values

In [17]:
model = MultimodalBertClassifier()
model.to(device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield err

MultimodalBertClassifier(
  (encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwi

In [19]:
for batch in train_dataloader:
    break
outputs = model(**batch)

RuntimeError: Expected target size [16, 30522], got [16, 100]

In [None]:
# A useful function to see the size and # of params of a model
def get_model_info(model):
    # Compute number of trainable parameters in the model
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    # Compute the size of the model in MB
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
        
    size_all_mb = (param_size + buffer_size) / 1024**2
    
    return num_params, size_all_mb

In [None]:
# Print model info
num_params, size_all_mb = get_model_info(model)

print("Number of trainable params:", num_params)
print('Model size: {:.3f}MB'.format(size_all_mb))

In [None]:
# Setup the training arguments
output_dir = "./ocr_based_caption_generator"

training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=32,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=3,
    lr_scheduler_type="cosine",
    logging_steps=10,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=False,
    load_best_model_at_end=True,
    dataloader_num_workers=0,  
    gradient_accumulation_steps=4,
)

In [None]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir='./results',
    evaluation_strategy = "steps",
    eval_steps = 500,
    save_steps = 500,
    logging_steps = 100,
    learning_rate=5e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    push_to_hub=False,
    logging_dir='./logs',
    save_total_limit=3,
    fp16=True,
    metric_for_best_model="bleu",
    load_best_model_at_end=True,
    greater_is_better=True,
    predict_with_generate=True,
)

In [None]:
# Compute absolute learning rate
base_learning_rate = 1e-3
total_train_batch_size = (
    training_args.train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
)

training_args.learning_rate = base_learning_rate * total_train_batch_size / 256
print("Set learning rate to:", training_args.learning_rate)

In [None]:
import sacrebleu

def compute_metrics(pred):
    targets = pred.label_ids
    preds = pred.predictions.argmax(-1)
    bleu = sacrebleu.corpus_bleu(preds, [targets], force=True).score
    return {"bleu": bleu}

In [None]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
)


In [None]:
# Train
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

In [None]:
# Evaluate on the test dataset
metrics = trainer.evaluate(test_dataset)
trainer.log_metrics("test", metrics)