In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import time 
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence
from models import Encoder, DecoderWithAttention
from dataset import *
from utils import *
from train import *
from nltk.translate.bleu_score import corpus_bleu

In [3]:
# Model parameters
encoder_dim = 2048 # resnet101
embed_dim = 512  # dimension of word embeddings
attention_dim = 512  # dimension of attention linear layers
decoder_dim = 512  # dimension of decoder RNN
dropout = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # sets device for model and PyTorch tensors
cudnn.benchmark = True  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead

# training parameters
epochs = 1  # number of epochs to train for (if early stopping is not triggered)
batch_size = 128
workers = 4
encoder_lr = 1e-4  # learning rate for encoder if fine-tuning
decoder_lr = 4e-4  # learning rate for decoder
fine_tune_encoder = False  # fine-tune encoder?
checkpoint = None  # path to checkpoint, None if none

In [4]:
DATA_NAME = 'flickr30k_5_cap_per_img_2_min_word_freq_resnet101'


# local 30k
DATA_JSON_PATH = 'data30.json'
IMGS_PATH = '/run/media/kelwa/DEV/data/flickr30k/Images/'

# local
# DATA_JSON_PATH = 'data.json'
# IMGS_PATH = 'flickr/Images/'
# kaggle paths
# DATA_JSON_PATH = '/kaggle/working/Image-Captioning/data30.json'
# IMGS_PATH = '../input/flickr30k/Images/'

In [5]:
# load vocab
vocab = build_vocab(DATA_JSON_PATH)

100%|██████████| 155070/155070 [00:00<00:00, 297615.10it/s]


In [6]:
len(vocab)

12096

In [7]:
epochs = 1

In [8]:
t_params = {
    'data_name': DATA_NAME,
    'imgs_path': IMGS_PATH,
    'df_path': DATA_JSON_PATH,
    'vocab': vocab,
    'epochs': epochs,
    'batch_size': batch_size,
    'workers': workers,
    'decoder_lr': decoder_lr,
    'encoder_lr': encoder_lr,
    'fine_tune_encoder': fine_tune_encoder
}

m_params = {
    'attention_dim': attention_dim,
    'embed_dim': embed_dim,
    'decoder_dim': decoder_dim,
    'encoder_dim': encoder_dim,
    'dropout': dropout
}

In [9]:
# dataloaders
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

In [40]:
_, val_loader = get_loaders(128, IMGS_PATH, DATA_JSON_PATH, transform, vocab)

Dataset split: train
Unique images: 29000
Total size: 145000
Dataset split: val
Unique images: 1014
Total size: 5070


In [41]:
encoder = Encoder().to(device)
decoder = DecoderWithAttention(attention_dim=attention_dim,
                                      embed_dim=embed_dim,
                                      decoder_dim=decoder_dim,
                                      encoder_dim=encoder_dim,
                                      vocab_size=len(vocab),
                                      dropout=dropout).to(device)

In [42]:
decoder_optimizer = torch.optim.RMSprop(params=filter(lambda p:p.requires_grad, decoder.parameters()),
                                    lr=decoder_lr)

encoder_optimizer = torch.optim.RMSprop(params=filter(lambda p:p.requires_grad, encoder.parameters()),
                                    lr=encoder_lr) if fine_tune_encoder else None

In [48]:
criterion = nn.CrossEntropyLoss().to(device)

In [56]:
bl4 = validate(val_loader, encoder, decoder, criterion, vocab)


Validation: [0/40]	Batch Time 2.644 (2.644)	Loss 10.3171 (10.3171)	Top-5 Accuracy 0.000 (0.000)	
----- Bleu-n Scores -----
1: 0
2: 0
3: 0
4: 0
-------------------------

 * LOSS - 10.317, TOP-5 ACCURACY - 0.000, BLEU-4 - 0



In [None]:
bl4

In [55]:
len(refs), len(refs[0]), len(refs[0][0])  

(128, 5, 22)

In [51]:
len(hyps), len(hyps[0])

(128, 14)

In [52]:
print_scores(refs, hyps)

----- Bleu-n Scores -----
1: 0
2: 0
3: 0
4: 0
-------------------------


In [18]:
targets.shape, scores.shape

(torch.Size([8]), torch.Size([8, 5089]))

In [19]:
[vocab.itos[i.item()] for i in targets]

['two', 'women', 'smile', '<unk>', 'at', 'the', 'viewer', '<eos>']

In [20]:
[vocab.itos[i.item()] for i in scores.argmax(1)]

['footbridge', 'toilet', 'pickup', 'hats', 'those', 'break', 'those', 'pickup']