In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

from torchvision import models
from torchvision import transforms

import os
import time

import pandas as pd
from PIL import Image

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [3]:
class Dictionary(object):
    def __init__(self):
        self.idx2word = []
        self.word2idx = {}
        self.add_word('<pad>')
        
    def add_word(self, word):
        if word not in self.idx2word:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1
        return self.word2idx[word]
    
    def __len__(self):
        return len(self.idx2word)

class Corpus(Dataset):
    def __init__(self, caption_path, bsz = 32, transform = None):
        self.caption_path = caption_path
        
        self.dictionary = Dictionary()
        self.captions = self.tokenize()
        self.keys = list(self.captions)
        self.bsz = bsz
        self.num_batches = self.__len__() // self.bsz
        
    def tokenize(self):
        captions_csv = pd.read_csv(os.path.join(self.caption_path, 'cleaned_captions.csv'))
        
        for i in range(len(captions_csv)):
            line = captions_csv.iloc[i, 1]
            words = line.lower().split() + ['<eos>']
            for word in words:
                self.dictionary.add_word(word)
                
        path2cap = {}
        for i in range(len(captions_csv)):
            line = captions_csv.iloc[i, 1]
            words = line.lower().split() + ['<eos>']
            ids = []
            for word in words:
                ids.append(self.dictionary.word2idx[word])
            path2cap[captions_csv.iloc[i, 0]] = ids
        
        return path2cap

    def __len__(self):
        return len(self.captions)
    
    def __getitem__(self, idx):
        transform = transforms.ToTensor()
        normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        
        img_paths = self.keys[idx*self.bsz : (idx+1) * self.bsz]
        batch_captions = []
        batch_targets = []
        for path in img_paths:
            tmp_caps = self.captions[path]
            batch_captions.append(torch.Tensor(tmp_caps).type(torch.LongTensor))
            batch_targets.append(torch.Tensor(tmp_caps).type(torch.LongTensor))

        padded_captions = self.pad(batch_captions)
        padded_targets = self.pad(batch_targets)
        
        img_tensor = torch.zeros(0, 3, 224, 224)
        for path in img_paths:
            img = Image.open(os.path.join(os.path.join(self.caption_path, 'Flickr8k_Dataset'), path))
            img = img.resize((224, 224))
            img = transform(img)
            img = normalize(img)
            img.unsqueeze_(0)
            
            img_tensor = torch.cat((img_tensor, img), 0)
            
        return padded_captions, padded_targets[0].view(-1), img_tensor 
      
    def pad(self, batch):
      batch.sort(reverse = True, key = len)
      lens = list([len(x) for x in batch])

      padded = pad_sequence(batch, batch_first = True)

      return (padded, lens)

In [4]:
corpus = Corpus('data')

In [5]:
class Encoder(nn.Module):
  def __init__(self, encode_size):
    super(Encoder, self).__init__()

    self.resnet = models.resnet18(pretrained = False, progress = True)

    self.resnet.fc = nn.Linear(512, encode_size)
    
  def forward(self, img):
    output = torch.tanh(self.resnet(img))
    return output


In [6]:
class Decoder(nn.Module):
  def __init__(self, ntokens, ninp, nhid, dropout = 0.5):
    super(Decoder, self).__init__()

    self.ntokens = ntokens
    self.nhid = nhid
    
    self.drop = nn.Dropout(dropout)
    self.embed = nn.Embedding(ntokens, ninp)
    self.LSTM = nn.LSTM(ninp, nhid)
    self.decode = nn.Linear(nhid, ntokens)

  def forward(self, input, h, t):    
    if t > 0:
        input = self.embed(input)

    out, hidden = self.LSTM(input, h)

    out = self.drop(out)

    decoded = self.decode(out)
    decoded = decoded.view(-1, self.ntokens)

    return F.log_softmax(decoded, dim = 1), hidden

  def init_hidden(self, bsz):
    return (torch.zeros(1, bsz, self.nhid).to(device), torch.zeros(1, bsz, self.nhid).to(device))

In [9]:
encoder = Encoder(200).to(device)
decoder = Decoder(len(corpus.dictionary), 200, 256).to(device)

In [10]:
encoder.load_state_dict(torch.load('encoder.pt', map_location='cpu'))
decoder.load_state_dict(torch.load('decoder.pt', map_location = 'cpu'))

RuntimeError: Error(s) in loading state_dict for Encoder:
	Missing key(s) in state_dict: "resnet.conv1.weight", "resnet.bn1.weight", "resnet.bn1.bias", "resnet.bn1.running_mean", "resnet.bn1.running_var", "resnet.layer1.0.conv1.weight", "resnet.layer1.0.bn1.weight", "resnet.layer1.0.bn1.bias", "resnet.layer1.0.bn1.running_mean", "resnet.layer1.0.bn1.running_var", "resnet.layer1.0.conv2.weight", "resnet.layer1.0.bn2.weight", "resnet.layer1.0.bn2.bias", "resnet.layer1.0.bn2.running_mean", "resnet.layer1.0.bn2.running_var", "resnet.layer1.1.conv1.weight", "resnet.layer1.1.bn1.weight", "resnet.layer1.1.bn1.bias", "resnet.layer1.1.bn1.running_mean", "resnet.layer1.1.bn1.running_var", "resnet.layer1.1.conv2.weight", "resnet.layer1.1.bn2.weight", "resnet.layer1.1.bn2.bias", "resnet.layer1.1.bn2.running_mean", "resnet.layer1.1.bn2.running_var", "resnet.layer2.0.conv1.weight", "resnet.layer2.0.bn1.weight", "resnet.layer2.0.bn1.bias", "resnet.layer2.0.bn1.running_mean", "resnet.layer2.0.bn1.running_var", "resnet.layer2.0.conv2.weight", "resnet.layer2.0.bn2.weight", "resnet.layer2.0.bn2.bias", "resnet.layer2.0.bn2.running_mean", "resnet.layer2.0.bn2.running_var", "resnet.layer2.0.downsample.0.weight", "resnet.layer2.0.downsample.1.weight", "resnet.layer2.0.downsample.1.bias", "resnet.layer2.0.downsample.1.running_mean", "resnet.layer2.0.downsample.1.running_var", "resnet.layer2.1.conv1.weight", "resnet.layer2.1.bn1.weight", "resnet.layer2.1.bn1.bias", "resnet.layer2.1.bn1.running_mean", "resnet.layer2.1.bn1.running_var", "resnet.layer2.1.conv2.weight", "resnet.layer2.1.bn2.weight", "resnet.layer2.1.bn2.bias", "resnet.layer2.1.bn2.running_mean", "resnet.layer2.1.bn2.running_var", "resnet.layer3.0.conv1.weight", "resnet.layer3.0.bn1.weight", "resnet.layer3.0.bn1.bias", "resnet.layer3.0.bn1.running_mean", "resnet.layer3.0.bn1.running_var", "resnet.layer3.0.conv2.weight", "resnet.layer3.0.bn2.weight", "resnet.layer3.0.bn2.bias", "resnet.layer3.0.bn2.running_mean", "resnet.layer3.0.bn2.running_var", "resnet.layer3.0.downsample.0.weight", "resnet.layer3.0.downsample.1.weight", "resnet.layer3.0.downsample.1.bias", "resnet.layer3.0.downsample.1.running_mean", "resnet.layer3.0.downsample.1.running_var", "resnet.layer3.1.conv1.weight", "resnet.layer3.1.bn1.weight", "resnet.layer3.1.bn1.bias", "resnet.layer3.1.bn1.running_mean", "resnet.layer3.1.bn1.running_var", "resnet.layer3.1.conv2.weight", "resnet.layer3.1.bn2.weight", "resnet.layer3.1.bn2.bias", "resnet.layer3.1.bn2.running_mean", "resnet.layer3.1.bn2.running_var", "resnet.layer4.0.conv1.weight", "resnet.layer4.0.bn1.weight", "resnet.layer4.0.bn1.bias", "resnet.layer4.0.bn1.running_mean", "resnet.layer4.0.bn1.running_var", "resnet.layer4.0.conv2.weight", "resnet.layer4.0.bn2.weight", "resnet.layer4.0.bn2.bias", "resnet.layer4.0.bn2.running_mean", "resnet.layer4.0.bn2.running_var", "resnet.layer4.0.downsample.0.weight", "resnet.layer4.0.downsample.1.weight", "resnet.layer4.0.downsample.1.bias", "resnet.layer4.0.downsample.1.running_mean", "resnet.layer4.0.downsample.1.running_var", "resnet.layer4.1.conv1.weight", "resnet.layer4.1.bn1.weight", "resnet.layer4.1.bn1.bias", "resnet.layer4.1.bn1.running_mean", "resnet.layer4.1.bn1.running_var", "resnet.layer4.1.conv2.weight", "resnet.layer4.1.bn2.weight", "resnet.layer4.1.bn2.bias", "resnet.layer4.1.bn2.running_mean", "resnet.layer4.1.bn2.running_var", "resnet.fc.weight", "resnet.fc.bias". 
	Unexpected key(s) in state_dict: "resnet.0.weight", "resnet.1.weight", "resnet.1.bias", "resnet.1.running_mean", "resnet.1.running_var", "resnet.1.num_batches_tracked", "resnet.4.0.conv1.weight", "resnet.4.0.bn1.weight", "resnet.4.0.bn1.bias", "resnet.4.0.bn1.running_mean", "resnet.4.0.bn1.running_var", "resnet.4.0.bn1.num_batches_tracked", "resnet.4.0.conv2.weight", "resnet.4.0.bn2.weight", "resnet.4.0.bn2.bias", "resnet.4.0.bn2.running_mean", "resnet.4.0.bn2.running_var", "resnet.4.0.bn2.num_batches_tracked", "resnet.4.0.conv3.weight", "resnet.4.0.bn3.weight", "resnet.4.0.bn3.bias", "resnet.4.0.bn3.running_mean", "resnet.4.0.bn3.running_var", "resnet.4.0.bn3.num_batches_tracked", "resnet.4.0.downsample.0.weight", "resnet.4.0.downsample.1.weight", "resnet.4.0.downsample.1.bias", "resnet.4.0.downsample.1.running_mean", "resnet.4.0.downsample.1.running_var", "resnet.4.0.downsample.1.num_batches_tracked", "resnet.4.1.conv1.weight", "resnet.4.1.bn1.weight", "resnet.4.1.bn1.bias", "resnet.4.1.bn1.running_mean", "resnet.4.1.bn1.running_var", "resnet.4.1.bn1.num_batches_tracked", "resnet.4.1.conv2.weight", "resnet.4.1.bn2.weight", "resnet.4.1.bn2.bias", "resnet.4.1.bn2.running_mean", "resnet.4.1.bn2.running_var", "resnet.4.1.bn2.num_batches_tracked", "resnet.4.1.conv3.weight", "resnet.4.1.bn3.weight", "resnet.4.1.bn3.bias", "resnet.4.1.bn3.running_mean", "resnet.4.1.bn3.running_var", "resnet.4.1.bn3.num_batches_tracked", "resnet.4.2.conv1.weight", "resnet.4.2.bn1.weight", "resnet.4.2.bn1.bias", "resnet.4.2.bn1.running_mean", "resnet.4.2.bn1.running_var", "resnet.4.2.bn1.num_batches_tracked", "resnet.4.2.conv2.weight", "resnet.4.2.bn2.weight", "resnet.4.2.bn2.bias", "resnet.4.2.bn2.running_mean", "resnet.4.2.bn2.running_var", "resnet.4.2.bn2.num_batches_tracked", "resnet.4.2.conv3.weight", "resnet.4.2.bn3.weight", "resnet.4.2.bn3.bias", "resnet.4.2.bn3.running_mean", "resnet.4.2.bn3.running_var", "resnet.4.2.bn3.num_batches_tracked", "resnet.5.0.conv1.weight", "resnet.5.0.bn1.weight", "resnet.5.0.bn1.bias", "resnet.5.0.bn1.running_mean", "resnet.5.0.bn1.running_var", "resnet.5.0.bn1.num_batches_tracked", "resnet.5.0.conv2.weight", "resnet.5.0.bn2.weight", "resnet.5.0.bn2.bias", "resnet.5.0.bn2.running_mean", "resnet.5.0.bn2.running_var", "resnet.5.0.bn2.num_batches_tracked", "resnet.5.0.conv3.weight", "resnet.5.0.bn3.weight", "resnet.5.0.bn3.bias", "resnet.5.0.bn3.running_mean", "resnet.5.0.bn3.running_var", "resnet.5.0.bn3.num_batches_tracked", "resnet.5.0.downsample.0.weight", "resnet.5.0.downsample.1.weight", "resnet.5.0.downsample.1.bias", "resnet.5.0.downsample.1.running_mean", "resnet.5.0.downsample.1.running_var", "resnet.5.0.downsample.1.num_batches_tracked", "resnet.5.1.conv1.weight", "resnet.5.1.bn1.weight", "resnet.5.1.bn1.bias", "resnet.5.1.bn1.running_mean", "resnet.5.1.bn1.running_var", "resnet.5.1.bn1.num_batches_tracked", "resnet.5.1.conv2.weight", "resnet.5.1.bn2.weight", "resnet.5.1.bn2.bias", "resnet.5.1.bn2.running_mean", "resnet.5.1.bn2.running_var", "resnet.5.1.bn2.num_batches_tracked", "resnet.5.1.conv3.weight", "resnet.5.1.bn3.weight", "resnet.5.1.bn3.bias", "resnet.5.1.bn3.running_mean", "resnet.5.1.bn3.running_var", "resnet.5.1.bn3.num_batches_tracked", "resnet.5.2.conv1.weight", "resnet.5.2.bn1.weight", "resnet.5.2.bn1.bias", "resnet.5.2.bn1.running_mean", "resnet.5.2.bn1.running_var", "resnet.5.2.bn1.num_batches_tracked", "resnet.5.2.conv2.weight", "resnet.5.2.bn2.weight", "resnet.5.2.bn2.bias", "resnet.5.2.bn2.running_mean", "resnet.5.2.bn2.running_var", "resnet.5.2.bn2.num_batches_tracked", "resnet.5.2.conv3.weight", "resnet.5.2.bn3.weight", "resnet.5.2.bn3.bias", "resnet.5.2.bn3.running_mean", "resnet.5.2.bn3.running_var", "resnet.5.2.bn3.num_batches_tracked", "resnet.5.3.conv1.weight", "resnet.5.3.bn1.weight", "resnet.5.3.bn1.bias", "resnet.5.3.bn1.running_mean", "resnet.5.3.bn1.running_var", "resnet.5.3.bn1.num_batches_tracked", "resnet.5.3.conv2.weight", "resnet.5.3.bn2.weight", "resnet.5.3.bn2.bias", "resnet.5.3.bn2.running_mean", "resnet.5.3.bn2.running_var", "resnet.5.3.bn2.num_batches_tracked", "resnet.5.3.conv3.weight", "resnet.5.3.bn3.weight", "resnet.5.3.bn3.bias", "resnet.5.3.bn3.running_mean", "resnet.5.3.bn3.running_var", "resnet.5.3.bn3.num_batches_tracked", "resnet.6.0.conv1.weight", "resnet.6.0.bn1.weight", "resnet.6.0.bn1.bias", "resnet.6.0.bn1.running_mean", "resnet.6.0.bn1.running_var", "resnet.6.0.bn1.num_batches_tracked", "resnet.6.0.conv2.weight", "resnet.6.0.bn2.weight", "resnet.6.0.bn2.bias", "resnet.6.0.bn2.running_mean", "resnet.6.0.bn2.running_var", "resnet.6.0.bn2.num_batches_tracked", "resnet.6.0.conv3.weight", "resnet.6.0.bn3.weight", "resnet.6.0.bn3.bias", "resnet.6.0.bn3.running_mean", "resnet.6.0.bn3.running_var", "resnet.6.0.bn3.num_batches_tracked", "resnet.6.0.downsample.0.weight", "resnet.6.0.downsample.1.weight", "resnet.6.0.downsample.1.bias", "resnet.6.0.downsample.1.running_mean", "resnet.6.0.downsample.1.running_var", "resnet.6.0.downsample.1.num_batches_tracked", "resnet.6.1.conv1.weight", "resnet.6.1.bn1.weight", "resnet.6.1.bn1.bias", "resnet.6.1.bn1.running_mean", "resnet.6.1.bn1.running_var", "resnet.6.1.bn1.num_batches_tracked", "resnet.6.1.conv2.weight", "resnet.6.1.bn2.weight", "resnet.6.1.bn2.bias", "resnet.6.1.bn2.running_mean", "resnet.6.1.bn2.running_var", "resnet.6.1.bn2.num_batches_tracked", "resnet.6.1.conv3.weight", "resnet.6.1.bn3.weight", "resnet.6.1.bn3.bias", "resnet.6.1.bn3.running_mean", "resnet.6.1.bn3.running_var", "resnet.6.1.bn3.num_batches_tracked", "resnet.6.2.conv1.weight", "resnet.6.2.bn1.weight", "resnet.6.2.bn1.bias", "resnet.6.2.bn1.running_mean", "resnet.6.2.bn1.running_var", "resnet.6.2.bn1.num_batches_tracked", "resnet.6.2.conv2.weight", "resnet.6.2.bn2.weight", "resnet.6.2.bn2.bias", "resnet.6.2.bn2.running_mean", "resnet.6.2.bn2.running_var", "resnet.6.2.bn2.num_batches_tracked", "resnet.6.2.conv3.weight", "resnet.6.2.bn3.weight", "resnet.6.2.bn3.bias", "resnet.6.2.bn3.running_mean", "resnet.6.2.bn3.running_var", "resnet.6.2.bn3.num_batches_tracked", "resnet.6.3.conv1.weight", "resnet.6.3.bn1.weight", "resnet.6.3.bn1.bias", "resnet.6.3.bn1.running_mean", "resnet.6.3.bn1.running_var", "resnet.6.3.bn1.num_batches_tracked", "resnet.6.3.conv2.weight", "resnet.6.3.bn2.weight", "resnet.6.3.bn2.bias", "resnet.6.3.bn2.running_mean", "resnet.6.3.bn2.running_var", "resnet.6.3.bn2.num_batches_tracked", "resnet.6.3.conv3.weight", "resnet.6.3.bn3.weight", "resnet.6.3.bn3.bias", "resnet.6.3.bn3.running_mean", "resnet.6.3.bn3.running_var", "resnet.6.3.bn3.num_batches_tracked", "resnet.6.4.conv1.weight", "resnet.6.4.bn1.weight", "resnet.6.4.bn1.bias", "resnet.6.4.bn1.running_mean", "resnet.6.4.bn1.running_var", "resnet.6.4.bn1.num_batches_tracked", "resnet.6.4.conv2.weight", "resnet.6.4.bn2.weight", "resnet.6.4.bn2.bias", "resnet.6.4.bn2.running_mean", "resnet.6.4.bn2.running_var", "resnet.6.4.bn2.num_batches_tracked", "resnet.6.4.conv3.weight", "resnet.6.4.bn3.weight", "resnet.6.4.bn3.bias", "resnet.6.4.bn3.running_mean", "resnet.6.4.bn3.running_var", "resnet.6.4.bn3.num_batches_tracked", "resnet.6.5.conv1.weight", "resnet.6.5.bn1.weight", "resnet.6.5.bn1.bias", "resnet.6.5.bn1.running_mean", "resnet.6.5.bn1.running_var", "resnet.6.5.bn1.num_batches_tracked", "resnet.6.5.conv2.weight", "resnet.6.5.bn2.weight", "resnet.6.5.bn2.bias", "resnet.6.5.bn2.running_mean", "resnet.6.5.bn2.running_var", "resnet.6.5.bn2.num_batches_tracked", "resnet.6.5.conv3.weight", "resnet.6.5.bn3.weight", "resnet.6.5.bn3.bias", "resnet.6.5.bn3.running_mean", "resnet.6.5.bn3.running_var", "resnet.6.5.bn3.num_batches_tracked", "resnet.6.6.conv1.weight", "resnet.6.6.bn1.weight", "resnet.6.6.bn1.bias", "resnet.6.6.bn1.running_mean", "resnet.6.6.bn1.running_var", "resnet.6.6.bn1.num_batches_tracked", "resnet.6.6.conv2.weight", "resnet.6.6.bn2.weight", "resnet.6.6.bn2.bias", "resnet.6.6.bn2.running_mean", "resnet.6.6.bn2.running_var", "resnet.6.6.bn2.num_batches_tracked", "resnet.6.6.conv3.weight", "resnet.6.6.bn3.weight", "resnet.6.6.bn3.bias", "resnet.6.6.bn3.running_mean", "resnet.6.6.bn3.running_var", "resnet.6.6.bn3.num_batches_tracked", "resnet.6.7.conv1.weight", "resnet.6.7.bn1.weight", "resnet.6.7.bn1.bias", "resnet.6.7.bn1.running_mean", "resnet.6.7.bn1.running_var", "resnet.6.7.bn1.num_batches_tracked", "resnet.6.7.conv2.weight", "resnet.6.7.bn2.weight", "resnet.6.7.bn2.bias", "resnet.6.7.bn2.running_mean", "resnet.6.7.bn2.running_var", "resnet.6.7.bn2.num_batches_tracked", "resnet.6.7.conv3.weight", "resnet.6.7.bn3.weight", "resnet.6.7.bn3.bias", "resnet.6.7.bn3.running_mean", "resnet.6.7.bn3.running_var", "resnet.6.7.bn3.num_batches_tracked", "resnet.6.8.conv1.weight", "resnet.6.8.bn1.weight", "resnet.6.8.bn1.bias", "resnet.6.8.bn1.running_mean", "resnet.6.8.bn1.running_var", "resnet.6.8.bn1.num_batches_tracked", "resnet.6.8.conv2.weight", "resnet.6.8.bn2.weight", "resnet.6.8.bn2.bias", "resnet.6.8.bn2.running_mean", "resnet.6.8.bn2.running_var", "resnet.6.8.bn2.num_batches_tracked", "resnet.6.8.conv3.weight", "resnet.6.8.bn3.weight", "resnet.6.8.bn3.bias", "resnet.6.8.bn3.running_mean", "resnet.6.8.bn3.running_var", "resnet.6.8.bn3.num_batches_tracked", "resnet.6.9.conv1.weight", "resnet.6.9.bn1.weight", "resnet.6.9.bn1.bias", "resnet.6.9.bn1.running_mean", "resnet.6.9.bn1.running_var", "resnet.6.9.bn1.num_batches_tracked", "resnet.6.9.conv2.weight", "resnet.6.9.bn2.weight", "resnet.6.9.bn2.bias", "resnet.6.9.bn2.running_mean", "resnet.6.9.bn2.running_var", "resnet.6.9.bn2.num_batches_tracked", "resnet.6.9.conv3.weight", "resnet.6.9.bn3.weight", "resnet.6.9.bn3.bias", "resnet.6.9.bn3.running_mean", "resnet.6.9.bn3.running_var", "resnet.6.9.bn3.num_batches_tracked", "resnet.6.10.conv1.weight", "resnet.6.10.bn1.weight", "resnet.6.10.bn1.bias", "resnet.6.10.bn1.running_mean", "resnet.6.10.bn1.running_var", "resnet.6.10.bn1.num_batches_tracked", "resnet.6.10.conv2.weight", "resnet.6.10.bn2.weight", "resnet.6.10.bn2.bias", "resnet.6.10.bn2.running_mean", "resnet.6.10.bn2.running_var", "resnet.6.10.bn2.num_batches_tracked", "resnet.6.10.conv3.weight", "resnet.6.10.bn3.weight", "resnet.6.10.bn3.bias", "resnet.6.10.bn3.running_mean", "resnet.6.10.bn3.running_var", "resnet.6.10.bn3.num_batches_tracked", "resnet.6.11.conv1.weight", "resnet.6.11.bn1.weight", "resnet.6.11.bn1.bias", "resnet.6.11.bn1.running_mean", "resnet.6.11.bn1.running_var", "resnet.6.11.bn1.num_batches_tracked", "resnet.6.11.conv2.weight", "resnet.6.11.bn2.weight", "resnet.6.11.bn2.bias", "resnet.6.11.bn2.running_mean", "resnet.6.11.bn2.running_var", "resnet.6.11.bn2.num_batches_tracked", "resnet.6.11.conv3.weight", "resnet.6.11.bn3.weight", "resnet.6.11.bn3.bias", "resnet.6.11.bn3.running_mean", "resnet.6.11.bn3.running_var", "resnet.6.11.bn3.num_batches_tracked", "resnet.6.12.conv1.weight", "resnet.6.12.bn1.weight", "resnet.6.12.bn1.bias", "resnet.6.12.bn1.running_mean", "resnet.6.12.bn1.running_var", "resnet.6.12.bn1.num_batches_tracked", "resnet.6.12.conv2.weight", "resnet.6.12.bn2.weight", "resnet.6.12.bn2.bias", "resnet.6.12.bn2.running_mean", "resnet.6.12.bn2.running_var", "resnet.6.12.bn2.num_batches_tracked", "resnet.6.12.conv3.weight", "resnet.6.12.bn3.weight", "resnet.6.12.bn3.bias", "resnet.6.12.bn3.running_mean", "resnet.6.12.bn3.running_var", "resnet.6.12.bn3.num_batches_tracked", "resnet.6.13.conv1.weight", "resnet.6.13.bn1.weight", "resnet.6.13.bn1.bias", "resnet.6.13.bn1.running_mean", "resnet.6.13.bn1.running_var", "resnet.6.13.bn1.num_batches_tracked", "resnet.6.13.conv2.weight", "resnet.6.13.bn2.weight", "resnet.6.13.bn2.bias", "resnet.6.13.bn2.running_mean", "resnet.6.13.bn2.running_var", "resnet.6.13.bn2.num_batches_tracked", "resnet.6.13.conv3.weight", "resnet.6.13.bn3.weight", "resnet.6.13.bn3.bias", "resnet.6.13.bn3.running_mean", "resnet.6.13.bn3.running_var", "resnet.6.13.bn3.num_batches_tracked", "resnet.6.14.conv1.weight", "resnet.6.14.bn1.weight", "resnet.6.14.bn1.bias", "resnet.6.14.bn1.running_mean", "resnet.6.14.bn1.running_var", "resnet.6.14.bn1.num_batches_tracked", "resnet.6.14.conv2.weight", "resnet.6.14.bn2.weight", "resnet.6.14.bn2.bias", "resnet.6.14.bn2.running_mean", "resnet.6.14.bn2.running_var", "resnet.6.14.bn2.num_batches_tracked", "resnet.6.14.conv3.weight", "resnet.6.14.bn3.weight", "resnet.6.14.bn3.bias", "resnet.6.14.bn3.running_mean", "resnet.6.14.bn3.running_var", "resnet.6.14.bn3.num_batches_tracked", "resnet.6.15.conv1.weight", "resnet.6.15.bn1.weight", "resnet.6.15.bn1.bias", "resnet.6.15.bn1.running_mean", "resnet.6.15.bn1.running_var", "resnet.6.15.bn1.num_batches_tracked", "resnet.6.15.conv2.weight", "resnet.6.15.bn2.weight", "resnet.6.15.bn2.bias", "resnet.6.15.bn2.running_mean", "resnet.6.15.bn2.running_var", "resnet.6.15.bn2.num_batches_tracked", "resnet.6.15.conv3.weight", "resnet.6.15.bn3.weight", "resnet.6.15.bn3.bias", "resnet.6.15.bn3.running_mean", "resnet.6.15.bn3.running_var", "resnet.6.15.bn3.num_batches_tracked", "resnet.6.16.conv1.weight", "resnet.6.16.bn1.weight", "resnet.6.16.bn1.bias", "resnet.6.16.bn1.running_mean", "resnet.6.16.bn1.running_var", "resnet.6.16.bn1.num_batches_tracked", "resnet.6.16.conv2.weight", "resnet.6.16.bn2.weight", "resnet.6.16.bn2.bias", "resnet.6.16.bn2.running_mean", "resnet.6.16.bn2.running_var", "resnet.6.16.bn2.num_batches_tracked", "resnet.6.16.conv3.weight", "resnet.6.16.bn3.weight", "resnet.6.16.bn3.bias", "resnet.6.16.bn3.running_mean", "resnet.6.16.bn3.running_var", "resnet.6.16.bn3.num_batches_tracked", "resnet.6.17.conv1.weight", "resnet.6.17.bn1.weight", "resnet.6.17.bn1.bias", "resnet.6.17.bn1.running_mean", "resnet.6.17.bn1.running_var", "resnet.6.17.bn1.num_batches_tracked", "resnet.6.17.conv2.weight", "resnet.6.17.bn2.weight", "resnet.6.17.bn2.bias", "resnet.6.17.bn2.running_mean", "resnet.6.17.bn2.running_var", "resnet.6.17.bn2.num_batches_tracked", "resnet.6.17.conv3.weight", "resnet.6.17.bn3.weight", "resnet.6.17.bn3.bias", "resnet.6.17.bn3.running_mean", "resnet.6.17.bn3.running_var", "resnet.6.17.bn3.num_batches_tracked", "resnet.6.18.conv1.weight", "resnet.6.18.bn1.weight", "resnet.6.18.bn1.bias", "resnet.6.18.bn1.running_mean", "resnet.6.18.bn1.running_var", "resnet.6.18.bn1.num_batches_tracked", "resnet.6.18.conv2.weight", "resnet.6.18.bn2.weight", "resnet.6.18.bn2.bias", "resnet.6.18.bn2.running_mean", "resnet.6.18.bn2.running_var", "resnet.6.18.bn2.num_batches_tracked", "resnet.6.18.conv3.weight", "resnet.6.18.bn3.weight", "resnet.6.18.bn3.bias", "resnet.6.18.bn3.running_mean", "resnet.6.18.bn3.running_var", "resnet.6.18.bn3.num_batches_tracked", "resnet.6.19.conv1.weight", "resnet.6.19.bn1.weight", "resnet.6.19.bn1.bias", "resnet.6.19.bn1.running_mean", "resnet.6.19.bn1.running_var", "resnet.6.19.bn1.num_batches_tracked", "resnet.6.19.conv2.weight", "resnet.6.19.bn2.weight", "resnet.6.19.bn2.bias", "resnet.6.19.bn2.running_mean", "resnet.6.19.bn2.running_var", "resnet.6.19.bn2.num_batches_tracked", "resnet.6.19.conv3.weight", "resnet.6.19.bn3.weight", "resnet.6.19.bn3.bias", "resnet.6.19.bn3.running_mean", "resnet.6.19.bn3.running_var", "resnet.6.19.bn3.num_batches_tracked", "resnet.6.20.conv1.weight", "resnet.6.20.bn1.weight", "resnet.6.20.bn1.bias", "resnet.6.20.bn1.running_mean", "resnet.6.20.bn1.running_var", "resnet.6.20.bn1.num_batches_tracked", "resnet.6.20.conv2.weight", "resnet.6.20.bn2.weight", "resnet.6.20.bn2.bias", "resnet.6.20.bn2.running_mean", "resnet.6.20.bn2.running_var", "resnet.6.20.bn2.num_batches_tracked", "resnet.6.20.conv3.weight", "resnet.6.20.bn3.weight", "resnet.6.20.bn3.bias", "resnet.6.20.bn3.running_mean", "resnet.6.20.bn3.running_var", "resnet.6.20.bn3.num_batches_tracked", "resnet.6.21.conv1.weight", "resnet.6.21.bn1.weight", "resnet.6.21.bn1.bias", "resnet.6.21.bn1.running_mean", "resnet.6.21.bn1.running_var", "resnet.6.21.bn1.num_batches_tracked", "resnet.6.21.conv2.weight", "resnet.6.21.bn2.weight", "resnet.6.21.bn2.bias", "resnet.6.21.bn2.running_mean", "resnet.6.21.bn2.running_var", "resnet.6.21.bn2.num_batches_tracked", "resnet.6.21.conv3.weight", "resnet.6.21.bn3.weight", "resnet.6.21.bn3.bias", "resnet.6.21.bn3.running_mean", "resnet.6.21.bn3.running_var", "resnet.6.21.bn3.num_batches_tracked", "resnet.6.22.conv1.weight", "resnet.6.22.bn1.weight", "resnet.6.22.bn1.bias", "resnet.6.22.bn1.running_mean", "resnet.6.22.bn1.running_var", "resnet.6.22.bn1.num_batches_tracked", "resnet.6.22.conv2.weight", "resnet.6.22.bn2.weight", "resnet.6.22.bn2.bias", "resnet.6.22.bn2.running_mean", "resnet.6.22.bn2.running_var", "resnet.6.22.bn2.num_batches_tracked", "resnet.6.22.conv3.weight", "resnet.6.22.bn3.weight", "resnet.6.22.bn3.bias", "resnet.6.22.bn3.running_mean", "resnet.6.22.bn3.running_var", "resnet.6.22.bn3.num_batches_tracked", "resnet.7.0.conv1.weight", "resnet.7.0.bn1.weight", "resnet.7.0.bn1.bias", "resnet.7.0.bn1.running_mean", "resnet.7.0.bn1.running_var", "resnet.7.0.bn1.num_batches_tracked", "resnet.7.0.conv2.weight", "resnet.7.0.bn2.weight", "resnet.7.0.bn2.bias", "resnet.7.0.bn2.running_mean", "resnet.7.0.bn2.running_var", "resnet.7.0.bn2.num_batches_tracked", "resnet.7.0.conv3.weight", "resnet.7.0.bn3.weight", "resnet.7.0.bn3.bias", "resnet.7.0.bn3.running_mean", "resnet.7.0.bn3.running_var", "resnet.7.0.bn3.num_batches_tracked", "resnet.7.0.downsample.0.weight", "resnet.7.0.downsample.1.weight", "resnet.7.0.downsample.1.bias", "resnet.7.0.downsample.1.running_mean", "resnet.7.0.downsample.1.running_var", "resnet.7.0.downsample.1.num_batches_tracked", "resnet.7.1.conv1.weight", "resnet.7.1.bn1.weight", "resnet.7.1.bn1.bias", "resnet.7.1.bn1.running_mean", "resnet.7.1.bn1.running_var", "resnet.7.1.bn1.num_batches_tracked", "resnet.7.1.conv2.weight", "resnet.7.1.bn2.weight", "resnet.7.1.bn2.bias", "resnet.7.1.bn2.running_mean", "resnet.7.1.bn2.running_var", "resnet.7.1.bn2.num_batches_tracked", "resnet.7.1.conv3.weight", "resnet.7.1.bn3.weight", "resnet.7.1.bn3.bias", "resnet.7.1.bn3.running_mean", "resnet.7.1.bn3.running_var", "resnet.7.1.bn3.num_batches_tracked", "resnet.7.2.conv1.weight", "resnet.7.2.bn1.weight", "resnet.7.2.bn1.bias", "resnet.7.2.bn1.running_mean", "resnet.7.2.bn1.running_var", "resnet.7.2.bn1.num_batches_tracked", "resnet.7.2.conv2.weight", "resnet.7.2.bn2.weight", "resnet.7.2.bn2.bias", "resnet.7.2.bn2.running_mean", "resnet.7.2.bn2.running_var", "resnet.7.2.bn2.num_batches_tracked", "resnet.7.2.conv3.weight", "resnet.7.2.bn3.weight", "resnet.7.2.bn3.bias", "resnet.7.2.bn3.running_mean", "resnet.7.2.bn3.running_var", "resnet.7.2.bn3.num_batches_tracked". 

In [9]:
encoder.eval()
decoder.eval()

Decoder(
  (drop): Dropout(p=0.5, inplace=False)
  (embed): Embedding(4528, 200)
  (LSTM): LSTM(200, 256)
  (decode): Linear(in_features=256, out_features=4528, bias=True)
)

In [10]:
def load_img(path):
    transform = transforms.ToTensor()
    normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    img = Image.open(os.path.join('data/Flickr8k_Dataset', path))
    img = img.resize((224, 224))
    img = transform(img)
    img = normalize(img)
    img.unsqueeze_(0)
    
    return img

In [13]:
def generate():
    encoder.eval()
    decoder.eval()
    
    out_sen = []
    img = load_img('44129946_9eeb385d77.jpg')
    hidden = decoder.init_hidden(1)
    
    img_features = encoder(img)
    img_features.unsqueeze_(0)
    input = torch.empty((1,1), dtype = torch.long).to(device)
    
    with torch.no_grad():
        for t in range(15):
            if t==0:
                out, hidden = decoder(img_features, hidden, t)
            else:
                out, hidden = decoder(input, hidden, t)
            word_weights = out.squeeze().div(3.0).exp().cpu()
            word_idx = torch.multinomial(word_weights, 1)[0]
            
            input.fill_(word_idx)
            
            word = corpus.dictionary.idx2word[word_idx]
            out_sen.append(word)
            
            if word == '<eos>':
                break
            
    print(out_sen)
    

In [14]:
generate()

['fire', 'pulling', 'spectators', 'cyclist', 'hello', 'takes', 'older', 'down', 'campfire', 'stick', 'swarmed', 'barrier', 'face', 'paddles', 'stands']
