In [11]:
from wgan import WGAN

import torch
import matplotlib.pyplot as plt
import numpy as np

from utils import load_model, load_yaml

In [None]:
# Set the configuration
config = load_yaml("/home/seungwon/Projects/deep-learning-projects/cv-07-wasserstein-generative-adversarial-network-pytorch/config/wgan_config.yml")

# Training setting
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(config['data']['seed'])
if device == 'cuda':
  torch.cuda.manual_seed_all(config['data']['seed'])

# Set the model
model = WGAN(gen_latent_z=config['model']['gen_latent_z'], gen_init_layer=config['model']['gen_init_layer'],
            gen_conv_trans=config['model']['gen_conv_trans'], gen_conv_filters=config['model']['gen_conv_filters'],
            gen_conv_kernels=config['model']['gen_conv_kernels'], gen_conv_strides=config['model']['gen_conv_strides'],
            gen_conv_pads=config['model']['gen_conv_pads'],gen_dropout_rate=config['model']['gen_dropout_rate'],
            crt_input_img=config['model']['crt_input_img'], crt_conv_filters=config['model']['crt_conv_filters'],
            crt_conv_kernels=config['model']['crt_conv_kernels'], crt_conv_strides=config['model']['crt_conv_strides'],
            crt_conv_pads=config['model']['crt_conv_pads'], crt_dropout_rate=config['model']['crt_dropout_rate']).to(device)
model, _, _, _ = load_model('/home/seungwon/Projects/deep-learning-projects/cv-07-wasserstein-generative-adversarial-network-pytorch/models/emnist/2023.07.21.15.32.37/emnist_last.pt', model)
print(model)

In [None]:
n_to_show = 10
z_latent = config['model']['gen_latent_z']
np.random.seed(777)

fig = plt.figure(figsize=(15, 3))
fig.subplots_adjust(hspace=0.4, wspace=0.4)
model.eval()

z = 2 * torch.rand(n_to_show, z_latent, device=device) - 1
img = model.G(z)

for i in range(n_to_show):     
    sub = fig.add_subplot(1, n_to_show, i+1)
    sub.axis('off')
    sub.imshow(img[i].squeeze(0).cpu().detach().numpy(), cmap='gray_r')
