### Load stuff

In [2]:
import torch
from torchvision.transforms import Normalize
from torchvision.utils import save_image
from torchvision.utils import make_grid
from gan_model import UpsampleGenerator, Discriminator
from vae_model import VAE
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
latent_dimension = 100

In [4]:
re_norm = Normalize([-1, -1, -1], [2, 2, 2])

In [None]:
# load the generator
gan_checkpoint = '../upsample/checkpoint_{}.pth'.format(99)
gen_model = UpsampleGenerator(latent_dimension).to(device)
checkpoint = torch.load(gan_checkpoint)
gen_model.load_state_dict(checkpoint['gen_model'])

In [None]:
# load VAE
vae_checkpoint = '../upsample/vae_checkpoint_{}.pth'.format(99)
vae_model = VAE(latent_dimension).to(device)
checkpoint = torch.load(vae_checkpoint)
vae_model.load_state_dict(checkpoint['model'])

In [None]:
def show_grid(imgs, nrow):
    img = make_grid(imgs, nrow=nrow, padding=100)
    npimg = img.numpy(img)
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')

### Qualitative evaluation

In [None]:
# generate visual samples
z = torch.randn(1000, latent_dimension, device=device)
vae_out = vae_model(z)
gan_out = gan_model(z)
for i , im in enumerat(vae_out):
    save_image(re_norm(im), 'vae_sample_{}.png'.format(i))
for i , im in enumerat(gan_out):
    save_image(re_norm(im), 'gan_sample_{}.png'.format(i))

In [None]:
# perturbing dimension
def pertube(z, epsilon, steps, dim):
    epsilon_m = torch.cat([epsilon*i for i in range(-steps/2, steps/2)]).view(1, -1)
    new_z = torch.repeat(z)
    new_z[:, dim] += epsilon_m
    return new_z

In [None]:
z = torch.randn(1, latent_dimension, device=device)
pbs = 16
new_z = pertube(z, 0.0001, pbs, 5)
vae_out = vae_model(new_z)
gan_out = gan_model(new_z)
show_grid(vae_out, pbs)