# Part 0: Argument and Import Initialization

- The "Args" class is initialized to contain any pre-defined hyperparameters and variables used in the experiment
- Imports consists of three sections. The fist part titled "Google Colab Imports" is optional and is only used to extract a zip file if the data is loaded as a zip file into google drive - the reason this is done is to unzip the data and load it into the coloab notebooks local memory for efficiency reasons. Otherwise, the data might be read in from the drive image by image during training, which greatly reduces training efficiency. This section can take around 10 minutes to complete.
- Basic torch, pandas, and numpy imports follow, as well as modules necessary to read in and modify the images from the dataset.

In [None]:
# 1 GOOGLE COLAB - DATASET IMPORTS
import os
import random
import time
import threading
from google.colab import drive, files

# 2 BASIC TORCH IMPORTS
import torch
import torch.nn.functional as F
from torchvision import models
from torch.utils.data import Dataset, random_split, DataLoader, Subset
import torch.nn as nn
import torch.optim as optim
from  torchtext.data.utils import get_tokenizer
import torchvision.models as models
from torchvision import datasets, transforms
import torchtext as tt
import collections

# 3 READING IN IMAGES FROM DATASET
from pycocotools.coco import COCO
import matplotlib.pyplot as plt
import zipfile
from io import BytesIO

# 4 TRANSFORMING IMAGES TO DATA
from PIL import Image
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# 5 BUILDING VOCABULARY FROM CAPTIONS
from torchvision.datasets import CocoCaptions
from nltk.tokenize import word_tokenize
from nltk.translate.bleu_score import sentence_bleu

from torchtext.vocab import vocab

In [1]:
# 6 INITIALIZE ARGUMENTS
class Args:
    def __init__(self):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        self.embedding_dim = 300
        self.hidden_size = 512
        self.max_caption_length = 20

        self.batch_size = 128
        self.epochs = 30
        self.lr = 0.001

        self.side_length = 224

        self.train_img_data_path = './data/train2017'
        self.val_img_data_path = './data/val2017'

        self.train_ann_path = './data/annotations_trainval2017/captions_train2017.json'
        self.val_ann_path = './data/annotations_trainval2017/captions_val2017.json'

args = Args()

In [None]:
# 7 GOOGLE COLAB + DATASET PREPARATION
drive.mount('/content/drive')

dataset_urls = [
    "http://images.cocodataset.org/zips/train2017.zip",
    "http://images.cocodataset.org/zips/val2017.zip",
    "http://images.cocodataset.org/annotations/annotations_trainval2017.zip"
]

save_dir = '/content/drive/MyDrive/data/'
extract_dir = './data'

os.makedirs(save_dir, exist_ok=True)

start_time = time.time()
for url in dataset_urls:
    filename = os.path.join(save_dir, url.split('/')[-1])

    if not os.path.exists(filename):
        !wget {url} -P {save_dir}

    !unzip -q -u {filename} -d {extract_dir}

end_time = time.time()
time_taken = end_time - start_time
print("Time Taken:", time_taken / 60, "minutes")

# Part 1: Dataset Initilization

- Definite the IMAGE_DATASET class with needed functions. Each image is resized to 64 x 64 pixels, and expanded to 3 channels to ensure consistency in training. The COCO module is used to associate images and image annotations in the COCO dataset.
- Training and Test datasets are initialized from this class as divided originally in the dataset and a DataLoader is initialized
- A vocabulary is also initialized based on the image captions loaded in from the dataset.

In [3]:
class IMAGE_DATASET(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_dir = img_dir
        self.coco = COCO(annotations_file)
        self.transform = transform
        self.target_transform = target_transform
        self.img_keys = list(self.coco.imgs.keys())

    def get_single_image(self, img_id):
        # load image and annotations
        img = self.coco.loadImgs(img_id)[0]
        annotation_ids = self.coco.getAnnIds(img_id)
        annotations = self.coco.loadAnns(annotation_ids)

        # open and transform image to tensor with needed dimensions
        path = os.path.join(self.img_dir, img['file_name'])
        image = Image.open(path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        caption = random.choice(annotations)['caption']

        # return image with random annotation choice
        return image, caption

    def __getitem__(self, index):
        img_id = self.img_keys[index]
        return self.get_single_image(img_id)

    def __len__(self):
        return len(self.img_keys)

In [4]:
# initialize image transforms
train_transform = transforms.Compose([
    transforms.Resize(args.side_length),              # resize images to same side_length
    transforms.RandomCrop(args.side_length),          # Randomly crop image to side_length
    transforms.RandomHorizontalFlip(),                # Random image flip horizontally
    transforms.ToTensor(),                            # Convert image to a tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])   # Normalize image tensor for resnet
])

test_transform = transforms.Compose([
    transforms.Resize(args.side_length),              # Resize images to same side_length
    transforms.CenterCrop(args.side_length),          # Center crop image to side_length
    transforms.ToTensor(),                            # Convert image to a tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])   # Normalize image tensor for resnet
])

# Initialize the full dataset
full_dataset = IMAGE_DATASET(args.train_ann_path,
                             args.train_img_data_path,
                             train_transform)

test_dataset = IMAGE_DATASET(args.val_ann_path,
                             args.val_img_data_path,
                             test_transform)

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size

train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])


# initialize the data loaders
image_train_loader = DataLoader(dataset=train_dataset,
                                batch_size=args.batch_size,
                                shuffle=True,
                                num_workers=4,
                                pin_memory=True)
image_val_loader = DataLoader(dataset=val_dataset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=4,
                                pin_memory=True)
image_test_loader = DataLoader(dataset=test_dataset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=4,
                                pin_memory=True)

# initialize vocabulary
tokenizer = get_tokenizer("basic_english")
counter = collections.Counter()

for ann in list(full_dataset.coco.anns.values()) + list(test_dataset.coco.anns.values()):
    caption = ann['caption'].strip()
    tokens = tokenizer(caption)
    tokens = ['<sos>'] + tokens + ['<eos>']
    counter.update(tokens)

sorted_by_freq = sorted(counter.items(), key=lambda x: x[1], reverse=True)
ordered_dict = collections.OrderedDict(sorted_by_freq)
vocabulary = vocab(ordered_dict)
vocabulary.append_token('<pad>')

loading annotations into memory...
Done (t=1.70s)
creating index...
index created!


# Part 2: Neural Network Initialization
- Two primary layers of the neural network are initialized. First a CNN layer that consists of a VGG-based model
- An LSTM layer that uses the output of the CNN layer with 4096 features as a hidden input, the start token as the input, and passes each generated token as the input in the loop. The sentence is padded or cut at 20 tokens to ensure simplicity and consistency in data size during training.

In [5]:
class Model(nn.Module):
    def __init__(self, vocab_size):
        super(Model, self).__init__()

        # PART 0: PARAMS
        self.vocab_size = vocab_size

        # PART 1: CNN
        resnet =  models.resnet50(pretrained=True)
        modules = list(resnet.children())[:-1]
        self.CNN = nn.Sequential(*modules)

        self.additional_layers = nn.Sequential(
          nn.PReLU(),
          nn.Dropout(p=0.5),
          nn.Linear(in_features=2048, out_features=args.embedding_dim)
        )

        self.FC_CNN = nn.Linear(args.embedding_dim, args.hidden_size)

        # PART 2: LSTM
        self.embedding = nn.Embedding(self.vocab_size, args.embedding_dim)
        self.LSTM = nn.LSTM(args.embedding_dim, args.hidden_size, batch_first = True)
        self.FC = nn.Linear(args.hidden_size, self.vocab_size)

    def forward(self, x, captions = None):
        with torch.no_grad():
            cnn_features = self.CNN(x)

        cnn_features = cnn_features.view(cnn_features.size(0), -1)
        cnn_features = self.additional_layers(cnn_features)

        batch_size = x.size(0)
        h_i = torch.zeros(1, batch_size, args.hidden_size).to(args.device)
        c_i = torch.zeros_like(h_i).to(args.device)

        # INFERENCE MODE:
        if captions is None:
            out_logits = torch.zeros((batch_size, args.max_caption_length, self.vocab_size), device=args.device)
            current_word = torch.zeros((batch_size, 1), dtype=torch.long, device=args.device)

            for t in range(args.max_caption_length):
                if t == 0:
                    input_word = cnn_features.unsqueeze(1)
                else:
                    input_word = self.embedding(current_word)

                output, (h_i, c_i) = self.LSTM(input_word, (h_i, c_i))

                logits = self.FC(output.squeeze(1))
                _, predicted_word = logits.max(1)

                out_logits[:, t, :] = logits
                current_word = predicted_word.unsqueeze(1)

            return out_logits

        # TRAINING MODE:
        else:
            embed = self.embedding(captions)
            embed = torch.cat((cnn_features.unsqueeze(1), embed[:, :-1, :]), dim=1)
            lstm_out, (h_i, c_i) = self.LSTM(embed, (h_i, c_i))
            lstm_out = lstm_out
            out_logits = self.FC(lstm_out)
            return out_logits

In [6]:
# INITIALIZE NETWORK
model = Model(len(vocabulary))
model = model.to(args.device)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:01<00:00, 97.3MB/s]


# Part 3: Training

This part consists of the primary training loop over the desired number of epochs and specified batch size. The labels are tokenized, padded / shortened as needed, and the <sos> and <eos> tokens are added. The CNN output is then fed into the LSTM and a backpropagation is used based on the loss.

In [None]:
print('Using device {}'.format(device))

# initialize tokenizer, loss, optimizer, scheduler
tokenizer = tt.data.utils.get_tokenizer("basic_english")
loss_func = nn.CrossEntropyLoss(ignore_index = vocabulary['<pad>']).to(args.device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=9, gamma=0.1)

In [None]:
# training helpers
def pad_sequences(sequences, max_length, pad_token='<pad>'):
    return [seq + [pad_token] * (max_length - len(seq)) if len(seq) < max_length else seq[:max_length] for seq in sequences]

def remove_noise(sequences):
    return [[token for token in seq if len(str(token)) > 1] for seq in sequences]

def tokenize_and_pad(labels, vocab, max_length = args.max_caption_length):
    tokenized_labels = [['<sos>'] + tokenizer(label) + ['<eos>'] for label in labels]
    filtered_labels = remove_noise(tokenized_labels)
    padded_labels = pad_sequences(filtered_labels, max_length)
    indexed_labels = [[vocab[token] for token in seq] for seq in padded_labels]
    return indexed_labels

def process_tokens(predicted_tokens):
    processed_strings = []
    for tokens in predicted_tokens:
        tokens = tokens[1:tokens.index("<eos>")]
        processed_string = ' '.join(tokens)
        processed_strings.append(processed_string)

    return processed_strings

In [None]:
# Main Training Loop
def train(num_epochs, ckpt_load_idx = -1):
    loss_history = []
    start_epoch = 0

    if ckpt_load_idx != -1:
        ckpt_path = f'/content/drive/MyDrive/Image Caption Generator/checkpoint_{ckpt_load_idx}.pth'
        checkpoint = torch.load(ckpt_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        loss_history = checkpoint['loss_history']

    for idx_e in range(start_epoch, num_epochs):

        model.train()

        for idx, batch in enumerate(image_train_loader):
            DATA, LABELS = batch

            # Tokenize and pad labels
            tokenized_labels = tokenize_and_pad(LABELS, vocabulary)

            # Convert to tensor
            target_sequences = torch.tensor(tokenized_labels, dtype=torch.long).to(device)

            # Get LSTM predictions
            pred = model(DATA.to(device))

            pred_loss = pred.view(-1, pred.size(-1))

            target_sequences_loss = target_sequences.view(-1)

            # Calculate loss
            loss = loss_func(pred_loss, target_sequences_loss)
            loss_history.append(loss.item())

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if idx % 100 == 0:
                # PRINT EPOCH + LOSS INFO
                print(f'Epoch [{idx_e+1}/{num_epochs}], Step [{idx+1}/{len(image_train_loader)}], Loss: {loss.item():.4f}')

                # PRINT SAMPLE SENTENCE:
                probabilities = torch.softmax(pred[0], dim=-1)
                predicted_indices = torch.argmax(probabilities, dim=-1)
                predicted_tokens = vocabulary.lookup_tokens(predicted_indices.tolist())
                print(predicted_tokens, vocabulary.lookup_tokens(target_sequences[0].tolist()))

        scheduler.step()

        if idx_e != 0 and idx_e % 4 == 0:
          ckpt_path = f'/content/drive/MyDrive/Image Caption Generator/checkpoint_{idx_e}.pth'
          torch.save({
              'epoch': idx_e,
              'model_state_dict': model.state_dict(),
              'optimizer_state_dict': optimizer.state_dict(),
              'scheduler_state_dict': scheduler.state_dict(),
              'loss_history': loss_history
          }, ckpt_path)


    ckpt_path = f'/content/drive/MyDrive/Image Caption Generator/checkpoint_{idx_e}.pth'
    torch.save({
              'epoch': idx_e,
              'model_state_dict': model.state_dict(),
              'optimizer_state_dict': optimizer.state_dict(),
              'scheduler_state_dict': scheduler.state_dict(),
              'loss_history': loss_history
    }, ckpt_path)

train(args.epochs)

In [None]:
# # most_common_words = counter.most_common(20)

# # # Split the words and counts into separate lists
# # words, counts = zip(*most_common_words)

# # # Create the bar graph
# # plt.figure(figsize=(10, 8))
# # plt.bar(words, counts, color='blue')
# # plt.xlabel('Words')
# # plt.ylabel('Counts')
# # plt.title('20 Most Common Words')
# # plt.xticks(rotation=45)
# # plt.show()

# ckpt_path = f'/content/drive/MyDrive/Image Caption Generator/checkpoint_new.pth'
# torch.save({
#               'model_state_dict': model.state_dict(),
#               'optimizer_state_dict': optimizer.state_dict(),
#               'scheduler_state_dict': scheduler.state_dict(),
# }, ckpt_path)

# Part 4: Testing, Validation, and Analysis

In [None]:
# # # Sampling: Visualize Samples from Full Dataset
# model.eval()

# # ckpt_path = '/content/drive/MyDrive/Image Caption Generator/checkpoint_test.pth'
# # checkpoint = torch.load(ckpt_path)
# # model.load_state_dict(checkpoint['model_state_dict'])

# # # from torch.utils.data import Subset

# random_indices = random.sample(range(len(full_dataset.img_keys)), 20)
# subset = Subset(full_dataset, random_indices)
# sample, true_labels = next(iter(DataLoader(subset, batch_size=20, shuffle=False)))

# # # fig, axes = plt.subplots(1, 10, figsize=(15, 1.5))
# # # for i, (img, label, ax) in enumerate(zip(images, true_labels, axes)):
# # #     ax.imshow(img)
# # #     ax.axis('off')
# # #     ax.set_title(f"Sample True Label: {label.item()}\n Predicted Label {}", fontsize=8)

# # # plt.tight_layout()
# # # plt.show()

In [None]:
# sample = sample.to(args.device)

In [None]:

# images = sample.to('cpu').permute(0, 2, 3, 1).numpy()



# # print(predicted_labels)

# tokenized_labels = tokenize_and_pad(true_labels, vocabulary)

#             # Convert to tensor
# target_sequences = torch.tensor(tokenized_labels, dtype=torch.long).to(args.device)

# predicted_labels = model(sample)

# # print(predicted_labels)
# # probabilities = torch.softmax(predicted_labels, dim=-1)
# # predicted_indices = torch.argmax(probabilities, dim=-1)

# predicted_tokens = [vocabulary.lookup_tokens(i.tolist()) for i in predicted_labels]
# print(predicted_tokens)


In [None]:
# for i, (img, label, pred) in enumerate(zip(images, true_labels, process_tokens(predicted_tokens))):
#     fig, ax = plt.subplots(figsize=(1.5, 1.5))
#     ax.imshow(img)
#     ax.axis('off')
#     ax.set_title(f"Sample True Label: {label}\nPredicted Label: {pred}", fontsize=8)
#     plt.tight_layout()
#     plt.show()

# Part 5: Export Model

In [None]:
# image = Image.open('/content/drive/MyDrive/Image Caption Generator/data/IMG_7504.JPG').convert('RGB')
# transform = transforms.Compose([
#     transforms.Resize((args.img_h, args.img_w)),  # Resize the image
#     transforms.ToTensor()                         # Convert the image to a tensor
# ])
# image = transform(image).unsqueeze(0)

# pred = lstm(image.to(device))
# tokens = [vocabulary.lookup_tokens(i.tolist()) for i in pred]
# label = process_tokens(tokens)

# image = image.permute(0, 2, 3, 1).numpy()

# # for i, (img, label, pred) in enumerate(zip(images, true_labels, process_tokens(predicted_tokens))):
# fig, ax = plt.subplots(figsize=(1.5, 1.5))
# ax.imshow(image[0])
# ax.axis('off')
# ax.set_title(f"Predicted Label: {label}", fontsize=8)
# plt.tight_layout()
# plt.show()

In [None]:
# files.download('/content/drive/MyDrive/Image Caption Generator/checkpoint_49.pth')