In [1]:
import json
from tqdm import tqdm 

import torch
from torchvision.utils import make_grid
from torch.utils.data import DataLoader

from datasets import get_MNIST, postprocess
from model import Glow

import matplotlib.pyplot as plt

In [2]:
device = torch.device("cuda")
output_folder = 'output/'
model_name = 'glow_checkpoint_9800.pt'

with open(output_folder + 'hparams.json') as json_file:  
    hparams = json.load(json_file)

In [3]:
_, test_mnist = get_MNIST(hparams['num_classes'])

model = Glow(
    hparams['img_shape'], 
    hparams['hidden_channels'], 
    hparams['K'], 
    hparams['L'], 
    hparams['actnorm_scale'],
    hparams['flow_permutation'], 
    hparams['flow_coupling'], 
    hparams['LU_decomposed'], 
    hparams['num_classes'],
    hparams['learn_top'], 
    hparams['y_condition']
)

model.load_state_dict(torch.load(output_folder + model_name)['model'])
model.set_actnorm_init()
model = model.to(device)
model = model.eval()

In [4]:
def sample(model, n_samples=4):
    with torch.no_grad():
        if hparams['y_condition']:
            y = torch.zeros(hparams['num_classes']*n_samples, hparams['num_classes']).to(device)
            for i in range(hparams['num_classes']):
                y[n_samples*i: n_samples*(i+1), i] = 1.
        else: y = None
        images = postprocess(model(y_onehot=y, temperature=0.7, reverse=True))
    return images.cpu()

In [6]:
n_samples = 12
images = sample(model, n_samples)
grid = make_grid(images, nrow=n_samples).permute(1, 2, 0)

plt.figure(figsize=(14, 14))
plt.imshow(grid)
plt.axis('off')
plt.title("Class-conditional generated examples", fontsize=26)
plt.show()