In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [14]:
import time
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence
from models import Encoder, DecoderWithAttention
from dataset import *
from utils import *
from nltk.translate.bleu_score import corpus_bleu

import numpy as np
import pickle

from train import train, validate

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# training parameters
grad_clip = 5.  # clip gradients at an absolute value of
alpha_c = 1.  # regularization parameter for 'doubly stochastic attention', as in the paper
print_freq = 100  # print training/validation stats every __ batches


In [3]:
def fit(t_params, checkpoint=None, m_params=None):

    # info
    data_name = t_params['data_name']
    imgs_path = t_params['imgs_path']
    df_path = t_params['df_path']
    vocab = t_params['vocab']

    start_epoch = 0
    epochs_since_improvement = 0
    best_bleu4 = 0
    epochs = t_params['epochs']
    batch_size = t_params['batch_size']
    workers = t_params['workers']
    encoder_lr = t_params['encoder_lr']
    decoder_lr = t_params['decoder_lr']
    fine_tune_encoder = t_params['fine_tune_encoder']
    
    pretrained_embeddings = t_params['pretrained_embeddings']
    fine_tune_embeddings = t_params['fine_tune_embeddings']
    embeddings_matrix = m_params['embeddings_matrix']

    # init / load checkpoint
    if checkpoint is None:

        # getting hyperparameters
        attention_dim = m_params['attention_dim']
        embed_dim = m_params['embed_dim']
        decoder_dim = m_params['decoder_dim']
        encoder_dim = m_params['encoder_dim']
        dropout = m_params['dropout']

        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                      embed_dim=embed_dim,
                                      decoder_dim=decoder_dim,
                                      encoder_dim=encoder_dim,
                                      vocab_size=len(vocab),
                                      dropout=dropout)
        if pretrained_embeddings:
            decoder.load_pretrained_embeddings(torch.tensor(embeddings_matrix, dtype=torch.float32))
            decoder.fine_tune_embeddings(fine_tune=fine_tune_embeddings)
        
        decoder_optimizer = torch.optim.RMSprop(params=filter(lambda p:p.requires_grad, decoder.parameters()),
                                            lr=decoder_lr)
        
        encoder=Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.RMSprop(params=filter(lambda p:p.requires_grad, encoder.parameters()),
                                            lr=encoder_lr) if fine_tune_encoder else None
    # load checkpoint
    else:
        checkpoint = torch.load(checkpoint)
        print('Loaded Checkpoint!!')
        start_epoch = checkpoint['epoch'] + 1
        print(f"Starting Epoch: {start_epoch}")
        epochs_since_improvement = checkpoint['epochs_since_imrovment']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['deocder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.RMSprop(params=filter(lambda p:p.requires_grad, encoder.parameters()),
                                                lr=encoder_lr)
    # move to gpu, if available
    decoder = decoder.to(device)
    encoder = encoder.to(device)
    
    # loss function
    criterion = nn.CrossEntropyLoss().to(device)
    
    # dataloaders
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    print('Loading Data')
    train_loader, val_loader = get_loaders(batch_size, imgs_path, df_path, transform, vocab, False ,workers)
    print('_'*50)

    print('-'*20, 'Fitting', '-'*20)
    for epoch in range(start_epoch, epochs):
        
        # decay lr is there is no improvement for 8 consecutive epochs and terminate after 20
        if epochs_since_improvement == 20:
            print('No improvement for 20 consecutive epochs, terminating...')
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)
        
        print('_'*50)
        print('-'*20, 'Training', '-'*20)
        # one epoch of training
        train(train_loader=train_loader,
            encoder=encoder,
            decoder=decoder,
            criterion=criterion,
            encoder_optimizer=encoder_optimizer,
            decoder_optimizer=decoder_optimizer,
            epoch=epoch)
        
        # one epoch of validation
        print('-'*20, 'Validation', '-'*20)
        recent_bleu4 = validate(val_loader=val_loader,
            encoder=encoder,
            decoder=decoder,
            criterion=criterion,
            vocab=vocab)

        
        # check for improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print(f'\nEpochs since last improvement: {epochs_since_improvement,}')
        else:
            # reset
            epochs_since_improvement = 0
        
        save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer,
            decoder_optimizer, recent_bleu4, is_best)

In [4]:
# Model parameters
encoder_dim = 2048 # resnet101
emb_dim = 300  # dimension of word embeddings
attention_dim = 512  # dimension of attention linear layers
decoder_dim = 512  # dimension of decoder RNN
dropout = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # sets device for model and PyTorch tensors
cudnn.benchmark = True  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead

# training parameters
epochs = 1  # number of epochs to train for (if early stopping is not triggered)
batch_size = 64
workers = 2
encoder_lr = 1e-4  # learning rate for encoder if fine-tuning
decoder_lr = 4e-4  # learning rate for decoder
fine_tune_encoder = False  # fine-tune encoder
pretrained_embeddings = True
fine_tune_embeddings = False
checkpoint = None  # path to checkpoint, None if none

In [5]:
DATA_NAME = 'flickr8k_5_cap_per_img_2_min_word_freq_resnet101_fullvocab_fix_ds_rmsprop'

# local
DATA_JSON_PATH = 'data.json'
IMGS_PATH = 'flickr/Images/'
# kaggle paths
# DATA_JSON_PATH = '/kaggle/working/Image-Captioning/data.json'
# IMGS_PATH = '../input/flickr8kimagescaptions/flickr8k/images/'
#colab
# DATA_JSON_PATH = 'Image-Captioning/data.json'
# IMGS_PATH = 'flickr8k/images/'

In [6]:
# load vocab
vocab = build_vocab(DATA_JSON_PATH); len(vocab)

100%|██████████| 40000/40000 [00:00<00:00, 388031.83it/s]


5089

### Pre-trained Embeddings

In [7]:
glove = pd.read_csv('glove.6B/glove.6B.300d.txt', sep=' ', quoting=3, header=None, index_col=0)
glove_embedding = {key: val.values for key, val in glove.T.items()}

In [8]:
def create_embedding_matrix(vocab, embedding_dict, dimension):
    embedding_matrix = np.zeros((len(vocab), dimension))
    
    for word, index in vocab.stoi.items():
        if word in embedding_dict:
            embedding_matrix[index] = embedding_dict[word]
        else: 
            embedding_matrix[index] = np.random.uniform(-.1, .1, size=dimension)
    return embedding_matrix

In [9]:
embedding_matrix = create_embedding_matrix(vocab, glove_embedding, 300)

In [10]:
embedding_matrix.shape

(5089, 300)

In [15]:
with open('embeddings.pkl','wb') as f: pickle.dump(embedding_matrix, f)

with open('embeddings.pkl','rb') as f: embedding_matrix1 = pickle.load(f)

np.array_equal(embedding_matrix,embedding_matrix1) #sanity check

True

In [11]:
t_params = {
    'data_name': DATA_NAME,
    'imgs_path': IMGS_PATH,
    'df_path': DATA_JSON_PATH,
    'vocab': vocab,
    'epochs': epochs,
    'batch_size': batch_size,
    'workers': workers,
    'decoder_lr': decoder_lr,
    'encoder_lr': encoder_lr,
    'fine_tune_encoder': fine_tune_encoder,
    'pretrained_embeddings': pretrained_embeddings,
    'fine_tune_embeddings': fine_tune_embeddings,
}

m_params = {
    'attention_dim': attention_dim,
    'embed_dim': emb_dim,
    'decoder_dim': decoder_dim,
    'encoder_dim': encoder_dim,
    'dropout': dropout,
    'embeddings_matrix': embedding_matrix
}

t_params

{'data_name': 'flickr8k_5_cap_per_img_2_min_word_freq_resnet101_fullvocab_fix_ds_rmsprop',
 'imgs_path': 'flickr/Images/',
 'df_path': 'data.json',
 'vocab': <dataset.Vocabulary at 0x7f3d6dbd0350>,
 'epochs': 1,
 'batch_size': 64,
 'workers': 2,
 'decoder_lr': 0.0004,
 'encoder_lr': 0.0001,
 'fine_tune_encoder': False,
 'pretrained_embeddings': True,
 'fine_tune_embeddings': False}

In [12]:
fit(t_params=t_params, m_params=m_params)

Loading Data
Dataset split: train
Unique images: 6000
Total size: 30000
Dataset split: val
Unique images: 1000
Total size: 5000
__________________________________________________
-------------------- Fitting --------------------
__________________________________________________
-------------------- Training --------------------
Epoch: [0][0/469]	Batch Time 6.728 (6.728)	Data Load Time 3.440 (3.440)	Loss 9.4371 (9.4371)	Top-5 Accuracy 0.000 (0.000)
Epoch: [0][100/469]	Batch Time 1.121 (1.090)	Data Load Time 0.000 (0.034)	Loss 5.1338 (5.7062)	Top-5 Accuracy 49.603 (43.075)
Epoch: [0][200/469]	Batch Time 1.105 (1.083)	Data Load Time 0.000 (0.017)	Loss 5.0455 (5.3644)	Top-5 Accuracy 50.611 (47.638)
Epoch: [0][300/469]	Batch Time 1.086 (1.072)	Data Load Time 0.000 (0.012)	Loss 4.6390 (5.1817)	Top-5 Accuracy 56.369 (50.053)
Epoch: [0][400/469]	Batch Time 1.001 (1.064)	Data Load Time 0.000 (0.009)	Loss 4.8202 (5.0634)	Top-5 Accuracy 56.081 (51.656)
-------------------- Validation -----------