In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

from torchvision import models
from torchvision import transforms

import os
import time

import pandas as pd
from PIL import Image

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [3]:
class Dictionary(object):
    def __init__(self):
        self.idx2word = []
        self.word2idx = {}
        
    def add_word(self, word):
        if word not in self.idx2word:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1
        return self.word2idx[word]
    
    def __len__(self):
        return len(self.idx2word)

class Corpus(Dataset):
    def __init__(self, caption_path, bsz = 32, transform = None):
        self.caption_path = caption_path
        
        self.dictionary = Dictionary()
        self.captions = self.tokenize()
        self.keys = list(self.captions)
        self.bsz = bsz
        self.num_batches = self.__len__() // self.bsz
        
    def tokenize(self):
        captions_csv = pd.read_csv(os.path.join(self.caption_path, 'cleaned_captions.csv'))
        
        for i in range(len(captions_csv)):
            line = captions_csv.iloc[i, 1]
            words = ['<start>'] + line.lower().split() + ['<eos>']
            for word in words:
                self.dictionary.add_word(word)
                
        path2cap = {}
        for i in range(len(captions_csv)):
            line = captions_csv.iloc[i, 1]
            words = ['<start>'] + line.lower().split() + ['<eos>']
            ids = []
            for word in words:
                ids.append(self.dictionary.word2idx[word])
            path2cap[captions_csv.iloc[i, 0]] = ids
        
        return path2cap

    def __len__(self):
        return len(self.captions)
    
    def __getitem__(self, idx):
        transform = transforms.ToTensor()
        normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        
        img_paths = self.keys[idx*self.bsz : (idx+1) * self.bsz]
        batch_captions = []
        for path in img_paths:
            batch_captions.append(torch.Tensor(self.captions[path]).type(torch.LongTensor))
        
        img_tensor = torch.zeros(0, 3, 224, 224)
        for path in img_paths:
            img = Image.open(os.path.join(os.path.join(self.caption_path, 'Flickr8k_Dataset'), path))
            img = img.resize((224, 224))
            img = transform(img)
            img = normalize(img)
            img.unsqueeze_(0)
            
            img_tensor = torch.cat((img_tensor, img), 0)
            
        return batch_captions, img_tensor 
 






In [4]:
corpus = Corpus('../data')

In [5]:
class RNN(nn.Module):
    def __init__(self, ntokens, ninp, nhid, dropout = 0.5):
        super(RNN, self).__init__()
        
        self.ntokens = ntokens
        self.nhid = nhid
        
        self.drop = nn.Dropout(dropout)
        
        self.encode = nn.Embedding(ntokens, ninp)
        self.rnn = nn.RNN(ninp, nhid)
        self.decode = nn.Linear(nhid, ntokens)
        
    def forward(self, input, h):
        emb = self.drop(self.encode(input))
        emb.unsqueeze_(1)
        output, hidden = self.rnn(emb, h)
        decoded = self.decode(output)
        decoded = decoded.view(-1, self.ntokens)
        
        return F.log_softmax(decoded, dim = 1), hidden
    
    def init_hidden(self):
        return torch.zeros(1, 1, self.nhid)

In [6]:
batch_size = 32
epochs = 20
ntokens = len(corpus.dictionary)
criterion = nn.NLLLoss()

In [7]:
def repackage_hidden(h):
    if isinstance(h, torch.Tensor):
        return h.detach()
    else:
        return repackage_hidden(h)

In [8]:
cnn = models.vgg11(pretrained = True).to(device)

In [9]:
rnn = RNN(ntokens, 200, 1000).to(device)
optimizer = torch.optim.RMSprop(rnn.parameters(), lr = 0.005)

In [10]:
def train():
    cnn.eval()
    rnn.train()
    average_loss = 0.
    start_time = time.time()
    ntokens = len(corpus.dictionary)
    
    hidden = rnn.init_hidden()
    
    for i in range(corpus.num_batches):
        captions, imgs = corpus[i]
        loss = 0
        for batch in range(batch_size):
            caption = captions[batch]
            
            optimizer.zero_grad()
            hidden = repackage_hidden(hidden)
            
            img = imgs[batch]
            img.unsqueeze_(0)
            
            cnn_output = cnn(img)
            
            # Bias interaction between CNN and RNN
            cnn_output.squeeze_(0)
            tmp = cnn_output.detach().numpy()
            
            bhi = rnn.rnn.bias_ih_l0.detach().numpy()
            bhi[:] = tmp
            
            output, hidden = rnn(caption, hidden)
            
            l = criterion(output, caption)
            loss += l
            
        
        loss.backward()
        optimizer.step()
        
        average_loss = loss / batch_size
        if i % 50 == 0 and i > 0:
            elapsed = time.time() - start_time
            print('| epoch {} | {}/{} batches | s/batch {} | loss {} |'.format(epoch, i, corpus.num_batches, elapsed / 50, average_loss))
            start_time = time.time()

In [11]:
try:
    for epoch in range(20):
        train()
    print('-'*89)
    print('Training complete')
except KeyboardInterrupt:
    print('-'*89)
    print('Exiting from training early')

-----------------------------------------------------------------------------------------
Exiting from training early
