<a href="https://colab.research.google.com/github/jeonghojo00/ImageCaptioning/blob/main/1_ImgCaptioning_EncoderSearching.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#0. Initialization

In [None]:
# Mount google drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Change directory to the package folder
%cd '/content/drive/MyDrive/Colab Notebooks/ImageCaptioning/'
# Verify the contents of the current folder
!ls

/content/drive/MyDrive/Colab Notebooks/ImageCaptioning
checkpoint		      license.txt		  pycocotools
code			      load_data.py		  README.md
cs7643-final-project	      lstm_decoder.py		  resize_image.py
custom_caption_eval.py	      main.py			  resnet.py
data			      model			  show_attend_tell.py
data_loader.py		      models			  split_caption.py
efficientnet.py		      models.py			  util
experiments		      nic_decoder_ResNet101.ckpt  utils.py
get_google_word2vec_model.sh  nic_encoder_ResNet101.ckpt  vgg.py
get_stanford_models.sh	      pretraining.py		  vocab.pkl
inception.py		      __pycache__		  vocabulary.py
learned_models		      pycocoevalcap


In [None]:
import sys
import subprocess
import pkg_resources

required = {'efficientnet_pytorch', 'timm', 'tqdm', 'torch', 'torchvision'}
installed = {pkg.key for pkg in pkg_resources.working_set}
missing = required - installed

if missing:
    python = sys.executable
    subprocess.check_call([python, '-m', 'pip', 'install', *missing], stdout=subprocess.DEVNULL)

In [None]:
# Import Libraries
import os
from PIL import Image
from tqdm import tqdm
import pickle
import nltk
from collections import Counter

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence
from torchvision import transforms

import torchvision.models as models
from model import (
    encoderCNN,
    decoderRNN
)
from load_data import *
from resize_image import *
from split_caption import *

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [None]:
# Define data folder
image_dir = './data/flickr8k/Images' # Original images folder
caption_path = './data/flickr8k/captions.txt' # Original caption file with path

train_image_dir = './data/resized_flickr8k/train/Images' # Resized train images folder
val_image_dir = './data/resized_flickr8k/val/Images' # Resized validation images folder
test_image_dir = './data/resized_flickr8k/test/Images' # Resized test images folder
train_caption_path = "./data/resized_flickr8k/train/captions.txt" # Resized train images' captions
val_caption_path = "./data/resized_flickr8k/val/captions.txt" # Resized validation images' captions
test_caption_path = "./data/resized_flickr8k/test/captions.txt" # Resized test images' captions

# Run for 299x299 resized images

train299_image_dir = './data/resized_flickr8k_299/train/Images' # Resized train images folder
val299_image_dir = './data/resized_flickr8k_299/val/Images' # Resized validation images folder
test299_image_dir = './data/resized_flickr8k_299/test/Images' # Resized test images folder
train299_caption_path = "./data/resized_flickr8k_299/train/captions.txt" # Resized train images' captions
val299_caption_path = "./data/resized_flickr8k_299/val/captions.txt" # Resized validation images' captions
test299_caption_path = "./data/resized_flickr8k_299/test/captions.txt" # Resized test images' captions



resized_image = [299,299] # Resized image size 299 for inception and 256 for the others
num_train_images = 6000
num_val_images = 1000

vocab_path = "./vocab.pkl" # vocabulary file
word_threshold = 4 # Minimum occurrances of words

#1. Preprocess Images and Captions

## Preprocess images

In [None]:
resizeImage_required = False
if resizeImage_required == True:
    save_resized_images(image_dir, train_image_dir, val_image_dir, test_image_dir, num_train_images, num_val_images, resize_image=resized_image)

## Preprocess captions for vocab dictionary and caption divisions

In [None]:
splitCaption_required = True

if splitCaption_required == True:
    split_caption(caption_path, train299_caption_path, val299_caption_path, test299_caption_path, vocab_path, num_train_images, num_val_images, word_threshold)

In [None]:
# Number of Traning data
!wc -l ./data/resized_flickr8k_299/train/captions.txt
# Number of Validation data
!wc -l ./data/resized_flickr8k_299/val/captions.txt
# Number of Testing data
!wc -l ./data/resized_flickr8k_299/test/captions.txt

63556 ./data/resized_flickr8k_299/train/captions.txt
10000 ./data/resized_flickr8k_299/val/captions.txt
10910 ./data/resized_flickr8k_299/test/captions.txt


#2. Train

In [None]:
def load_encoder(encoder_name, embed_size):
    encoder = None
    if encoder_name == 'ResNet152':
        encoder = encoderCNN.ResNet152(embed_size)
    elif encoder_name == 'Efficientnet':
        encoder = encoderCNN.Efficientnet(embed_size)
    elif encoder_name == "DenseNet161":
        encoder = encoderCNN.DenseNet161(embed_size)
    elif encoder_name == "InceptionV3":
        encoder = encoderCNN.InceptionV3(embed_size)
    elif encoder_name == "GoogleNet":
        encoder = encoderCNN.GoogleNet(embed_size)
    elif encoder_name == "MobileNetV3":
        encoder = encoderCNN.MobileNetV3(embed_size)
    elif encoder_name == "ResNeXt101":
        encoder = encoderCNN.ResNeXt101(embed_size)
    elif encoder_name == "WideResNet101":
        encoder = encoderCNN.WideResNet101(embed_size)
    elif encoder_name == "MNASNet":
        encoder = encoderCNN.MNASNet(embed_size)
    elif encoder_name == "ShuffleNetV2":
        encoder = encoderCNN.ShuffleNetV2(embed_size)
    elif encoder_name == "SqueezeNet":
        encoder = encoderCNN.SqueezeNet(embed_size)
    
    return encoder

In [None]:
import torch
from torch.nn.utils.rnn import pack_padded_sequence


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

models_path = "./learned_models/" # model path that learned models will be saved
crop_size = 224 #299 if encoder=="InceptionV3"

vocab_path = "./vocab.pkl" # Vocabulary path that is preprocessed

# Make a directory that a learned model will be saved
if not os.path.exists(models_path):
    os.makedirs(models_path)

# Load Vocabulary dictionary (Vocabulary class needs to be defined first)
with open(vocab_path, 'rb') as f:
    vocab = pickle.load(f)

# Make transforms for training, validating, and testing the model
train_transform = transforms.Compose([ 
    transforms.RandomCrop(crop_size),
    transforms.RandomHorizontalFlip(), 
    transforms.ToTensor(), 
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

val_transform = transforms.Compose([ 
    transforms.Resize(crop_size), 
    transforms.ToTensor(), 
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

test_transform = transforms.Compose([ 
    transforms.Resize(crop_size), 
    transforms.ToTensor(), 
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

batch_size = 128
num_workers = 2

# Get the dataloaders
train_data_loader = get_loader(train299_image_dir, train299_caption_path, vocab, train_transform, batch_size, shuffle=True, num_workers=num_workers, testing=False) 
val_data_loader = get_loader(val299_image_dir, val299_caption_path, vocab, val_transform, batch_size, shuffle=False, num_workers=num_workers, testing=False)
test_data_loader = get_loader(test299_image_dir, test299_caption_path, vocab, test_transform, batch_size, shuffle=False, num_workers=num_workers, testing=True)

In [None]:
train_data_loader.dataset[0][0].shape

torch.Size([3, 299, 299])

In [None]:
# Model Hyperparamter
embed_size = 256 # Embedding size for output of encoder and input of decoder
hidden_size = 512 # LSTM hidden states
num_layers = 1 # Number of layers of LSTM

# Declare a decoder
decoder = decoderRNN.DecoderLSTM(embed_size, hidden_size, len(vocab), num_layers).to(device)

num_epochs = 10
learning_rate = 0.001

log_step = 20 # Number of steps to show a log for each batch
save_step = 1000 # Number of steps to save the learned model

# Set loss function
criterion = nn.CrossEntropyLoss()

In [None]:
'''
encoder_list = ['ResNet152',
                'Efficientnet',
                'DenseNet161',
                'GoogleNet',
                'MobileNetV3',
                'ResNeXt101',
                'WideResNet101',
                'MNASNet',
                'ShuffleNetV2',
                'SqueezeNet']
'''
encoder_list = ['InceptionV3']
# InceptionV3 needs input_size of 299
# InceptionV3 needs to be processed separately

In [None]:
import time
import numpy as np

start_time = time.time()
loss_dict = dict()
loss_dict['train'] = dict()
loss_dict['val'] = dict()
perplexity_dict = dict()
perplexity_dict['train'] = dict()
perplexity_dict['val'] = dict()

for encoder_name in encoder_list:
    encoder = load_encoder(encoder_name, embed_size).to(device)
    each_model_path = os.path.join(models_path, encoder_name)
    if not os.path.exists(each_model_path):
        os.makedirs(each_model_path)
    
    if encoder_name == "ResNet152":
        params = list(decoder.parameters()) + list(encoder.linear.parameters()) + list(encoder.bn.parameters())   
    else:
        params = list(decoder.parameters()) + list(encoder.parameters()) 
    optimizer = torch.optim.Adam(params, lr=learning_rate)
        
    for epoch in range(num_epochs):
        # Train
        print("[ Training ]")
        total_loss = 0
        total_count = 0
        total_step = len(train_data_loader)
        for i, (images, captions, lengths) in enumerate(train_data_loader):
            images = images.to(device)
            captions = captions.to(device)
            targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]

            # Training by forward and backward
            features = encoder(images)
            if encoder_name == "InceptionV3":
                features = features.logits
            outputs = decoder(features, captions, lengths)
            loss = criterion(outputs, targets)
            decoder.zero_grad()
            encoder.zero_grad()
            loss.backward()
            optimizer.step()

            # Calculate loss
            total_loss += loss.item()
            total_count += images.shape[0]

            # Print a log
            avg_loss = total_loss / total_count
            perplexity = np.exp(loss.item())

            if epoch in loss_dict['train']:
                loss_dict['train'][epoch].append(avg_loss)
                perplexity_dict['train'][epoch].append(perplexity)
            else:
                loss_dict['train'][epoch] = list()
                perplexity_dict['train'][epoch] = list()
                loss_dict['train'][epoch].append(avg_loss)
                perplexity_dict['train'][epoch].append(perplexity)

            if i % log_step == 0:
                print('Epoch [{}/{}], Step [{}/{}], Average Loss: {:.4f}, Perplexity: {:5.4f}, Elapsed time: {:.4f}s'
                      .format(epoch, num_epochs, i, total_step, avg_loss, perplexity, time.time() - start_time))

        # Save learned models
        torch.save(decoder.state_dict(), os.path.join(each_model_path, f'decoder-{epoch + 1}.ckpt'))
        torch.save(encoder.state_dict(), os.path.join(each_model_path, f'encoder-{epoch + 1}.ckpt'))
        print(f"Model saved: {os.path.join(each_model_path, f'decoder-{epoch + 1}.ckpt')}")
        print(f"Model saved: {os.path.join(each_model_path, f'encoder-{epoch + 1}.ckpt')}")

        # Validate
        print("[ Validation ]")
        total_loss = 0
        total_count = 0
        total_step = len(val_data_loader)
        with torch.no_grad():
            for i, (images, captions, lengths) in enumerate(val_data_loader):
                images = images.to(device)
                captions = captions.to(device)
                targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]

                # Only Forward
                features = encoder(images)
                if encoder_name == "InceptionV3":
                    features = features.logits
                outputs = decoder(features, captions, lengths)
                loss = criterion(outputs, targets)

                # Calculate the loss
                total_loss += loss.item()
                total_count += images.shape[0]

                # Print the log
                avg_loss = total_loss / total_count
                perplexity = np.exp(loss.item())

                if epoch in loss_dict['val']:
                    loss_dict['val'][epoch].append(avg_loss)
                    perplexity_dict['val'][epoch].append(perplexity)
                else:
                    loss_dict['val'][epoch] = list()
                    perplexity_dict['val'][epoch] = list()
                    loss_dict['val'][epoch].append(avg_loss)
                    perplexity_dict['val'][epoch].append(perplexity)

                if i % log_step == 0:
                    print('Epoch [{}/{}], Step [{}/{}], Average Loss: {:.4f}, Perplexity: {:5.4f}, Elapsed time: {:.4f}s'
                          .format(epoch, num_epochs, i, total_step, avg_loss, perplexity, time.time() - start_time))

    # Save loss and perplexity as pickle files
    loss_path = os.path.join(each_model_path, "loss.pkl")
    with open(loss_path, 'wb') as f:
        pickle.dump(loss_dict, f)

    perp_path = os.path.join(each_model_path, "perplexity.pkl")
    with open(perp_path, 'wb') as f:
        pickle.dump(perplexity_dict, f)

[ Training ]
Epoch [0/10], Step [0/497], Average Loss: 0.0250, Perplexity: 24.5903, Elapsed time: 1.5747s
Epoch [0/10], Step [20/497], Average Loss: 0.0230, Perplexity: 14.4718, Elapsed time: 9.6565s
Epoch [0/10], Step [40/497], Average Loss: 0.0222, Perplexity: 13.9485, Elapsed time: 17.7957s
Epoch [0/10], Step [60/497], Average Loss: 0.0219, Perplexity: 16.8395, Elapsed time: 25.8841s
Epoch [0/10], Step [80/497], Average Loss: 0.0217, Perplexity: 15.6737, Elapsed time: 34.0967s
Epoch [0/10], Step [100/497], Average Loss: 0.0215, Perplexity: 13.7893, Elapsed time: 42.1262s
Epoch [0/10], Step [120/497], Average Loss: 0.0214, Perplexity: 14.4820, Elapsed time: 50.2470s
Epoch [0/10], Step [140/497], Average Loss: 0.0213, Perplexity: 12.3257, Elapsed time: 58.4757s
Epoch [0/10], Step [160/497], Average Loss: 0.0212, Perplexity: 12.9532, Elapsed time: 66.5621s
Epoch [0/10], Step [180/497], Average Loss: 0.0212, Perplexity: 14.6120, Elapsed time: 74.4760s
Epoch [0/10], Step [200/497], Avera