In [None]:
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from PIL import Image

import gensim

model = gensim.models.KeyedVectors.load_word2vec_format('../models/word2vec_model.bin', binary=True) 

def get_feature_space_representaion(sentence):
    doc_tokens = sentence.split()
    words = [token for token in doc_tokens if token in model.vocab]
    features = model[words]
    mins = np.min(features, axis=0)
    maxs = np.max(features, axis=0)
    stack = np.hstack((mins, maxs))
    return stack


In [None]:
import yaml
from pprint import pprint
from enum import Enum

class Generator(nn.Module):
    def __init__(self, config):
        super(Generator, self).__init__()
        self.ngpu = config.N_GPU.value
        nz = config.N_LATENT_VECTOR.value
        nzc = config.N_TEXT_EMBEDDING.value
        ngf = config.N_GENERATOR_FEATURE_MAP.value
        nc = config.N_COLOR_CHANNELS.value
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz+nzc, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, image, comment):
        comment = comment.reshape((comment.shape[0],comment.shape[1],1,1))
        cat = torch.cat((image, comment), dim=1)
        return self.main(cat)

#load image model

config_path = '../config/config.yml'
with open(config_path, 'r') as stream:
    try:
        config = yaml.safe_load(stream)
    except yaml.YAMLError as exc:
        print(exc)
print('Config loaded from: {}'.format(config_path))
sentence = 'cat on table'
comment = get_feature_space_representaion(sentence)

config = Enum('config', config)
device = torch.device("cuda:0" if (config.GPU_MODE.value and torch.cuda.is_available() and config.N_GPU.value > 0) else "cpu")
fixed_noise = torch.randn(1, config.N_LATENT_VECTOR.value, 1, 1, device=device)

path_to_model = '../models/generator_55.pth'
generator = Generator(config).to(device)
generator.load_state_dict(torch.load(path_to_model))
generator.eval()

comment = torch.tensor(comment.reshape((1,600,1))).to(device)
img = generator(fixed_noise, comment).detach().cpu()
plt.figure(figsize=(15,15))
plt.axis("off")
plt.title("Result of sentence".format(sentence))
img = vutils.make_grid(img, padding=2, normalize=True)
plt.imshow(np.transpose(img,(1,2,0)))