# Machine Learning Assignment \#2 - Image Captioning


## Unzip training and validation images
- First, you must download the data files from link in assignment description, then you must move these .zip files to Colab environment in "/content/" path.
- If you use individual environment, you can apply all variables to your environment.


In [13]:
import zipfile

train_zip_file = zipfile.ZipFile('content/train_images.zip')
train_zip_file.extractall('content/')

valid_zip_file = zipfile.ZipFile('content/valid_images.zip')
valid_zip_file.extractall('content/')

## local 환경에 맞춰서 경로 수정함

## Set seed value of several main computation packages (do not touch this settings)

In [14]:
import random
import numpy as np
import torch


random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)

## Create Vocabulary Dictionary

## Dataloader for dataset

In [15]:
import nltk
from PIL import Image

import torch.utils.data as data


nltk.download('punkt_tab')
nltk.download('punkt')
class ImageCaptioningDataset(data.Dataset):
    def __init__(
        self,
        root,                         # root path of resized data
        captions_path,                # path of original text data
        vocab, transform=None,
        is_train=True,                # check whether the dataset is for training or not
        gen_train_captions_path=None, # path for generated caption data
        n_gen_captions=None           # the number of generated captions that are used for training. You can change this value. If this variable is 0, only original caption data will be used.
        ):
        self.root = root
        self.captions = []

        with open(captions_path, 'rb') as f:
          captions = pickle.load(f)

        for fname, caps in captions:
            if is_train:
              for cap in caps:
                    self.captions.append((fname, cap.strip()))

            else:
                for cap in caps:
                    self.captions.append((fname, cap.strip()))

        if is_train:
          ################################################################################################################################
          # FILL BLANK #1: load generated captions and append (image, generated_caption) pairs with the value of n_gen_captions.
          with open(gen_train_captions_path, 'rb') as f:
              gen_captions = pickle.load(f)
          for fname, caps in gen_captions:
              for i, cap in enumerate(caps):
                  if i < n_gen_captions:
                      self.captions.append((fname, cap.strip()))
          ################################################################################################################################

        self.vocab = vocab
        self.transform = transform

    def __getitem__(self, index):
        vocab = self.vocab
        path = self.captions[index][0]
        caption = self.captions[index][1]

        image = Image.open(os.path.join(self.root, path)).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)

        # Tokenize caption string
        tokens = nltk.tokenize.word_tokenize(str(caption).lower())
        caption = []

        ################################################################################################################################
        # FILL BLANK #2: Caption tokens should start with the start token and end with the end token, and should be list type.
        caption.append(vocab('<start>'))
        caption.extend([vocab(token) for token in tokens])
        caption.append(vocab('<end>'))
        ################################################################################################################################

        target = torch.Tensor(caption)
        return image, target

    def __len__(self):
        return len(self.captions)

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [16]:
def collate_fn(data):
    """
    [input]
    * data: list of tuple (image, caption).
        * image: torch tensor of shape (3, 256, 256).
        * caption: torch tensor of shape (?); variable length.
    [output]
    * images: torch tensor of shape (batch_size, 3, 256, 256).
    * targets: torch tensor of shape (batch_size, padded_length).
    * lengths: list; valid length for each padded caption.
    """
    data.sort(key=lambda x: len(x[1]), reverse=True)
    images, captions = zip(*data)

    images = torch.stack(images, 0)

    lengths = [len(caption) for caption in captions]
    targets = torch.zeros(len(captions), max(lengths)).long()
    for i, cap in enumerate(captions):
        end = lengths[i]
        targets[i, :end] = cap[:end]
    return images, targets, lengths

def collate_fn_test(data):
    images, captions = zip(*data)

    images = torch.stack(images, 0)

    lengths = [len(caption) for caption in captions]
    targets = torch.zeros(len(captions), max(lengths)).long()
    for i, cap in enumerate(captions):
        end = lengths[i]
        targets[i, :end] = cap[:end]
    return images, targets, lengths

def get_loader(root, captions_path, vocab, transform, batch_size, shuffle, num_workers, testing, is_train,
               gen_train_captions_path=None, n_gen_captions=None):
    image_captioning_dataset = ImageCaptioningDataset(root=root, captions_path=captions_path, vocab=vocab, transform=transform, is_train=is_train, gen_train_captions_path=gen_train_captions_path, n_gen_captions=n_gen_captions)
    # This will return (images, captions, lengths) for each iteration by collate_fn in Dataloader.
    # images: a tensor of shape (batch_size, 3, 224, 224).
    # captions: a tensor of shape (batch_size, padded_length).
    # lengths: a list indicating valid length for each caption. length is (batch_size).
    if not testing:
        data_loader = torch.utils.data.DataLoader(dataset=image_captioning_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=collate_fn)
    else:
        data_loader = torch.utils.data.DataLoader(dataset=image_captioning_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=collate_fn_test)
    return data_loader


# Train Image Captioning Model

## Model Definition

In [17]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.nn.utils.rnn import pack_padded_sequence


class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        ################################################################################################################################
        '''
        FILL BLANK #3:
          - load pretrained resnet152 by torchvision.models
          - drop the last layer.
          - modulate pretrained resnet
          - attach an embedding layer to integrate text embedding (momentum value: 0.01)
        '''
        resnet      = models.resnet152(pretrained=True)
        modules     = list(resnet.children())[:-1]  # drop the last layer (fully connected)
        self.resnet = nn.Sequential(*modules)
        self.linear = nn.Linear(resnet.fc.in_features, embed_size)
        self.bn     = nn.BatchNorm1d(embed_size, momentum=0.01)
        ################################################################################################################################


    def forward(self, images):
        with torch.no_grad():
            features = self.resnet(images)
        features = features.reshape(features.size(0), -1)
        features = self.bn(self.linear(features))
        return features


class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_seq_length=20):
        super(DecoderRNN, self).__init__()

        ################################################################################################################################
        #FILL BLANK #4: define text embedding layer and lstm layer
        self.embed  = nn.Embedding(vocab_size, embed_size)
        self.lstm   = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        ################################################################################################################################

        self.max_seg_length = max_seq_length


    def forward(self, features, captions, lengths):
        embeddings = self.embed(captions)
        embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
        packed = pack_padded_sequence(embeddings, lengths, batch_first=True)
        hiddens, _ = self.lstm(packed)
        outputs = self.linear(hiddens[0])
        return outputs

    def sample(self, features, states=None):
        sampled_indexes = []
        inputs = features.unsqueeze(1)
        for i in range(self.max_seg_length):
            hiddens, states = self.lstm(inputs, states)   # hiddens: (batch_size, 1, hidden_size)
            outputs = self.linear(hiddens.squeeze(1))     # outputs: (batch_size, vocab_size)
            _, predicted = outputs.max(1)                 # predicted: (batch_size)
            sampled_indexes.append(predicted)
            inputs = self.embed(predicted)                # inputs: (batch_size, embed_size)
            inputs = inputs.unsqueeze(1)                  # inputs: (batch_size, 1, embed_size)
        sampled_indexes = torch.stack(sampled_indexes, 1) # sampled_indexes: (batch_size, max_seq_length)
        return sampled_indexes


## Vocabulay Class Definition

In [18]:
class Vocabulary(object):
    """Simple vocabulary wrapper."""
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

    def add_word(self, word):
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __call__(self, word):
        if not word in self.word2idx:
            return self.word2idx['']
        return self.word2idx[word]

    def __len__(self):
        return len(self.word2idx)

## Model Training

In [None]:
import os
import pickle

import torch
from torch.nn.utils.rnn import pack_padded_sequence
from torchvision import transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

crop_size = 224

# 원래 경로
# train_image_dir = "./train_images" # training image path (train)
# valid_image_dir = "./valid_images"   # validation image path (valid)

# train_captions_path = "./train_captions.pkl"
# valid_captions_path = "./valid_captions.pkl"
# vocab_path = "./vocab.pkl" # pre-processsed vocab file path

# 로컬 환경에 맞도록 수정한 경로
train_image_dir = "content/train_images" # training image path (train)
valid_image_dir = "content/valid_images"   # validation image path (valid)

train_captions_path = "content/train_captions.pkl"
valid_captions_path = "content/valid_captions.pkl"
vocab_path="content/vocab.pkl"

# make directory to save model
model_path = "./models/"   # save model path
if not os.path.exists(model_path):
    os.makedirs(model_path)

# Load vocabulary flie
with open(vocab_path, 'rb') as f:
    vocab = pickle.load(f)

train_transform = transforms.Compose([
    transforms.RandomCrop(crop_size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

val_transform = transforms.Compose([
    transforms.Resize(crop_size),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

batch_size = 128
num_workers = 0 ## <- 2 / 로컬 환경에 맞도록 2->0으로 수정.

n_gen_train_captions = 0     # Decide the number of generated captions for each image (integer)

# 원래 경로
# VLM_gen_train_captions_path = "./generated_captions.pkl" # path of generated captions

# 로컬 환경에 맞도록 경로 수정
VLM_gen_train_captions_path = "content/generated_captions.pkl" # path of generated captions


train_data_loader = get_loader(train_image_dir, train_captions_path, vocab, train_transform, batch_size, shuffle=True, num_workers=num_workers, testing=False, is_train=True,
                               gen_train_captions_path=VLM_gen_train_captions_path, n_gen_captions=n_gen_train_captions)
valid_data_loader = get_loader(valid_image_dir, valid_captions_path, vocab, val_transform, batch_size, shuffle=False, num_workers=num_workers, testing=False, is_train=False)


embed_size = 256
hidden_size = 512
num_layers = 1

encoder = EncoderCNN(embed_size).to(device)
decoder = DecoderRNN(embed_size, hidden_size, len(vocab), num_layers).to(device)

num_epochs = 10
learning_rate = 0.001

################################################################################################################################
# FILL BLANK #5 : Use cross entropy loss for image captioning model
criterion = torch.nn.CrossEntropyLoss()
################################################################################################################################

# get trainable parameters (freeze the pretrained ResNet backbone).
params = list(decoder.parameters()) + list(encoder.linear.parameters()) + list(encoder.bn.parameters())
optimizer = torch.optim.Adam(params, lr=learning_rate)

In [20]:
import time
import numpy as np

log_step = 20
start_time = time.time()
best_valid_perplexity = 9999999999999999

for epoch in range(num_epochs):
    print("[ Training ]")
    total_loss = 0
    total_count = 0
    total_step = len(train_data_loader)
    for i, (images, captions, lengths) in enumerate(train_data_loader):
        images = images.to(device)
        captions = captions.to(device)
        targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]

        features = encoder(images)
        outputs = decoder(features, captions, lengths)
        loss = criterion(outputs, targets)
        decoder.zero_grad()
        encoder.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_count += images.shape[0]

        if i % log_step == 0:
            print('Epoch [{}/{}], Step [{}/{}], Average Loss: {:.4f}, Perplexity: {:5.4f}, Elapsed time: {:.4f}s'
                  .format(epoch+1, num_epochs, i, total_step, total_loss / total_count, np.exp(loss.item()), time.time() - start_time))

    torch.save(decoder.state_dict(), os.path.join(model_path, f'decoder-{epoch + 1}.ckpt'))
    torch.save(encoder.state_dict(), os.path.join(model_path, f'encoder-{epoch + 1}.ckpt'))
    print(f"Model saved: {os.path.join(model_path, f'decoder-{epoch + 1}.ckpt')}")
    print(f"Model saved: {os.path.join(model_path, f'encoder-{epoch + 1}.ckpt')}")

    print("[ Validation ]")
    total_loss = 0
    total_count = 0
    total_step = len(valid_data_loader)
    with torch.no_grad():
        for i, (images, captions, lengths) in enumerate(valid_data_loader):
            images = images.to(device)
            captions = captions.to(device)
            targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]

            features = encoder(images)
            outputs = decoder(features, captions, lengths)
            loss = criterion(outputs, targets)

            total_loss += loss.item()
            total_count += images.shape[0]

            if i % log_step == 0:
                print('Epoch [{}/{}], Step [{}/{}], Average Loss: {:.4f}, Perplexity: {:5.4f}, Elapsed time: {:.4f}s'
                      .format(epoch+1, num_epochs, i, total_step, total_loss / total_count, np.exp(loss.item()), time.time() - start_time))

    ####### You can mount your drive and save the python, checkpoint and data files to your drive to prevent the elimination of related files by runtime disconnection.

    if best_valid_perplexity >= np.exp(total_loss / total_count):
        torch.save(decoder.state_dict(), os.path.join(model_path, 'decoder-best.ckpt'))
        torch.save(encoder.state_dict(), os.path.join(model_path, 'encoder-best.ckpt'))
        print(f"Best Model saved at {epoch+1}: {os.path.join(model_path, 'decoder-best.ckpt')}")
        print(f"Best Model saved at {epoch+1}: {os.path.join(model_path, 'encoder-best.ckpt')}")

        best_valid_perplexity = np.exp(total_loss / total_count)
        best_epoch = epoch+1

print(f"[!] Best model saved at {best_epoch}...")

[ Training ]
Epoch [1/10], Step [0/235], Average Loss: 0.0636, Perplexity: 3447.7605, Elapsed time: 0.8638s
Epoch [1/10], Step [20/235], Average Loss: 0.0487, Perplexity: 176.3974, Elapsed time: 18.2856s
Epoch [1/10], Step [40/235], Average Loss: 0.0427, Perplexity: 71.9925, Elapsed time: 35.7968s
Epoch [1/10], Step [60/235], Average Loss: 0.0393, Perplexity: 60.7641, Elapsed time: 53.5451s
Epoch [1/10], Step [80/235], Average Loss: 0.0370, Perplexity: 42.1145, Elapsed time: 71.1505s
Epoch [1/10], Step [100/235], Average Loss: 0.0354, Perplexity: 35.9566, Elapsed time: 89.4146s
Epoch [1/10], Step [120/235], Average Loss: 0.0342, Perplexity: 36.5266, Elapsed time: 107.1721s
Epoch [1/10], Step [140/235], Average Loss: 0.0333, Perplexity: 33.5420, Elapsed time: 124.8035s
Epoch [1/10], Step [160/235], Average Loss: 0.0325, Perplexity: 31.2103, Elapsed time: 142.4682s
Epoch [1/10], Step [180/235], Average Loss: 0.0318, Perplexity: 27.9156, Elapsed time: 160.1455s
Epoch [1/10], Step [200/235