# Imports

In [None]:
from __future__ import print_function
import random
import torch.nn as nn
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
from gan.constants import *
from gan.generator import Generator

# Dataset import

In [None]:
# Create the dataset from the image folder, and apply transformations on it
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))

# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

# Decide which device we want to run on, CPU or GPU with CUDA (Nvidia)
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# Import the generator file and instanciate it

In [None]:
# Create the generator
netG = Generator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

# Load the generators weights from the 'generator' file that was produced by train.ipynb
netG.load_state_dict(torch.load('generator', map_location=torch.device('cpu')))
netG.eval()

# Instanciate the seed

In [None]:
# Set random seed for reproducibility
seed = 2

# use if you want new results
seed = random.randint(1, 10000)

fixed_noise = torch.randn(64, nz, 1, 1, device=device)

random.seed(seed)
torch.manual_seed(seed)
print("Seed: ", seed)

# Generate the images and display them

In [None]:
# Generate the images from the generator
generated_images = netG(fixed_noise).detach().cpu()

In [None]:
# Make a grid and display them
im = vutils.make_grid(generated_images, padding=0, normalize=True)
fig = plt.figure(figsize=(20, 20))
fig.set_facecolor('white')
plt.imshow(np.transpose(im, (1, 2, 0)))
plt.axis('off')
plt.show()