In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
py_files_path = '../'
import sys
sys.path.append(py_files_path)

In [3]:
import time 
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence
from models import Encoder, DecoderWithAttention
from dataset import *
from utils import *
from nltk.translate.bleu_score import corpus_bleu

In [4]:
# Model parameters
enc_dim = 512
emb_dim = 256  # dimension of word embeddings
attention_dim = 256  # dimension of attention linear layers
decoder_dim = 256  # dimension of decoder RNN
dropout = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # sets device for model and PyTorch tensors
cudnn.benchmark = True  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead

# Training parameters
start_epoch = 0
epochs = 120  # number of epochs to train for (if early stopping is not triggered)
epochs_since_improvement = 0  # keeps track of number of epochs since there's been an improvement in validation BLEU
batch_size = 32
workers = 4
encoder_lr = 1e-4  # learning rate for encoder if fine-tuning
decoder_lr = 4e-4  # learning rate for decoder
grad_clip = 5.  # clip gradients at an absolute value of
alpha_c = 1.  # regularization parameter for 'doubly stochastic attention', as in the paper
best_bleu4 = 0.  # BLEU-4 score right now
print_freq = 100  # print training/validation stats every __ batches
fine_tune_encoder = False  # fine-tune encoder?
checkpoint = None  # path to checkpoint, None if none

In [5]:
# load vocab
vocab = build_vocab('../data.json')

100%|██████████| 30000/30000 [00:00<00:00, 348962.15it/s]


In [6]:
len(vocab)

4451

In [7]:
decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(vocab),
                                       encoder_dim=enc_dim,
                                       dropout=dropout)

decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)

encoder = Encoder()
encoder.fine_tune(fine_tune_encoder)
encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                     lr=encoder_lr) if fine_tune_encoder else None

In [8]:
# move to GPU, if available
decoder = decoder.to(device)
encoder = encoder.to(device)

In [9]:
# Loss function
criterion = nn.CrossEntropyLoss().to(device)

In [10]:
# Custom dataloaders
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_loader, val_loader = get_loaders(batch_size, '../flickr/Images/', '../data.json', transform, vocab, workers)

Dataset split: train
Unique images: 6000
Total size: 30000
Dataset split: val
Unique images: 1000
Total size: 5000


In [15]:
from train import train, validate

In [48]:
epochs = 1
start_epoch = 0


for epoch in range(start_epoch, epochs):
    train(train_loader=train_loader, 
          encoder=encoder,
          decoder=decoder,
          criterion=criterion,
          encoder_optimizer=encoder_optimizer,
          decoder_optimizer=decoder_optimizer,
          epoch=epoch
         )

2it [00:02,  1.14s/it]

Epoch: [0][0/938]	Batch Time 2.660 (2.660)	Data Load Time 2.489 (2.489)	Loss 6.4553 (6.4553)	Top-5 Accuracy 26.168 (26.168)


102it [00:23,  6.13it/s]

Epoch: [0][100/938]	Batch Time 0.150 (0.233)	Data Load Time 0.002 (0.081)	Loss 5.9524 (6.0867)	Top-5 Accuracy 35.644 (33.742)


202it [00:40,  6.08it/s]

Epoch: [0][200/938]	Batch Time 0.153 (0.204)	Data Load Time 0.001 (0.046)	Loss 5.4442 (5.9013)	Top-5 Accuracy 44.816 (36.508)


302it [01:00,  6.06it/s]

Epoch: [0][300/938]	Batch Time 0.152 (0.200)	Data Load Time 0.001 (0.039)	Loss 5.2968 (5.7479)	Top-5 Accuracy 50.633 (39.434)


402it [01:16,  6.35it/s]

Epoch: [0][400/938]	Batch Time 0.154 (0.191)	Data Load Time 0.001 (0.030)	Loss 5.3618 (5.6247)	Top-5 Accuracy 45.231 (41.541)


502it [01:33,  5.64it/s]

Epoch: [0][500/938]	Batch Time 0.223 (0.188)	Data Load Time 0.001 (0.024)	Loss 5.0540 (5.5220)	Top-5 Accuracy 51.652 (43.261)


602it [01:50,  6.32it/s]

Epoch: [0][600/938]	Batch Time 0.147 (0.183)	Data Load Time 0.001 (0.021)	Loss 5.0498 (5.4323)	Top-5 Accuracy 48.936 (44.673)


701it [02:06,  6.25it/s]

Epoch: [0][700/938]	Batch Time 0.153 (0.180)	Data Load Time 0.001 (0.018)	Loss 4.8629 (5.3619)	Top-5 Accuracy 56.040 (45.785)


802it [02:22,  6.30it/s]

Epoch: [0][800/938]	Batch Time 0.163 (0.178)	Data Load Time 0.001 (0.016)	Loss 4.8263 (5.3007)	Top-5 Accuracy 53.822 (46.738)


902it [02:38,  6.13it/s]

Epoch: [0][900/938]	Batch Time 0.162 (0.176)	Data Load Time 0.001 (0.014)	Loss 4.5594 (5.2471)	Top-5 Accuracy 58.632 (47.555)


938it [02:44,  5.69it/s]


In [75]:
for epoch in range(start_epoch, epochs):
    recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion)

Validation: [0/157]	Batch Time 0.875 (0.875)	Loss 5.5678 (5.5678)	Top-5 Accuracy 53.430 (53.430)	
Validation: [100/157]	Batch Time 0.169 (0.140)	Loss 5.5231 (5.6358)	Top-5 Accuracy 50.303 (49.587)	

 * LOSS - 5.613, TOP-5 ACCURACY - 49.821, BLEU-4 - 0.004498661142443737



In [11]:
IMG_DIR = '../flickr/Images/'
CAP_FILE = '../data.json'

In [70]:
ds = CaptionDataset(IMG_DIR, CAP_FILE, vocab, transform, split='val')

Dataset split: val
Unique images: 1000
Total size: 5000


In [71]:
img, _, _, all_toks= ds[0]

In [73]:
all_toks.shape

torch.Size([5, 38])