In [64]:
import os
import yaml
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn.functional as F
from torchvision.utils import make_grid, save_image

from infogan.infogan_module import InfoGAN

In [65]:
VERSION = 0
DIST_TYPE = 'gaussian'
GRP_NO = 0
MAKE_GRID = False

In [66]:
def interpolate_dim(gan, grp_type='gaussian', grp_no=0, std_range=4, num_samples=5, num_interops=11):
    gan.eval()
    device = gan.device
    z = gan.sample_noise(num_samples)
    c = gan.sample_code(num_samples)
    if grp_type == 'gaussian':
        z = z.repeat((num_interops, 1, 1))
        c = c.repeat((num_interops, 1, 1))
        prev_idx = sum(gan.hparams.categ_code_dims)
        for i, std in enumerate(np.linspace(-std_range, std_range, num_interops)):
            c[i, :, prev_idx + grp_no] = std
    elif grp_type == 'categorical':
        z = z.repeat((gan.hparams.categ_code_dims[grp_no], 1, 1))
        c = c.repeat((gan.hparams.categ_code_dims[grp_no], 1, 1))
        prev_idx = sum(gan.hparams.categ_code_dims[:grp_no])
        for i in range(gan.hparams.categ_code_dims[grp_no]):
            dist_dim = gan.hparams.categ_code_dims[grp_no]
            c[i, :, prev_idx:prev_idx+dist_dim] = F.one_hot(torch.tensor(i, device=device), dist_dim)
    else:
        raise ValueError
    with torch.no_grad():
        imgs = gan(z.reshape(-1, z.shape[-1]), c.reshape(-1, c.shape[-1]))
    imgs = imgs.reshape(-1, num_samples, *imgs.shape[-3:])
    imgs = [make_grid(img, nrow=num_samples, padding=2,
                      normalize=True, norm_range=(0, 1)).cpu() for img in imgs]
    return imgs

In [67]:
model_dir = f'lightning_logs/version_{VERSION}'

with open(model_dir + '/hparams.yaml') as f:
    hparams = yaml.safe_load(f)

files = os.listdir(model_dir + '/checkpoints')
files.sort()
file = files[0]

model = InfoGAN.load_from_checkpoint(model_dir + '/checkpoints/' + file)

In [68]:
viz_imgs = interpolate_dim(model, DIST_TYPE, GRP_NO)

In [69]:
viz_imgs = interpolate_dim(model, DIST_TYPE, GRP_NO)

if MAKE_GRID:
    save_image(viz_imgs, f"media/version_{VERSION}_{DIST_TYPE}_{GRP_NO}.jpg", nrow=1)
else:
    viz_imgs = [img.numpy().transpose(1, 2, 0) for img in viz_imgs]
    viz_imgs = [(255 * img).astype(np.uint8) for img in viz_imgs]
    viz_imgs = [Image.fromarray(img) for img in viz_imgs]
    viz_imgs[0].save(f"media/version_{VERSION}_{DIST_TYPE}_{GRP_NO}.gif", format="GIF", 
                     append_images=viz_imgs[1:], save_all=True, duration=len(viz_imgs) * 50, loop=0)