In [None]:
import glob 
import os
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F  # activation function
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
from tensorboardX import SummaryWriter

import poisevae
from poisevae.datasets import CUB
from poisevae.utils import NN_lookup

plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size'] = 20
plt.rcParams['font.weight'] = 'normal'
plt.rcParams['mathtext.fontset'] = 'cm'
plt.rcParams['text.usetex'] = False

In [None]:
HOME_PATH = os.path.expanduser('~')
DATA_PATH = os.path.join(HOME_PATH, 'Datasets/CUB/')

In [None]:
class EncImg(nn.Module):
    def __init__(self, latent_dim, input_dim=2048):
        super(EncImg, self).__init__()
        self.input_dim = input_dim
        self.latent_dim = latent_dim

        n = (input_dim, 1024, 512, 256)
        li = []
        for i in range(len(n)-1):
            li += [nn.Linear(n[i], n[i+1]), nn.ELU(inplace=True)]
        self.enc = nn.Sequential(*li)
        self.enc_mu = nn.Linear(256, latent_dim)
        self.enc_var = nn.Linear(256, latent_dim)
        
    def forward(self, x):
        x = self.enc(x)
        mu = self.enc_mu(x)
        log_var = self.enc_var(x)
        return mu, log_var

class DecImg(nn.Module):
    def __init__(self, latent_dim, output_dim=2048):
        super(DecImg, self).__init__()  
        self.latent_dim = latent_dim
        self.output_dim = output_dim
        
        n = (output_dim, 1024, 512, 256)
        li = [nn.Linear(latent_dim, 256)]
        for i in range(len(n)-1, 0, -1):
            li += [nn.ELU(inplace=True), nn.Linear(n[i], n[i-1])]
        self.dec = nn.Sequential(*li)
        
    def forward(self, z):
        return self.dec(z)

In [None]:
class EncTxt(nn.Module):
    def __init__(self, vocab_size, latent_dim, emb_dim=128):
        super(EncTxt, self).__init__()
        self.vocab_size = vocab_size
        self.latent_dim = latent_dim
        self.emb_dim = emb_dim
        
        # 0 is for the excluded words and does not contribute to gradient
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)

        n_channels = (1, 32, 64, 128, 256, 512)
        kernels = (4, 4, 4, (1, 4), (1, 4))
        strides = (2, 2, 2, (1, 2), (1, 2))
        paddings = (1, 1, 1, (0, 1), (0, 1))
        li = []
        for i, (n, k, s, p) in enumerate(zip(n_channels[1:], kernels, strides, paddings), 1):
            li += [nn.Conv2d(n_channels[i-1], n, kernel_size=k, stride=s, padding=p), 
                   nn.BatchNorm2d(n), nn.ReLU(inplace=True)]
            
        self.enc = nn.Sequential(*li)
        self.enc_mu = nn.Conv2d(512, latent_dim, kernel_size=4, stride=1, padding=0)
        self.enc_var = nn.Conv2d(512, latent_dim, kernel_size=4, stride=1, padding=0)
        
    def forward(self, x):
        x = self.emb(x.long()).unsqueeze(1) # add channel dim
        x = self.enc(x)
        mu = self.enc_mu(x).squeeze()
        log_var = self.enc_var(x).squeeze()
        return mu, log_var

class DecTxt(nn.Module):
    def __init__(self, vocab_size, latent_dim, emb_dim=128, txt_len=32):
        super(DecTxt, self).__init__()
        self.vocab_size = vocab_size
        self.latent_dim = latent_dim
        self.emb_dim = emb_dim
        self.txt_len = txt_len
        
        n_channels = (1, 32, 64, 128, 256, 512, latent_dim)
        kernels = (4, 4, 4, (1, 4), (1, 4), 4)
        strides = (2, 2, 2, (1, 2), (1, 2), 1)
        paddings = (1, 1, 1, (0, 1), (0, 1), 0)
        li = []
        for i, (n, k, s, p) in enumerate(zip(n_channels[1:], kernels, strides, paddings), 1):
            li = [nn.ConvTranspose2d(n, n_channels[i-1], kernel_size=k, stride=s, padding=p), 
                  nn.BatchNorm2d(n_channels[i-1]), nn.ReLU(inplace=True)] + li
            
        # No batchnorm at the first and last block
        del li[-2]
        del li[1]
        self.dec = nn.Sequential(*li)
        self.anti_emb = nn.Linear(emb_dim, vocab_size)
        
    def forward(self, z, argmax=False):
        z = self.dec(z)
        z = self.anti_emb(z.view(-1, self.emb_dim))
        z = z.view(-1, self.txt_len, self.vocab_size) # batch x txt len x vocab size
        if argmax:
            z = z.argmax(-1).float()
        return z

In [None]:
tx = lambda data: torch.Tensor(data)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
CUB_train = CUB(DATA_PATH, DATA_PATH, 'train', device, tx, return_idx=False)
CUB_test = CUB(DATA_PATH, DATA_PATH, 'test', device, tx, return_idx=True)

In [None]:
vocab_size, txt_len = CUB_train.CUBtxt.vocab_size, CUB_train.CUBtxt.max_sequence_length
vocab_size, txt_len

In [None]:
batch_size = 128
train_loader = torch.utils.data.DataLoader(CUB_train, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(CUB_test, batch_size=batch_size, shuffle=True)
len(train_loader), len(test_loader)

In [None]:
enc_img = EncImg(128).to(device)
dec_img = DecImg(128).to(device)
enc_txt = EncTxt(vocab_size, 128).to(device)
dec_txt = DecTxt(vocab_size, 128).to(device)

def CELoss(input, target):
    batch_size = target.shape[0]
    input, target = input.view(-1, vocab_size), target.view(-1).to(torch.long)
    loss = nn.functional.cross_entropy(input, target, reduction='sum') #/ batch_size
    return loss
    
def MSELoss(input, target):
    loss = nn.functional.mse_loss(input, target, reduction='sum') #/ batch_size
    return loss / 10
    
def MAELoss(input, target):
    loss = nn.functional.l1_loss(input, target, reduction='sum') #/ batch_size
    return loss / 10
    
vae = poisevae.POISEVAE([enc_img, enc_txt], [dec_img, dec_txt], [MAELoss, CELoss], 
                        latent_dims=[128, (128, 1, 1)]).to(device)

In [None]:
# for i in vae.named_parameters():
#     print(i[0])

In [None]:
optimizer = optim.Adam(vae.parameters(), lr=1e-3)

In [None]:
PATH = os.path.join(HOME_PATH, 'Datasets/CUB_train_results/')

In [None]:
epoch = 0
# try:
#     vae, optimizer, epoch = poisevae.utils.load_checkpoint(vae, optimizer, sorted(glob.glob(PATH + 'train*.pt'))[-1])
# except:
#     pass

In [None]:
from datetime import datetime
writer = SummaryWriter('runs/CUB/' + datetime.now().strftime('%y%m%d%H%M'))

In [None]:
epochs = 30 + epoch
for epoch in tqdm(range(epoch, epochs)):
    poisevae.utils.train(vae, train_loader, optimizer, epoch, writer)
    labels, latent_info = poisevae.utils.test(vae, test_loader, epoch, writer, 
                                              record_idx=(2, 3), return_latents=True)
    if (epoch+1) % 10 == 0 and epoch > 0:
        poisevae.utils.save_checkpoint(vae, optimizer, PATH + 'training_%d.pt' % (epoch+1), epoch+1)

In [None]:
writer.flush()
writer.close()

In [None]:
poisevae.utils.save_latent_info(latent_info, PATH)

## Results

In [None]:
# vae, optimizer, epoch = poisevae.utils.load_checkpoint(vae, optimizer, sorted(glob.glob(PATH + 'train*.pt'))[-1])
# epoch

In [None]:
with torch.no_grad():
    for i, (img, txt, idx) in enumerate(test_loader):
        results = vae([img.to(device), txt.to(device)])
        dist, idx_h = NN_lookup(results['x_rec'][0], img)
        # dist, idx_h = NN_lookup(results['x_rec'][0], img)
        imgs_h = [CUB_test.CUBimg.dataset[int(idx[j]) // 10][0] for j in idx_h]
        imgs = [CUB_test.CUBimg.dataset[int(j) // 10][0] for j in idx]
        txts = [CUB_test.CUBtxt.data[str(int(j))] for j in idx]
        break

In [None]:
ncols = min(len(imgs_h), 20)
fig, ax = plt.subplots(nrows=2, ncols=ncols, figsize=(15, 1.5))
for i, aux in enumerate(zip(imgs, imgs_h)):
    if i >= ncols:
        break
    for j, im in enumerate(aux):
        ax[j, i].imshow(im.permute(1, 2, 0))
        ax[j, i].set_xticks([])
        ax[j, i].set_yticks([])
ax[1, 0].set_ylabel('Rec', fontsize=24)
fig.tight_layout(pad=0)
fig.savefig(PATH + 'CUBImgRec.pdf', dpi=300)

In [None]:
ncols = min(len(imgs_h), 20)
nrows = 2 * int(np.floor(len(imgs_h) / 20))
fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 1.5*nrows/2))
for i, aux in enumerate(zip(imgs, imgs_h)):
    if i >= nrows / 2 * 20:
        break
    for j, im in enumerate(aux):
        ax[j+i//20*2, i%20].imshow(im.permute(1, 2, 0))
        ax[j+i//20*2, i%20].set_xticks([])
        ax[j+i//20*2, i%20].set_yticks([])
    ax[1+i//20*2, 0].set_ylabel('Rec', fontsize=24)
fig.tight_layout(pad=0)
fig.savefig(PATH + 'CUBImgRecExtra.pdf', dpi=300)

In [None]:
sents_h = []
sents_emb = results['x_rec'][1].argmax(-1)
for sent_emb in sents_emb:
    sent_h = []
    for j in sent_emb:
        char = CUB_test.CUBtxt.i2w[str(int(j.item()))]
        if char == '<eos>':
            break
        sent_h.append(char)
    sents_h.append(' '.join(sent_h))

sents = []
for i, sent_h in enumerate(sents_h):
    sent = []
    for char in txts[i]['tok']:
        if char == '<eos>':
            break
        sent.append(char)
    sents.append(' '.join(sent))

In [None]:
for sent, sent_h in zip(sents, sents_h):
    print('     ', sent)
    print('Rec: ', sent_h)

In [None]:
with open(PATH + 'CUBTxtRec.txt', 'w') as file:
    for sent, sent_h in zip(sents, sents_h):
        file.write('Ref: ' + sent + '\n')
        file.write('Rec: ' + sent_h + '\n')