# 5. Vanilla GAN on MNIST with FC layers

### About this notebook

This notebook was used in the 50.039 Deep Learning course at the Singapore University of Technology and Design.

**Author:** Matthieu DE MARI (matthieu_demari@sutd.edu.sg)

**Version:** 1.1 (29/08/2023)

**Requirements:**
- Python 3 (tested on v3.11.4)
- Matplotlib (tested on v3.7.2)
- Numpy (tested on v1.25.2)
- Torch (tested on v2.0.1+cu118)
- Torchvision (tested on v0.15.2+cu118)
- We also strongly recommend setting up CUDA on your machine! (At this point, honestly, it is almost mandatory).

### Imports

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

### Dataset and dataloader

As seen many times before...

In [None]:
# Image transform to be applied to dataset
# - Tensor conversion
transform = transforms.Compose([transforms.ToTensor()])

In [None]:
# MNIST train dataset
mnist = torchvision.datasets.MNIST(root = './data/',
                                   train = True,
                                   transform = transform,
                                   download = True)

In [None]:
# Data loader
batch_size = 32
data_loader = torch.utils.data.DataLoader(dataset = mnist,
                                          batch_size = batch_size, 
                                          shuffle = True)

### Discriminator model as a set of FC layers

Very similar to the Encoder in the Notebook 1, or any image processing model using FC layers really...

In [None]:
# Discriminator
class Dicriminator(nn.Module):
    
    def __init__(self, hidden_size, image_size):
        # Init from nn.Module
        super().__init__()
        
        # FC layers
        self.D = nn.Sequential(nn.Linear(image_size, hidden_size),
                               nn.LeakyReLU(0.2),
                               nn.Linear(hidden_size, hidden_size),
                               nn.LeakyReLU(0.2),
                               nn.Linear(hidden_size, 1),
                               nn.Sigmoid())
        
    def forward(self, x):
        return self.D(x)

### Generator model as a set of FC layers

Generator will be based on Linear layers, and will produce 1d vectors of size 784, matching the size of the MNIST samples after they have been flattened.

In [None]:
# Generator
class Generator(nn.Module):
    
    def __init__(self, latent_size, hidden_size, image_size):
        # Init from nn.Module
        super().__init__()
        
        # FC layers
        self.G = nn.Sequential(nn.Linear(latent_size, hidden_size),
                               nn.ReLU(),
                               nn.Linear(hidden_size, hidden_size),
                               nn.ReLU(),
                               nn.Linear(hidden_size, image_size),
                               nn.Tanh())
        
    def forward(self, x):
        return self.G(x)

### Trainer function

The trainer function will implement the interleaved training discussed in slides.

In [None]:
# Hyperparameters for model generation and training
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 300
batch_size = 32

In [None]:
# Create discriminator model
D = Dicriminator(hidden_size, image_size)
D.to(device)

In [None]:
# Create generator model
G = Generator(latent_size, hidden_size, image_size)
G.to(device)

In [None]:
# Losses and optimizers
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr = 0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr = 0.0002)

In [None]:
# History trackers for training curves
# Keeping track of losses and accuracy scores
d_losses = np.zeros(num_epochs)
g_losses = np.zeros(num_epochs)
real_scores = np.zeros(num_epochs)
fake_scores = np.zeros(num_epochs)

### Training a GAN is difficult (read me before running)

**Note:** running the cell below (our trainer function) will take a long time!

It is also sensitive to unlucky initialization (as discussed in class).

This basically means that the models are not guaranteed to train well at all, as different values for the initial parameters could lead to two different outcomes for the interleaved training.

This is what makes training a GAN extra difficult!

Feel free to try seeding the models and using different seeds to see how they may converge to different results!

In [None]:
total_step = len(data_loader)
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(data_loader):
        # 1. Flatten image
        images = images.view(batch_size, -1).to(device)
        images = Variable(images)
        
        # 2. Create the labels which are later used as input for the BCE loss
        real_labels = torch.ones(batch_size, 1).to(device)
        real_labels = Variable(real_labels)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        fake_labels = Variable(fake_labels)
        
        """
        PART 1: TRAIN THE DISCRIMINATOR
        """

        # 3. Compute BCE_Loss using real images
        # Here, BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
        # Second term of the loss is always zero since real_labels = 1
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs
        
        # 3.bis. Compute BCELoss using fake images
        # Here, BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
        # First term of the loss is always zero since fake_labels = 0
        z = torch.randn(batch_size, latent_size).to(device)
        z = Variable(z)
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs
        
        # 4. Backprop and optimize for D
        # Remember to reset gradients for both optimizers!
        d_loss = d_loss_real + d_loss_fake
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        """
        PART 2: TRAIN THE GENERATOR
        """

        # 5. Generate fresh noise samples and produce fake images
        z = torch.randn(batch_size, latent_size).to(device)
        z = Variable(z)
        fake_images = G(z)
        outputs = D(fake_images)
        
        # 6. We train G to maximize log(D(G(z))
        # instead of minimizing log(1-D(G(z)))
        # (Strictly equivalent but empirically better)
        g_loss = criterion(outputs, real_labels)
        
        # 7. Backprop and optimize G
        # Remember to reset gradients for both optimizers!
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        
        """
        PART 3: UPDATE STATISTICS FOR VISUALIZATION LATER
        """
        
        # 8. Update the losses and scores for mini-batches
        d_losses[epoch] = d_losses[epoch]*(i/(i+1.)) \
            + d_loss.item()*(1./(i+1.))
        g_losses[epoch] = g_losses[epoch]*(i/(i+1.)) \
            + g_loss.item()*(1./(i+1.))
        real_scores[epoch] = real_scores[epoch]*(i/(i+1.)) \
            + real_score.mean().item()*(1./(i+1.))
        fake_scores[epoch] = fake_scores[epoch]*(i/(i+1.)) \
            + fake_score.mean().item()*(1./(i+1.))
        
        # 9. Display
        if (i+1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), 
                          real_score.mean().item(), fake_score.mean().item()))

### Visualization

As usual, show the training curves for losses and accuracies of both models, along with some images produced by the generator.

In [None]:
# Display losses for both the generator and discriminator
plt.figure()
plt.plot(range(1, num_epochs + 1), d_losses, label = 'd loss')
plt.plot(range(1, num_epochs + 1), g_losses, label = 'g loss')    
plt.legend()
plt.show()

In [None]:
# Display accuracy scores for both the generator and discriminator
plt.figure()
plt.plot(range(1, num_epochs + 1), fake_scores, label='fake score')
plt.plot(range(1, num_epochs + 1), real_scores, label='real score')    
plt.legend()
plt.show()

In [None]:
# Generate a few fake samples (5 of them) for visualization
n_samples = 5
z = torch.randn(n_samples, latent_size).to(device)
z = Variable(z)
fake_images = G(z)
fake_images = fake_images.cpu().detach().numpy().reshape(n_samples, 28, 28)
print(fake_images.shape)

In [None]:
# Display
plt.figure()
plt.imshow(fake_images[0])
plt.show()
plt.figure()
plt.imshow(fake_images[1])
plt.show()
plt.figure()
plt.imshow(fake_images[2])
plt.show()
plt.figure()
plt.imshow(fake_images[3])
plt.show()
plt.figure()
plt.imshow(fake_images[4])
plt.show()