In [1]:
# get the code for kaggle
!git clone https://github.com/moaaztaha/Image-Captioning
py_files_path = '/kaggle/working/Image-Captioning/'
import sys
sys.path.append(py_files_path)

Cloning into 'Image-Captioning'...
remote: Enumerating objects: 298, done.[K
remote: Counting objects: 100% (298/298), done.[K
remote: Compressing objects: 100% (221/221), done.[K
remote: Total 298 (delta 176), reused 194 (delta 74), pack-reused 0[K
Receiving objects: 100% (298/298), 14.74 MiB | 14.86 MiB/s, done.
Resolving deltas: 100% (176/176), done.


In [2]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
import time 
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence
from models import Encoder, DecoderWithAttention
from dataset import *
from utils import *
from train import *
from nltk.translate.bleu_score import corpus_bleu

In [4]:
# Model parameters
encoder_dim = 2048 # resnet101
emb_dim = 512  # dimension of word embeddings
attention_dim = 512  # dimension of attention linear layers
decoder_dim = 512  # dimension of decoder RNN
dropout = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # sets device for model and PyTorch tensors
cudnn.benchmark = True  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead

# training parameters
epochs = 6  # number of epochs to train for (if early stopping is not triggered)
batch_size = 64
workers = 4
encoder_lr = 1e-4  # learning rate for encoder if fine-tuning
decoder_lr = 4e-4  # learning rate for decoder
fine_tune_encoder = True  # fine-tune encoder?
checkpoint = None  # path to checkpoint, None if none

In [5]:
DATA_NAME = 'flickr8k_5_cap_per_img_2_min_word_freq_resnet101_finetunefromthestart'

# local
# DATA_JSON_PATH = 'data.json'
# IMGS_PATH = 'flickr/Images/'
# kaggle paths
DATA_JSON_PATH = '/kaggle/working/Image-Captioning/data.json'
IMGS_PATH = '../input/flickr8kimagescaptions/flickr8k/images/'

In [6]:
# load vocab
vocab = build_vocab(DATA_JSON_PATH)

100%|██████████| 30000/30000 [00:00<00:00, 223533.68it/s]


In [7]:
len(vocab)

4451

In [8]:
t_params = {
    'data_name': DATA_NAME,
    'imgs_path': IMGS_PATH,
    'df_path': DATA_JSON_PATH,
    'vocab': vocab,
    'epochs': epochs,
    'batch_size': batch_size,
    'workers': workers,
    'decoder_lr': decoder_lr,
    'encoder_lr': encoder_lr,
    'fine_tune_encoder': fine_tune_encoder
}

m_params = {
    'attention_dim': attention_dim,
    'embed_dim': emb_dim,
    'decoder_dim': decoder_dim,
    'encoder_dim': encoder_dim,
    'dropout': dropout
}

In [9]:
t_params

{'data_name': 'flickr8k_5_cap_per_img_2_min_word_freq_resnet101_finetunefromthestart',
 'imgs_path': '../input/flickr8kimagescaptions/flickr8k/images/',
 'df_path': '/kaggle/working/Image-Captioning/data.json',
 'vocab': <dataset.Vocabulary at 0x7f6e3b0fcc90>,
 'epochs': 6,
 'batch_size': 64,
 'workers': 4,
 'decoder_lr': 0.0004,
 'encoder_lr': 0.0001,
 'fine_tune_encoder': True}

In [10]:
fit(t_params=t_params, m_params=m_params)

Downloading: "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth" to /root/.cache/torch/hub/checkpoints/resnet101-5d3b4d8f.pth


  0%|          | 0.00/170M [00:00<?, ?B/s]

Loading Data
Dataset split: train
Unique images: 6000
Total size: 30000
Dataset split: val
Unique images: 1000
Total size: 5000
__________________________________________________
-------------------- Fitting --------------------
__________________________________________________
-------------------- Training --------------------
Epoch: [0][0/469]	Batch Time 6.150 (6.150)	Data Load Time 2.252 (2.252)	Loss 9.3009 (9.3009)	Top-5 Accuracy 0.000 (0.000)
Epoch: [0][100/469]	Batch Time 0.685 (0.777)	Data Load Time 0.000 (0.023)	Loss 5.6434 (6.2207)	Top-5 Accuracy 42.301 (35.474)
Epoch: [0][200/469]	Batch Time 0.753 (0.748)	Data Load Time 0.000 (0.012)	Loss 4.9385 (5.7774)	Top-5 Accuracy 55.918 (41.256)
Epoch: [0][300/469]	Batch Time 0.712 (0.737)	Data Load Time 0.000 (0.008)	Loss 4.9255 (5.5053)	Top-5 Accuracy 51.216 (45.052)
Epoch: [0][400/469]	Batch Time 0.716 (0.732)	Data Load Time 0.000 (0.006)	Loss 4.8176 (5.3192)	Top-5 Accuracy 55.140 (47.666)
-------------------- Validation -----------

In [11]:
checkpoint = load_checkpoint('./BEST_checkpoint_flickr8k_5_cap_per_img_2_min_word_freq_resnet101_finetunefromthestart.pth.tar')

Loaded Checkpoint!!
Last Epoch: 4
Best Bleu-4: 0.13634096537364324


In [15]:
# parameters
epochs = 10  
batch_size = 64
workers = 4
encoder_lr = 1e-5 * .8
decoder_lr = 3e-4 * .8
fine_tune_encoder = True  
checkpoint = './BEST_checkpoint_flickr8k_5_cap_per_img_2_min_word_freq_resnet101_finetunefromthestart.pth.tar'  # path to checkpoint, None if none

In [16]:
t_params = {
    'data_name': DATA_NAME,
    'imgs_path': IMGS_PATH,
    'df_path': DATA_JSON_PATH,
    'vocab': vocab,
    'epochs': epochs,
    'batch_size': batch_size,
    'workers': workers,
    'decoder_lr': decoder_lr,
    'encoder_lr': encoder_lr,
    'fine_tune_encoder': fine_tune_encoder
}

In [17]:
t_params

{'data_name': 'flickr8k_5_cap_per_img_2_min_word_freq_resnet101_finetunefromthestart',
 'imgs_path': '../input/flickr8kimagescaptions/flickr8k/images/',
 'df_path': '/kaggle/working/Image-Captioning/data.json',
 'vocab': <dataset.Vocabulary at 0x7f6e3b0fcc90>,
 'epochs': 10,
 'batch_size': 64,
 'workers': 4,
 'decoder_lr': 0.00023999999999999998,
 'encoder_lr': 8.000000000000001e-06,
 'fine_tune_encoder': True}

In [18]:
fit(t_params, checkpoint, m_params)

Loaded Checkpoint!!
Starting Epoch: 5
Loading Data
Dataset split: train
Unique images: 6000
Total size: 30000
Dataset split: val
Unique images: 1000
Total size: 5000
__________________________________________________
-------------------- Fitting --------------------
__________________________________________________
-------------------- Training --------------------
Epoch: [5][0/469]	Batch Time 2.652 (2.652)	Data Load Time 1.821 (1.821)	Loss 3.3510 (3.3510)	Top-5 Accuracy 73.906 (73.906)
Epoch: [5][100/469]	Batch Time 0.713 (0.735)	Data Load Time 0.000 (0.018)	Loss 3.3715 (3.4215)	Top-5 Accuracy 73.708 (73.025)
Epoch: [5][200/469]	Batch Time 0.753 (0.727)	Data Load Time 0.000 (0.009)	Loss 3.2441 (3.4304)	Top-5 Accuracy 77.097 (73.028)
Epoch: [5][300/469]	Batch Time 0.720 (0.722)	Data Load Time 0.000 (0.006)	Loss 3.7493 (3.4386)	Top-5 Accuracy 69.679 (72.877)
Epoch: [5][400/469]	Batch Time 0.680 (0.720)	Data Load Time 0.000 (0.005)	Loss 3.2901 (3.4456)	Top-5 Accuracy 75.043 (72.777)
---

In [20]:
checkpoint = load_checkpoint('./BEST_checkpoint_flickr8k_5_cap_per_img_2_min_word_freq_resnet101_finetunefromthestart.pth.tar')

Loaded Checkpoint!!
Last Epoch: 6
Best Bleu-4: 0.13916614323948495


<a href='./BEST_checkpoint_flickr8k_5_cap_per_img_2_min_word_freq_resnet101_finetunefromthestart.pth.tar'>download</a>