In [212]:
from __future__ import unicode_literals, print_function, division
from io import open
import glob
import unicodedata
import string
import pandas as pd

import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable as V

In [187]:
data = pd.read_csv('data/train2014.csv')
data.drop(['img_fname'], axis=1, inplace=True)
data.head()

Unnamed: 0,captions
0,['Two ducks floating on top of a lake with bro...
1,['A long bill bird standing on a beach next to...
2,['Two birds sitting on a red towel on a shower...
3,['many elephants walking with a person with a ...
4,['A black and white bird sitting on top of a r...


In [188]:
data.captions[0]

"['Two ducks floating on top of a lake with brown water.', 'Two ducks are swimming on the lake next to each other.', 'Two geese swimming on the water and looking toward the camera.', 'Two ducks facing the same direction, wading in a pond.', 'Two geese resting in a body of water']"

In [189]:
tmp_tmp = []

for caption in data.captions:
    tmp = caption
    for sent in tmp.lower().replace(',','').replace('.','').replace('"','').split("'"):
        tmp_tmp.append(sent)
all_captions = [tmp for tmp in tmp_tmp if len(tmp)>6]
all_captions[0:5]

['two ducks floating on top of a lake with brown water',
 'two ducks are swimming on the lake next to each other',
 'two geese swimming on the water and looking toward the camera',
 'two ducks facing the same direction wading in a pond',
 'two geese resting in a body of water']

In [190]:
bow = sorted(list(set(' '.join(all_captions).split(' '))))
bow[0:10]

['', '#', '#2', '&', '(just', '(thailand)', '(two', '(unseen)', '-', '1']

In [191]:
vocab2id = {key:i for i, key in enumerate(bow)}
vocab2id[""]

0

In [192]:
vocab2id['<start>'] = 10239
vocab2id['<end>'] = 10240
vocab2id['<pad>'] = 10241

In [193]:
new_all_caps = []
for caption in all_captions:
    new_captions = '<start> ' + caption
    new_all_caps.append(new_captions)
all_captions = []
for caption in new_all_caps:
    all_captions.append(str(caption).split())
all_captions[0]

['<start>',
 'two',
 'ducks',
 'floating',
 'on',
 'top',
 'of',
 'a',
 'lake',
 'with',
 'brown',
 'water']

In [195]:
max_len = max([len(cap) for cap in all_captions])

In [197]:
all_captions[0]

['<start>',
 'two',
 'ducks',
 'floating',
 'on',
 'top',
 'of',
 'a',
 'lake',
 'with',
 'brown',
 'water']

In [199]:
caption = all_captions[0].copy()


(51, 51)

In [214]:
y_length = [len(cap) for cap in all_captions]

In [201]:
train_cap = []
y_cap = []
for i in range(len(all_captions)):
    caption = all_captions[i].copy()
    diff = max_len - len(caption)
    tmp = caption.copy()
    pad = ['<pad>' for i in range(diff)]
    tmp.append('<end>')
    tmp.extend(pad)
    caption.extend(pad)
    caption.append('<pad>')
    train_cap.append(caption)
    y_cap.append(tmp)
len(train_cap[0]), len(y_cap[0])

(51, 51)

In [217]:
X_lookup = []
for x in train_cap:
    X_lookup.append([vocab2id[word] for word in x])
y_lookup = []
for x in y_cap:
    y_lookup.append([vocab2id[word] for word in x])

In [218]:
X_lookup[0], y_lookup[0]

([10239,
  9470,
  2774,
  3374,
  6007,
  9207,
  5967,
  85,
  4886,
  10057,
  1197,
  9852,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241],
 [10239,
  9470,
  2774,
  3374,
  6007,
  9207,
  5967,
  85,
  4886,
  10057,
  1197,
  9852,
  10240,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241,
  10241])

In [211]:
vocab_len = len(vocab2id.keys())
vocab_emb = nn.Embedding(vocab_len, 50)

In [215]:
class CustomDataset(Dataset):
    def __init__(self, x, y, y_length):
        self.x = x
        self.y = y
        self.y_length = y_length
    
    def __len__(self):
        self.len = len(x)
        return self.len
    
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx], self.y_length[idx]
        

In [216]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, vocab_size, emb_size):
        super(RNN, self).__init__()
        self.emb = nn.Embedding(vocab_size, emb_size)
        self.hidden_size = hidden_size
        self.linear_i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.linear_h2o = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        x = x.long()
        x = self.emb(x)
        combined = torch.cat((x, hidden), 1)
        hidden = F.tanh(self.linear_i2h(combined))
        output = self.linear_h2o(hidden)
        return output, hidden

    def initHidden(self, bs):
        return torch.zeros(bs, self.hidden_size)
        

In [None]:
def train_loop(model, epochs, lr=0.01, wd=0.0):
    optim = get_optimizer(model, lr = lr, wd = wd)
    for i in range(epochs):
        loss = train(model, optim, train_2_dl)
        if i%5 == 1: print("train loss %.3f" % loss)
    predict(model, test_2_dl)