In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import time 
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence
from models import Encoder, DecoderWithAttention
from dataset import *
from utils import *
from train import *
from nltk.translate.bleu_score import corpus_bleu

In [3]:
# Model parameters
encoder_dim = 2048 # resnet101
emb_dim = 512  # dimension of word embeddings
attention_dim = 512  # dimension of attention linear layers
decoder_dim = 512  # dimension of decoder RNN
dropout = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # sets device for model and PyTorch tensors
cudnn.benchmark = True  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead

# training parameters
epochs = 10  # number of epochs to train for (if early stopping is not triggered)
batch_size = 64
workers = 4
encoder_lr = 1e-4  # learning rate for encoder if fine-tuning
decoder_lr = 4e-4  # learning rate for decoder
fine_tune_encoder = False  # fine-tune encoder?
checkpoint = None  # path to checkpoint, None if none

In [4]:

DATA_NAME = 'flickr8k_5_cap_per_img_2_min_word_freq_resnet101'

# local
DATA_JSON_PATH = 'data.json'
IMGS_PATH = 'flickr/Images/'

In [5]:
# load vocab
vocab = build_vocab(DATA_JSON_PATH)

100%|██████████| 30000/30000 [00:00<00:00, 356318.14it/s]


In [6]:
len(vocab)

4451

In [7]:
t_params = {
    'data_name': DATA_NAME,
    'imgs_path': IMGS_PATH,
    'df_path': DATA_JSON_PATH,
    'vocab': vocab,
    'epochs': epochs,
    'batch_size': batch_size,
    'workers': workers,
    'decoder_lr': decoder_lr,
    'encoder_lr': encoder_lr,
    'fine_tune_encoder': fine_tune_encoder
}

m_params = {
    'attention_dim': attention_dim,
    'embed_dim': emb_dim,
    'decoder_dim': decoder_dim,
    'encoder_dim': encoder_dim,
    'dropout': dropout
}

epochs=100
fit(t_params=t_params, m_params=m_params)

Loading Data
Dataset split: train
Unique images: 6000
Total size: 30000
Dataset split: val
Unique images: 1000
Total size: 5000
__________________________________________________
-------------------- Fitting --------------------
Epoch: [0][0/469]	Batch Time 4.229 (4.229)	Data Load Time 1.813 (1.813)	Loss 9.3459 (9.3459)	Top-5 Accuracy 0.000 (0.000)
Epoch: [0][100/469]	Batch Time 0.935 (0.986)	Data Load Time 0.000 (0.018)	Loss 5.7318 (6.3204)	Top-5 Accuracy 39.776 (33.845)
Epoch: [0][200/469]	Batch Time 0.943 (0.970)	Data Load Time 0.000 (0.009)	Loss 5.2457 (5.8528)	Top-5 Accuracy 47.798 (39.861)
Epoch: [0][300/469]	Batch Time 1.038 (0.966)	Data Load Time 0.000 (0.006)	Loss 5.0713 (5.5781)	Top-5 Accuracy 51.399 (43.790)
Epoch: [0][400/469]	Batch Time 1.108 (0.975)	Data Load Time 0.000 (0.005)	Loss 4.8134 (5.3795)	Top-5 Accuracy 52.443 (46.636)
Validation: [0/79]	Batch Time 4.246 (4.246)	Loss 5.5030 (5.5030)	Top-5 Accuracy 50.759 (50.759)	

 * LOSS - 5.426, TOP-5 ACCURACY - 51.889, BLEU-

KeyboardInterrupt: 