In [6]:
import pandas 
import numpy as np
import os
import re
import operator
import math
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from constants import *
from data import *
from encoders import *
from decoders import *
from solver import *
import matplotlib.pyplot as plt
import torchvision.models as torchmodels

In [208]:
train_size = 10
valid_size = 10
batch_size = 2
coco = COCO(
    pandas.read_csv("/mnt/raid/davech2y/COCO_2014/preprocessed/coco_train2014.caption.csv"), 
#     pandas.read_csv("/mnt/raid/davech2y/COCO_2014/preprocessed/coco_train2014.caption.csv"), 
    pandas.read_csv("/mnt/raid/davech2y/COCO_2014/preprocessed/coco_valid2014.caption.csv"),
    [train_size, valid_size]
)
train_captions = coco.transformed_data['train']
valid_captions = coco.transformed_data['valid']
dict_idx2word = coco.dict_idx2word
dict_word2idx = coco.dict_word2idx
corpus = coco.corpus
train_ds = COCOCaptionDataset(
    None,
    train_captions, 
    database="/mnt/raid/davech2y/COCO_2014/preprocessed/coco_train2014_224_new.hdf5"
)
valid_ds = COCOCaptionDataset(
    None, 
    valid_captions,
    database="/mnt/raid/davech2y/COCO_2014/preprocessed/coco_valid2014_224_new.hdf5"
#     database="/mnt/raid/davech2y/COCO_2014/preprocessed/coco_train2014_224_new.hdf5"
)
train_dl = DataLoader(train_ds, batch_size=batch_size)
valid_dl = DataLoader(valid_ds, batch_size=batch_size)
dataloader = {
    'train': train_dl,
    'valid': valid_dl
}

In [225]:
class AE(nn.Module):
    def __init__(self, dict_size):
        super(AttnAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 3, kernel_size=2, stride=2),
            nn.ReLU()
        )
    
    def forward(self, inputs, caption_inputs):
        visual_encoded = self.encoder(inputs)
        outputs = self.decoder(visual_encoded)
        
        return outputs

In [226]:
class AE_skip(nn.Module):
    def __init__(self, dict_size):
        super(AttnAE, self).__init__()
        self.encoder_1 = nn.Sequential(
            # (3, 224, 224)
            nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.encoder_2 = nn.Sequential(
            # (8, 112, 112)
            nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.encoder_3 = nn.Sequential(
            # (16, 56, 56)
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.encoder_4 = nn.Sequential(
            # (32, 28, 28)
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # (64, 14, 14)
        )
        self.encoder_5 = nn.Sequential( 
            # (64, 14, 14)
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
            # (128, 7, 7)
        )

        self.decoder_1 = nn.Sequential(
            # (128, 7, 7)
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.ReLU(),
            # (64, 14, 14)
        )
        self.decoder_2 = nn.Sequential(    
            # (128, 14, 14)
            nn.ConvTranspose2d(128, 32, kernel_size=2, stride=2),
            nn.ReLU(),
            # (32, 28, 28)
        )
        self.decoder_3 = nn.Sequential(
            # (64, 28, 28)
            nn.ConvTranspose2d(64, 16, kernel_size=2, stride=2),
            nn.ReLU(),
            # (16, 56, 56)
        )
        self.decoder_4 = nn.Sequential(
            # (32, 56, 56)
            nn.ConvTranspose2d(32, 8, kernel_size=2, stride=2),
            nn.ReLU()
            # (8, 112, 112)
        )
        self.decoder_5 = nn.Sequential(
            # (8, 112, 112)
            nn.ConvTranspose2d(16, 3, kernel_size=2, stride=2),
            nn.ReLU()
            # (3, 224, 224)
        )
    
    
    def forward(self, inputs, caption_inputs):
        encoded_1 = self.encoder_1(inputs)
        encoded_2 = self.encoder_2(encoded_1)
        encoded_3 = self.encoder_3(encoded_2)
        encoded_4 = self.encoder_4(encoded_3)
        encoded_5 = self.encoder_5(encoded_4)
        decoded_1 = self.decoder_1(encoded_5)
        decoded_2 = self.decoder_2(torch.cat((decoded_1, encoded_4), dim=1))
        decoded_3 = self.decoder_3(torch.cat((decoded_2, encoded_3), dim=1))
        decoded_4 = self.decoder_4(torch.cat((decoded_3, encoded_2), dim=1))
        outputs = self.decoder_5(torch.cat((decoded_4, encoded_1), dim=1))
        
        return outputs

In [227]:
class AttnAE(nn.Module):
    def __init__(self, dict_size):
        super(AttnAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 3, kernel_size=2, stride=2),
            nn.ReLU()
        )
        self.comp_visual = nn.Linear(256, 512)
        self.comp_hidden = nn.Linear(512, 512)
        self.attn_combine = nn.Linear(512, 1)
        self.embedding = nn.Embedding(dict_size, 512)
        self.text_encoder = nn.LSTMCell(512, 512)
    
    def initHidden(self, visual_encoded):
        states = (
            Variable(torch.zeros(visual_encoded.size(0), 512)),
            Variable(torch.zeros(visual_encoded.size(0), 512))
        )
        
        return states
    
    def attention(self, visual_encoded, hiddens):
        inputs = visual_encoded.view(visual_encoded.size(0), visual_encoded.size(1), -1)
        inputs = inputs.transpose(2, 1).contiguous()
        V = self.comp_visual(inputs)
        H = self.comp_hidden(hiddens).unsqueeze(1)
        outputs = F.tanh(V + H)
        outputs = self.attn_combine(outputs).squeeze()
        outputs = F.softmax(outputs, dim=1)
        
        return outputs
    
    def forward(self, inputs, caption_inputs):
        visual_encoded = self.encoder(inputs)
        states = self.initHidden(visual_encoded)
        seq_length = caption_inputs.size(1)
        attn_inputs = []
        for step in range(seq_length):
            embedded = self.embedding(caption_inputs[:, step])
            states = self.text_encoder(embedded, states)
            attn_inputs.append(states[0].unsqueeze(1))
        attn_inputs = torch.cat(attn_inputs, dim=1)
        attn_inputs = attn_inputs.mean(1)
        attn_weights = self.attention(visual_encoded, attn_inputs)
        attended = visual_encoded.view(visual_encoded.size(0), visual_encoded.size(1), -1) * attn_weights.unsqueeze(1)
        outputs = self.decoder(attended.view(attended.size(0), attended.size(1), int(np.sqrt(attended.size(2))), int(np.sqrt(attended.size(2)))))
        
        return outputs

In [223]:
model = AE(dict_idx2word.__len__() + 1)
cr = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [224]:
train_loss = []
valid_loss = []
for e in range(1000):
    for phase in ["train", "valid"]:
        total = []
        for model_ids, visuals, captions, cap_lengths in dataloader[phase]:
            inputs = Variable(visuals)
            caption_inputs = Variable(torch.cat([item.view(1, -1) for item in captions]).transpose(1, 0)[:, :cap_lengths[0]])
            outputs = model(inputs, caption_inputs)
            loss = cr(outputs, visuals)
            if phase == "train":
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            total.append(loss.item())
        if phase == "train":
            print("{}_loss:".format(phase), np.mean(total))
            train_loss.append(np.mean(total))
        else:
            print("{}_loss:".format(phase), np.mean(total))
            valid_loss.append(np.mean(total))

loss: 1.4731552600860596
loss: 1.4684765338897705
loss: 1.4642805337905884
loss: 1.4601991653442383
loss: 1.455910849571228
loss: 1.4514995574951173
loss: 1.4465100049972535
loss: 1.4355175018310546
loss: 1.311687970161438
loss: 1.1681266188621522
loss: 1.1369916200637817
loss: 1.0546221375465392
loss: 1.0198426604270936
loss: 0.9784405827522278
loss: 0.9332048296928406
loss: 0.9131057739257813
loss: 0.9004910826683045
loss: 0.8944319009780883
loss: 0.8956144094467163
loss: 0.8869249224662781
loss: 0.8753095030784607
loss: 0.8689142942428589
loss: 0.8660258650779724
loss: 0.8594482421875
loss: 0.8554783463478088
loss: 0.8501441597938537
loss: 0.8411869883537293
loss: 0.8266507148742676
loss: 0.8190279364585876
loss: 0.798250138759613
loss: 0.7715648531913757
loss: 0.7178617238998413
loss: 0.6668466567993164
loss: 0.6422518968582154
loss: 0.61940438747406
loss: 0.6143876433372497
loss: 0.5965945899486542
loss: 0.5900753378868103
loss: 0.5881588578224182
loss: 0.5769123315811158
loss: 0.

loss: 0.21944119334220885
loss: 0.21641429364681244
loss: 0.2163900375366211
loss: 0.21368127167224885
loss: 0.21223822832107545
loss: 0.21224642992019654
loss: 0.2134406477212906
loss: 0.21015147268772125
loss: 0.21016659140586852
loss: 0.211619234085083
loss: 0.21015809178352357
loss: 0.21265054643154144
loss: 0.2115221858024597
loss: 0.20753279328346252
loss: 0.20657336115837097
loss: 0.2085155129432678
loss: 0.2087952733039856
loss: 0.20818174183368682
loss: 0.2117968827486038
loss: 0.21058105528354645
loss: 0.2073229283094406
loss: 0.21580494940280914
loss: 0.20881400406360626
loss: 0.2097170054912567
loss: 0.2021850675344467
loss: 0.20223195254802703
loss: 0.20116907954216004
loss: 0.2012250989675522
loss: 0.19731835424900054
loss: 0.19894111454486846
loss: 0.19922910630702972
loss: 0.1961271196603775
loss: 0.19574955403804778
loss: 0.20119810998439788
loss: 0.1984974503517151
loss: 0.19914743900299073
loss: 0.19828783273696898
loss: 0.19830714762210847
loss: 0.2017400085926056
l

loss: 0.13041398376226426
loss: 0.13220637291669846
loss: 0.13187080174684523
loss: 0.13275076746940612
loss: 0.1332360103726387
loss: 0.13213045597076417
loss: 0.13474689573049545
loss: 0.134626667201519
loss: 0.13082030266523362
loss: 0.13075461536645888
loss: 0.1281462401151657
loss: 0.12764533013105392
loss: 0.12821342051029205
loss: 0.1271494671702385
loss: 0.12844094783067703
loss: 0.12761468887329103
loss: 0.12704060226678848
loss: 0.1294735088944435
loss: 0.12781810462474824
loss: 0.12810629606246948
loss: 0.1281293272972107
loss: 0.12765832543373107
loss: 0.1295992374420166
loss: 0.1298574134707451
loss: 0.12956509590148926
loss: 0.13009819835424424
loss: 0.1289130002260208
loss: 0.1307766392827034
loss: 0.13017017990350724
loss: 0.13169043362140656
loss: 0.13280769139528276
loss: 0.1316063940525055
loss: 0.13438787311315536
loss: 0.1327378734946251
loss: 0.12973684072494507
loss: 0.1317090019583702
loss: 0.12912764847278596
loss: 0.1287749320268631
loss: 0.12953146249055864
l

loss: 0.10700096487998963
loss: 0.10667206197977067
loss: 0.10638737231492996
loss: 0.10570160150527955
loss: 0.10576317757368088
loss: 0.10494960695505143
loss: 0.10518445968627929
loss: 0.10432399362325669
loss: 0.10457066893577575
loss: 0.10468761175870896
loss: 0.10416443049907684
loss: 0.10445534586906433
loss: 0.10424883812665939
loss: 0.10421891063451767
loss: 0.10414347052574158
loss: 0.10418057292699814
loss: 0.10390780121088028
loss: 0.10358244478702545
loss: 0.10330913066864014
loss: 0.1043105110526085
loss: 0.10566964745521545
loss: 0.10568356215953827
loss: 0.1056232824921608
loss: 0.10434226989746094
loss: 0.10576246082782745
loss: 0.10705870985984803
loss: 0.10787830799818039
loss: 0.108588607609272
loss: 0.10704617202281952
loss: 0.10795335024595261
loss: 0.10728948712348937
loss: 0.10681288689374924
loss: 0.10797723978757859
loss: 0.1064338818192482
loss: 0.10706541091203689
loss: 0.10713682025671005
loss: 0.10825761556625366
loss: 0.11143080443143845
loss: 0.109591948

In [221]:
len(test_1)

1000