In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

class Generator(nn.Module):
    def __init__(self, z_dim):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(z_dim, 384 * 4 * 4),
            nn.ReLU(True),

            nn.Unflatten(-1, (384, 4, 4)),
            nn.ConvTranspose2d(384, 192, 5, stride=2, padding=2, output_padding=1, bias=False),
            nn.BatchNorm2d(192),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(192, 96, 5, stride=2, padding=2, output_padding=1, bias=False),
            nn.BatchNorm2d(96),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(96, 3, 5, stride=2, padding=2, output_padding=1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        return self.net(x)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
z_dim = 100
# testing
generator_test = Generator(z_dim).to(device)
generator_test.load_state_dict(torch.load('./WGAN_With_Data_Enhance/generator_WGAN_'+str(400)+'.pth'))
generator_test.eval()

noise = torch.normal(0, 1, (10000, z_dim)).to(device)
images = generator_test(noise)
images = images.permute(0, 2, 3, 1)
images = (images + 1.0) / 2.0
images = (images * 255).to(torch.uint8)
print(images.shape)

images_np = images.detach().cpu().numpy()
for i, img_array in enumerate(images_np):
    img = Image.fromarray(img_array)
    # print(img)
    img.save(f'./WGAN_With_Data_Enhance/gen_imgs/image_{i}.png')

torch.Size([10000, 32, 32, 3])
