# Implementing GAN for MNIST

In the previous recipe, we have seen various building block to construct the GAN architecture. In the recipe, we will take the MNIST dataset and design the  GAN having Generator and Discriminator. Herein we will also understand how to implement the loss function and how to provide fake and true images to the Discriminator. 

# Importing Requirements

In [None]:
import itertools
import math

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from IPython import display
from tensorboardX import SummaryWriter
from torch.autograd import Variable

%matplotlib inline 
SEED = 1234
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

**Load Dataset** :  Each MNIST image is of size 28*28. as we will be using fully connected layers so we will be flattening these image into shape 784.

In [None]:
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

train_dataset = dsets.MNIST(root='./data/', train=True, download=True, transform=transform)

def train_loader(batch_size):
    for i in range(0, len(train_dataset)-batch_size,batch_size):
        yield torch.Tensor(np.array(train_dataset.data[i:i+batch_size])).type(torch.FloatTensor)

# Model
**The Discriminator:**  It takes an image of size 784 and gradually shrinks to 1 by passing it through 3 fully connected layers. if the output probability is toward 1 he the input image is classified as true or else it is classified as fake.

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        out = self.model(x.view(x.size(0), 784))
        out = out.view(out.size(0), -1)
        return out

**The Generator:** The Generator looks like as given below. The generator takes a random vector of size 100. It is having 3 fully connected layers each subsequently dilate the input shape and output shape of 784   so that it can form an image of size (28*28). This is the image generated by the Generator. 

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 784),
            nn.LeakyReLU(0.2, inplace=True),
        )
    
    def forward(self, x):
        x = x.view(x.size(0), 100)
        out = self.model(x)
        return out

In [None]:
discriminator = Discriminator().to(device)
generator = Generator().to(device)

**Defining loss and Optimizer for Generator/Discriminator**

In [None]:
criterion = nn.BCELoss()
lr = 0.0002
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)

In [None]:
def train_discriminator(discriminator, images, real_labels, fake_images, fake_labels):
    discriminator.zero_grad()
    outputs = discriminator(images)
    real_loss = criterion(outputs, real_labels)
    real_score = outputs
    
    outputs = discriminator(fake_images) 
    fake_loss = criterion(outputs, fake_labels)
    fake_score = outputs

    d_loss = real_loss + fake_loss
    d_loss.backward()
    d_optimizer.step()
    return d_loss, real_score, fake_score

In [None]:
def train_generator(generator, discriminator_outputs, real_labels):
    generator.zero_grad()
    g_loss = criterion(discriminator_outputs, real_labels)
    g_loss.backward()
    g_optimizer.step()
    return g_loss

In [None]:
# draw samples from the input distribution to inspect the generation on training 
num_test_samples = 16
test_noise = Variable(torch.randn(num_test_samples, 100).cuda())

# Training
**Training Process:** The overall training process has the following steps, listed below along with the code. 

1. Generating random vector of size 100.

```python
noise = Variable(torch.randn(images.size(0), 100).cuda())
```

2. Generating fake images by passing it through Generator, and generating labels with all zeros for these images.

```python
fake_images = generator(noise)
fake_labels = Variable(torch.zeros(images.size(0)).cuda())
```

3. Training Discriminator by using fake image along with labels and real images along with labels. In return after training the discriminator provides the `d_loss` which is a summation of the loss generated by real images and fake images. This function also provides `real_score` and  `fake_score`. `real_score` indicates the prediction of the discriminator for images with real labels, ideally, this output should be near to 1.  `fake_score` indicates the prediction of the discriminator for images with fake labels, ideally, this output should be near to 0. Monitoring   `real_score` and fake_score provides a good idea about the convergence of the discriminator.

4. Then-after fake images are generated by the generator. The label for these images is all ones(The discriminator should treat all the images as real images for loss to be 0). These fake images are passed on to the discriminator and the output generated by discriminator and the original labels are used for the loss calculation. On the basis of images generated a provided to the discriminator and label predicted by the discriminator the generator loss `g_loss` is calculated.


In [None]:
# create figure for plotting
size_figure_grid = int(math.sqrt(num_test_samples))
fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(6, 6))
for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
    ax[i,j].get_xaxis().set_visible(False)
    ax[i,j].get_yaxis().set_visible(False)

# set number of epochs and initialize figure counter
num_epochs = 200
num_fig = 0

writter = SummaryWriter()

total_iterations = 0
for epoch in range(num_epochs):
    for n , images in enumerate(train_loader(batch_size=100)):
        images = Variable(images.cuda())
        real_labels = Variable(torch.ones(images.size(0)).cuda())
        # Sample from generator
        noise = Variable(torch.randn(images.size(0), 100).cuda())
        fake_images = generator(noise)
        fake_labels = Variable(torch.zeros(images.size(0)).cuda())
        
        # Train the discriminator
        d_loss, real_score, fake_score = train_discriminator(discriminator, images, real_labels, fake_images, fake_labels)
        
        # Sample again from the generator and get output from discriminator
        noise = Variable(torch.randn(images.size(0), 100).cuda())
        fake_images = generator(noise)
        outputs = discriminator(fake_images)

        # Train the generator
        g_loss = train_generator(generator, outputs, real_labels)
        
        if (n+1) % 100 == 0:
            test_images = generator(test_noise)
            
            for k in range(num_test_samples):
                i = k//4
                j = k%4
                ax[i,j].cla()
                ax[i,j].imshow(test_images[k,:].data.cpu().numpy().reshape(28, 28), cmap='Greys')
            display.clear_output(wait=True)
            display.display(plt.gcf())
            
            writter.add_scalar("Generator/Loss",g_loss,total_iterations)
            writter.add_scalar("Discriminator/Loss",d_loss ,total_iterations)
            writter.add_scalar("Score/Real",real_score.data.mean() ,total_iterations)
            writter.add_scalar("Score/Fake",fake_score.data.mean(),total_iterations)
        total_iterations = total_iterations +  1

fig.close()

After training for few iteration following are the sample generated by the Generator. Training more could provide better output.

![](figures/mnist_generated.png)

Figure: Sample generated by Generator on the MNIST data
