In [1]:
from Generator import Generator

import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image

import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d

In [14]:
GEN_IN = 100
IMG_LEN = 16
selection = {
    0: 'random',
    1: 'interpolation',
    2: 'semantic' # for example, (female - male) + child
}

In [3]:
gen = Generator().cuda()
gen.load_state_dict(torch.load('./models/generator34.pt'))

<All keys matched successfully>

In [78]:
mode = 1

if selection[mode] == 'random':
    # generate random seeds and move to cuda
    latent_space_vec = torch.randn(IMG_LEN, GEN_IN, 1, 1).cuda()
    
elif selection[mode] == 'interpolation':
    # load vectors
    load_vector = np.loadtxt('vectors.txt')
    y = np.vstack([load_vector[1], load_vector[10]])
    xvals = np.linspace(0, 1, num=IMG_LEN)
    # interpolation
    sample = interp1d([0, 1], y, axis=0)
    # generate interpolated noise/seed
    latent_space_vec = torch.tensor(sample(xvals).reshape(IMG_LEN, GEN_IN, 1, 1), dtype=torch.float32).cuda()
    
elif selection[mode] == 'semantic':
    # load vectors
    load_vector = np.loadtxt('vectors.txt')
    # sematically calculate new faces
    z1 = (load_vector[0] + load_vector[6] + load_vector[8]) / 3.
    z2 = (load_vector[1] + load_vector[2] + load_vector[4]) / 3.
    z3 = (load_vector[3] + load_vector[4] + load_vector[6]) / 3.
    z_new = z1 - z2 + z3
    sample = np.zeros(shape=(IMG_LEN, GEN_IN))
    # generate seed/noise
    for i in range(IMG_LEN):
        sample[i] = z_new + 0.1 * np.random.normal(-1.0, 1.0, 100)
    latent_space_vec = torch.tensor(sample.reshape(IMG_LEN, GEN_IN, 1, 1), dtype=torch.float32).cuda()

# turn off autograd
with torch.no_grad():
    # get output
    viz_sample = gen(latent_space_vec)
    # generate latent space vectors
    viz_vector = latent_space_vec.detach().cpu().numpy().reshape(IMG_LEN, GEN_IN)
    # save vector
    np.savetxt('vectors.txt', viz_vector)
    # save images
    save_image(viz_sample, f'img_{selection[mode]}.png', nrow=IMG_LEN, normalize=True)