In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
py_files_path = '../'
import sys
sys.path.append(py_files_path)

In [3]:
import time 
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence
from models import Encoder, DecoderWithAttention
from dataset import *
from utils import *
from nltk.translate.bleu_score import corpus_bleu

In [64]:
# Model parameters
enc_dim = 512
emb_dim = 256  # dimension of word embeddings
attention_dim = 256  # dimension of attention linear layers
decoder_dim = 256  # dimension of decoder RNN
dropout = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # sets device for model and PyTorch tensors
cudnn.benchmark = True  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead

# Training parameters
start_epoch = 0
epochs = 120  # number of epochs to train for (if early stopping is not triggered)
epochs_since_improvement = 0  # keeps track of number of epochs since there's been an improvement in validation BLEU
batch_size = 32
workers = 4
encoder_lr = 1e-4  # learning rate for encoder if fine-tuning
decoder_lr = 4e-4  # learning rate for decoder
grad_clip = 5.  # clip gradients at an absolute value of
alpha_c = 1.  # regularization parameter for 'doubly stochastic attention', as in the paper
best_bleu4 = 0.  # BLEU-4 score right now
print_freq = 100  # print training/validation stats every __ batches
fine_tune_encoder = False  # fine-tune encoder?
checkpoint = None  # path to checkpoint, None if none

In [65]:
# load vocab
vocab = build_vocab('../data.json')

100%|██████████| 30000/30000 [00:00<00:00, 324992.88it/s]


In [66]:
len(vocab)

4451

In [76]:
decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(vocab),
                                       encoder_dim=enc_dim,
                                       dropout=dropout)

decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)

encoder = Encoder()
encoder.fine_tune(fine_tune_encoder)
encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                     lr=encoder_lr) if fine_tune_encoder else None

In [77]:
# move to GPU, if available
decoder = decoder.to(device)
encoder = encoder.to(device)

In [78]:
# Loss function
criterion = nn.CrossEntropyLoss().to(device)

In [79]:
# transforms 
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_loader, valid_loader = get_loaders(batch_size, '../flickr/Images/','../data.json', transform, vocab, workers)

Dataset split: train
Unique images: 6000
Total size: 30000
Dataset split: val
Unique images: 1000
Total size: 5000


In [37]:
img, caption, cap_len = next(iter(train_loader))
img.shape, caption.shape, cap_len.shape

(torch.Size([32, 3, 224, 224]), torch.Size([32, 39]), torch.Size([32]))

In [38]:
img, caption, cap_len, all_tokens = next(iter(valid_loader))
img.shape, caption.shape, cap_len.shape

(torch.Size([32, 3, 224, 224]), torch.Size([32, 39]), torch.Size([32]))

In [72]:
enc_out = encoder(img.to(device))
enc_out.shape

torch.Size([32, 14, 14, 512])

In [80]:
decoder

DecoderWithAttention(
  (attention): Attention(
    (encoder_att): Linear(in_features=512, out_features=256, bias=True)
    (decoder_att): Linear(in_features=256, out_features=256, bias=True)
    (full_att): Linear(in_features=256, out_features=1, bias=True)
    (relu): ReLU()
    (softmax): Softmax(dim=1)
  )
  (embedding): Embedding(4451, 256)
  (dropout): Dropout(p=0.5, inplace=False)
  (decode_step): LSTMCell(768, 256)
  (init_h): Linear(in_features=512, out_features=256, bias=True)
  (init_c): Linear(in_features=512, out_features=256, bias=True)
  (f_beta): Linear(in_features=256, out_features=512, bias=True)
  (sigmoid): Sigmoid()
  (fc): Linear(in_features=256, out_features=4451, bias=True)
)

In [81]:
enc_out.shape

torch.Size([32, 14, 14, 512])

In [86]:
preds, encoded_caps, decoded_lengths, alphas, sort_id = decoder(enc_out, caption.to(device), cap_len.unsqueeze(1).to(device))

torch.Size([32, 196, 512])
torch.Size([32, 512])


In [89]:
preds.shape, encoded_caps.shape, alphas.shape

(torch.Size([32, 16, 4451]), torch.Size([32, 39]), torch.Size([32, 16, 196]))

In [95]:
alphas[0, 0, :].sum()

tensor(1., device='cuda:0', grad_fn=<SumBackward0>)

In [None]:
def train(train_loader, encoder, deocder, criterion, encoder_optimizer, decoder_optimizer, epoch):
    decoder.train()
    encoder.train()
    
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top5acc = AverageMeter()
    
    start = time.time()
    
    for i, (imgs, caps, caplens) in enumerate(train_loader):
        