# **Generative adversarial networks**

In the following notebook we will look at a simple example of a GAN (Generative adversarial network) to understand the concept of its training. The purpose of a GAN is to create new images that appear to be taken samples from a given distribution, in this case handwritten digits (the MNIST dataset).

In [0]:
"""
The implementation follows the implementation of simple_GAN of vamsi3
"""

import os
import imageio
import numpy as np
import matplotlib.pyplot as plt

# Importing torch modules
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

# For MNIST dataset and visualization
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
from torchvision.utils import make_grid



In [0]:
# Getting the hyper parameters

epochs = 200
batch_size = 64
learning_rate = 0.0002
b1 = 0.5  # beta values for optimizer
b2 = 0.999  # beta values for optimizer
latent_dim = 100
img_size = 28
channels = 1
output_dir = 'output'

img_shape = (channels, img_size, img_size)

# Check CUDA's presence
cuda_is_present = True if torch.cuda.is_available() else False


# **The two opponents**
In the next step, we are creating the generator and discriminator class for our training. The purpose of the generator is to create images which seem to be a normal sample from the MNIST dataset. The job of the discriminator is to find the fake samples that the generator is mixing into the MNIST dataset.

In [0]:
class Generator(nn.Module):
	def __init__(self):
		super().__init__()

    # This function yields an easy way to define a whole layer containing the
    # the linear unit, the batch normalisation and the activation function
		def layer_block(input_size, output_size, normalize=True):
			layers = [nn.Linear(input_size, output_size)]
			if normalize:
				layers.append(nn.BatchNorm1d(output_size, 0.8))
			layers.append(nn.LeakyReLU(0.2, inplace=True))
			return layers

    # The model is upsampling from the latent space to a full fake image. The
    # '*' unpacks the list of layers from 'layer_block' into its components so
    # they can be put together to form the full model.

		self.model = nn.Sequential(
			*layer_block(latent_dim, 128, normalize=False),
			*layer_block(128, 256),
			*layer_block(256, 512),
			*layer_block(512, 1024),

      # np.prod() returns the product of all inputs, so for 'img_shape':
      # channel * img_size * img_size. Therefore we now received all the pixels
      # we need to reconstruct the image.

			nn.Linear(1024, int(np.prod(img_shape))),
			nn.Tanh()  # The values of the image are forced to be between -1 and 1
		)

	def forward(self, z):
		img = self.model(z)
		img = img.view(img.size(0), *img_shape)  # Vector is reshaped into image
    # This is the image the generator is trying to fool the classifier with.
		return img  

class Discriminator(nn.Module):
	def __init__(self):
		super().__init__()

    # This model tries to classify, whether an image is real or not. So the
    # classification task is not to define which number is shown, but just if
    # the number was drawn by a human or generated by the generator.

		self.model = nn.Sequential(
			nn.Linear(int(np.prod(img_shape)), 512),
			nn.LeakyReLU(0.2, inplace=True),
			nn.Linear(512, 256),
			nn.LeakyReLU(0.2, inplace=True),
			nn.Linear(256, 1),  # There is only one output (since only two classes)
			nn.Sigmoid()  # Output is now between 0 and 1 (1 = real, 0 = fake)
		)

	def forward(self, img):
		img_flat = img.view(img.size(0), -1)  # The image is reshaped into a vector.
		verdict = self.model(img_flat)
		return verdict



# **Getting ready**
Before we start the training, the finish up the setup and load our dataset MNIST, which is containing several thousand images of handwritten digits. These will be used as the distribution that the generator tries to imitate.

In [0]:
# Utilize CUDA if available
generator = Generator()
discriminator = Discriminator()
adversarial_loss = torch.nn.BCELoss()  # Binary cross entropy loss

# The binary cross entropy loss only knows to answers: true and false. It
# becomes zero, when the probability for the correct answer is 100% and
# infinity when it is 0%.

if cuda_is_present:
	generator.cuda()
	discriminator.cuda()
	adversarial_loss.cuda()

# Loading MNIST dataset
os.makedirs('data/mnist', exist_ok=True)

# Now the dataloader is created, using the predefined dataset.MNIST class from
# torchvision to load the MNIST images.

data_loader = torch.utils.data.DataLoader(
	datasets.MNIST('/data/mnist', train=True, download=True,
		transform=transforms.Compose([
				transforms.ToTensor(),
				transforms.Normalize([0.5], [0.5])
			])),
	batch_size=batch_size, shuffle=True)

# **The fight begins**
During training, the generator starts creating several fake images to fool the discriminator. The discriminator gets to see both real and fake images and learns to distinguish the two groups. So in order to still fool the discriminator, the generator has to become better and better in his job of creating realistic images. Therefore, after training we receive a generator who is now capable of creating realistic samples of handwritten digits.

In [0]:
# Training the GAN

Tensor = torch.cuda.FloatTensor if cuda_is_present else torch.FloatTensor

# Two optimizers are defined, since the generator and the discriminator are going
# to complete with each other during training:
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=(b1, b2))
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(b1, b2))

losses = []
images_for_gif = []
for epoch in range(1, epochs+1):
  print('Starting epoch number: {}'.format(epoch))
  for i, (images, _) in enumerate(data_loader):

    real_images = Variable(images.type(Tensor))
    real_output = Variable(Tensor(images.size(0), 1).fill_(1.0), requires_grad=False)
    fake_output = Variable(Tensor(images.size(0), 1).fill_(0.0), requires_grad=False)

    # Training Generator
    optimizer_generator.zero_grad()
    z = Variable(Tensor(np.random.normal(0, 1, (images.shape[0], latent_dim))))

    # The generator creates a batch of fake images:
    generated_images = generator(z)

    # The discriminator now evaluates the images. The more of them he believes
    # to be real, the smaller becomes the generator_loss, since he fools the
    # discriminator nicely:
    generator_loss = adversarial_loss(discriminator(generated_images), real_output)
    generator_loss.backward()
    optimizer_generator.step()

    # Training Discriminator
    optimizer_discriminator.zero_grad()

    # The discriminator receives the fake and real images and has to determine
    # what is real and what is fake. He receives a small loss, if both the
    # real images are determined as real and the fake images are determined as
    # fake:
    discriminator_loss_real = adversarial_loss(discriminator(real_images), real_output)
    discriminator_loss_fake = adversarial_loss(discriminator(generated_images.detach()), fake_output)
    discriminator_loss = (discriminator_loss_real + discriminator_loss_fake) / 2
    discriminator_loss.backward()
    optimizer_discriminator.step()

    # print("[Epoch {}/{}] [Batch {}/{}] ---> ".format(epoch, epochs, i, len(data_loader)),
    #     "[D Loss: {}] [G Loss: {}]".format(np.round(discriminator_loss.item(), 4),
    #                       np.round(generator_loss.item(), 4)))

  # Here, some generated fake images for every epoch are visualised:
  losses.append((generator_loss.item(), discriminator_loss.item()))
  grid_img = make_grid(generated_images.data[:25].cpu(), nrow=5)
  plt.imshow(grid_img.permute(1, 2, 0))
  plt.show()

# Visualizing the losses at every epoch
losses = np.array(losses)
plt.plot(losses.T[0], label='Generator')
plt.plot(losses.T[1], label='Discriminator')
plt.title("Training Losses")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()