In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

from torchvision import models
from torchvision import transforms

import os
import time

import pandas as pd
from PIL import Image

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [3]:
class Dictionary(object):
    def __init__(self):
        self.idx2word = []
        self.word2idx = {}
        
    def add_word(self, word):
        if word not in self.idx2word:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1
        return self.word2idx[word]
    
    def __len__(self):
        return len(self.idx2word)

class Corpus(Dataset):
    def __init__(self, caption_path, bsz = 32, transform = None):
        self.caption_path = caption_path
        
        self.dictionary = Dictionary()
        self.captions = self.tokenize()
        self.keys = list(self.captions)
        self.bsz = bsz
        self.num_batches = self.__len__() // self.bsz
        
    def tokenize(self):
        captions_csv = pd.read_csv(os.path.join(self.caption_path, 'cleaned_captions.csv'))
        
        for i in range(len(captions_csv)):
            line = captions_csv.iloc[i, 1]
            words = ['<start>'] + line.lower().split() + ['<eos>']
            for word in words:
                self.dictionary.add_word(word)
                
        path2cap = {}
        for i in range(len(captions_csv)):
            line = captions_csv.iloc[i, 1]
            words = ['<start>'] + line.lower().split() + ['<eos>']
            ids = []
            for word in words:
                ids.append(self.dictionary.word2idx[word])
            path2cap[captions_csv.iloc[i, 0]] = ids
        
        return path2cap

    def __len__(self):
        return len(self.captions)
    
    def __getitem__(self, idx):
        transform = transforms.ToTensor()
        normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        
        img_paths = self.keys[idx*self.bsz : (idx+1) * self.bsz]
        batch_captions = []
        for path in img_paths:
            batch_captions.append(torch.Tensor(self.captions[path]).type(torch.LongTensor))
        
        img_tensor = torch.zeros(0, 3, 224, 224)
        for path in img_paths:
            img = Image.open(os.path.join(os.path.join(self.caption_path, 'Flickr8k_Dataset'), path))
            img = img.resize((224, 224))
            img = transform(img)
            img = normalize(img)
            img.unsqueeze_(0)
            
            img_tensor = torch.cat((img_tensor, img), 0)
            
        return batch_captions, img_tensor 
 






In [4]:
corpus = Corpus('../data')

In [5]:
class RNN(nn.Module):
    def __init__(self, ntokens, ninp, nhid, dropout = 0.5):
        super(RNN, self).__init__()
        
        self.ntokens = ntokens
        self.nhid = nhid
        
        self.drop = nn.Dropout(dropout)
        
        self.encode = nn.Embedding(ntokens, ninp)
        self.rnn = nn.RNN(ninp, nhid)
        self.decode = nn.Linear(nhid, ntokens)
        
    def forward(self, input, h):
        emb = self.drop(self.encode(input))
        emb.unsqueeze_(1)
        output, hidden = self.rnn(emb, h)
        decoded = self.decode(output)
        decoded = decoded.view(-1, self.ntokens)
        
        return F.log_softmax(decoded, dim = 1), hidden
    
    def init_hidden(self):
        return torch.zeros(1, 1, self.nhid)

In [6]:
cnn = models.vgg11(pretrained = True).to(device)

In [7]:
rnn = RNN(len(corpus.dictionary), 200, 1000).to(device)

In [8]:
rnn.load_state_dict(torch.load('../multimodal.pt', map_location = torch.device('cpu')))

<All keys matched successfully>

In [9]:
cnn.eval()


VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (14): ReLU(inplace=True)
    (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
 

In [10]:
rnn.eval()

RNN(
  (drop): Dropout(p=0.5, inplace=False)
  (encode): Embedding(4528, 200)
  (rnn): RNN(200, 1000)
  (decode): Linear(in_features=1000, out_features=4528, bias=True)
)

In [12]:
def load_img(img_name):
    transform = transforms.ToTensor()
    normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    img = Image.open(os.path.join('../data/Flickr8k_Dataset', img_name))
    img = img.resize((224, 224))
    img = transform(img)
    img = normalize(img)
    img.unsqueeze_(0)

    return img

In [36]:
def generate(img_name):
    rnn.eval()
    cnn.eval()
    
    out_sen = []
    
    img = load_img(img_name)
    hidden = rnn.init_hidden()
    
    cnn_output = cnn(img)
    
    cnn_output.squeeze_(0)
    tmp = cnn_output.detach().numpy()
    bhi = rnn.rnn.bias_ih_l0.detach().numpy()
    bhi[:] = tmp
    
    input = torch.randint(len(corpus.dictionary), (1,), dtype = torch.long)
    
    with torch.no_grad():
        for i in range(10):
            output, hidden = rnn(input, hidden)
            
            word_weights = output.squeeze().div(1.0).exp().cpu()
            word_idx = torch.multinomial(word_weights, 1)[0]
            
            input.fill_(word_idx)
            
            word = corpus.dictionary.idx2word[word_idx]
            out_sen.append(word)

    print(out_sen)        
    return out_sen

In [37]:
out = generate('58363930_0544844edd.jpg')

['attempts', 'attempts', 'attempts', 'morning', 'stuff', 'stuff', 'barrels', 'marking', 'athlete', 'dogs']


In [38]:
print(out)

['attempts', 'attempts', 'attempts', 'morning', 'stuff', 'stuff', 'barrels', 'marking', 'athlete', 'dogs']
