In [2]:
# https://github.com/godeastone/GAN-torch/blob/main/models/GAN.py

In [3]:
import os
import torch.nn as nn
import torch.utils.data
import torchvision
from torchvision import transforms
from torchvision.utils import save_image

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Hyper-parameters & Variables setting
num_epoch = 200
batch_size = 100
learning_rate = 0.0002
img_size = 28 * 28
num_channel = 1
dir_name = 'GAN_results'

noise_size = 100
hidden_size1 = 256
hidden_size2 = 512
hidden_size3 = 1024

In [5]:
# Device setting
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Now using {} device'.format(device))

# Create a directory for saving samples
if not os.path.exists(dir_name):
    os.makedirs(dir_name)

Now using cpu device


In [6]:
# Dataset transform setting
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5)])

In [7]:
# MNIST dataset setting
MNIST_dataset = torchvision.datasets.MNIST(root='../../data',
                                            train=True,
                                            transform=transform,
                                            download=True)

In [8]:
# Data loader
data_loader = torch.utils.data.DataLoader(dataset=MNIST_dataset,
                                            batch_size=batch_size,
                                            shuffle=True)

In [9]:
# Declares discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.linear1 = nn.Linear(img_size, hidden_size3)
        self.linear2 = nn.Linear(hidden_size3, hidden_size2)
        self.linear3 = nn.Linear(hidden_size2, hidden_size1)
        self.linear4 = nn.Linear(hidden_size1, 1)
        self.leaky_relu = nn.LeakyReLU(0.2)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.leaky_relu(self.linear1(x))
        x = self.leaky_relu(self.linear2(x))
        x = self.leaky_relu(self.linear3(x))
        x = self.linear4(x)
        x = self.sigmoid(x)
        return x

In [10]:
# Declares generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.linear1 = nn.Linear(noise_size, hidden_size1)
        self.linear2 = nn.Linear(hidden_size1, hidden_size2)
        self.linear3 = nn.Linear(hidden_size2, hidden_size3)
        self.linear4 = nn.Linear(hidden_size3, img_size)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.relu(self.linear3(x))
        x = self.linear4(x)
        x = self.tanh(x)
        return x

In [11]:
# Initialize generator/Discriminator
discriminator = Discriminator()
generator = Generator()

# Device setting
discriminator = discriminator.to(device)
generator = generator.to(device)

# Loss function & Optimizer setting
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=learning_rate)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=learning_rate)

In [13]:
'''
Traing part
'''

for epoch in range(num_epoch):
    for i, (images, label) in enumerate(data_loader):
        
        # make ground truth (labels) -> 1 for real, 0 for fake
        real_label = torch.full((batch_size, 1), 1, dtype=torch.float32).to(device)
        fake_label = torch.full((batch_size, 1), 0, dtype=torch.float32).to(device)

        # reshape real images from MNIST dataset
        real_images = images.reshape(batch_size, -1).to(device)

        # ----------------
        # train Generator
        # ----------------

        # Initialize grad
        g_optimizer.zero_grad()
        d_optimizer.zero_grad()

        # make fake images with generatoro & noise vector 'z'1128.ipynb
        z = torch.randn(batch_size, noise_size).to(device)
        fake_images = generator(z)

        # Compare result of discriminator with fake images & real labels
        # If generator deceives discriminator, g_loss will decrease
        g_loss = criterion(discriminator(fake_images), real_label)

        # Train generator with backpropagation
        g_loss.backward()
        g_optimizer.step()

        # -------------------
        # train Discriminator
        # -------------------

        # Initialize grad
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        
        # make fake images with generator & noise vector 'z'
        z = torch.randn(batch_size, noise_size).to(device)
        fake_images = generator(z)

        # Calculate fake & real loss with generated images above & real images
        fake_loss = criterion(discriminator(fake_images), fake_label)
        real_loss = criterion(discriminator(real_images), real_label)
        d_loss = (fake_loss + real_loss) / 2

        # Train discriminator with backpropagation
        # In this part, we don't train generator
        d_loss.backward()
        d_optimizer.step()

        d_performance = discriminator(real_images).mean()
        g_performance = discriminator(fake_images).mean()

        if (i + 1) % 150 == 0:
            print('Epoch [ {}/{} ] Step [ {}/{} ] d_loss : {:.5f} g_loss : {:.5f}'
                    .format(epoch, num_epoch, i+1, len(data_loader), d_loss.item(), g_loss.item()))
    
    # print discriminator & generator's performance
    print(" Epock {}'s discriminator performance : {:.2f} generator performance : {:.2f}"
            .format(epoch, d_performance, g_performance))

    # Save fake images in each epoch
    samples = fake_images.reshape(batch_size, 1, 28, 28)
    save_image(samples, os.path.join(dir_name, 'GAN_fake_samples{}.png' .format(epoch + 1)))

Epoch [ 0/200 ] Step [ 150/600 ] d_loss : 0.00346 g_loss : 5.94724
Epoch [ 0/200 ] Step [ 300/600 ] d_loss : 0.01195 g_loss : 7.58132
Epoch [ 0/200 ] Step [ 450/600 ] d_loss : 0.00474 g_loss : 9.49991
Epoch [ 0/200 ] Step [ 600/600 ] d_loss : 0.05915 g_loss : 11.94468
 Epock 0's discriminator performance : 0.98 generator performance : 0.01
Epoch [ 1/200 ] Step [ 150/600 ] d_loss : 1.26506 g_loss : 3.36980
Epoch [ 1/200 ] Step [ 300/600 ] d_loss : 0.28695 g_loss : 3.41164
Epoch [ 1/200 ] Step [ 450/600 ] d_loss : 0.27602 g_loss : 1.90944
Epoch [ 1/200 ] Step [ 600/600 ] d_loss : 2.02220 g_loss : 4.21905
 Epock 1's discriminator performance : 0.95 generator performance : 0.21
Epoch [ 2/200 ] Step [ 150/600 ] d_loss : 0.64651 g_loss : 1.44640
Epoch [ 2/200 ] Step [ 300/600 ] d_loss : 1.54195 g_loss : 3.25332
Epoch [ 2/200 ] Step [ 450/600 ] d_loss : 0.02219 g_loss : 5.01066
Epoch [ 2/200 ] Step [ 600/600 ] d_loss : 0.13064 g_loss : 2.34361
 Epock 2's discriminator performance : 0.95 gener