# GAN - MNIST

- Build the generator and discriminator components of a GAN from scratch
- Create generator and discriminator loss functions
- Train GAN and visualize the generated images

In [None]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0)

In [None]:
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    """
    Function to visualize images, given tensor of images and number of images with their dimension
    It renders in a uniform grid
    """
    
    image_unflat = image_tensor.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

## Generator
- Create a Generator
- A function to make a single layer/block for generator's NN
- Ech block should include a `linear transformation` to map to another shape, a `batch_normalization` for stabilization
- A non-linear activation function `ReLU` so the output can be tranformed in complex ways

In [None]:
def get_generator_block(input_dim, output_dim):
    '''
    Function for returning a block of the generators nn given input and output dimension
    
    Params:
        input_dim: the dimension of the input vector, a scalar
        output_dim: the dimension of the output vector, a scalar
    Returns:
        a Generators NN layer, with a linear transformation followed by batch normalization and then
        a Relu activation
    '''
    
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.BatchNorm1d(output_dim),
        nn.ReLU(inplace=True)
    )

In [None]:
def test_gen_block(in_features, out_features, num_test=1000):
    block = get_generator_block(in_features, out_features)
    
    assert len(block) == 3
    assert type(block[0]) == nn.Linear
    assert type(block[1]) == nn.BatchNorm1d
    assert type(block[2]) == nn.ReLU
    
    # Check shape
    test_input = torch.randn(num_test, in_features)
    test_output = block(test_input)
    
    assert tuple(test_output.shape) == (num_test, out_features)
    assert test_output.std() > 0.55
    assert test_output.std() < 0.65
    
test_gen_block(25, 12)
test_gen_block(15, 28)
print('Success!')

#### Generator Class
- The noise vector dimension
- The image dimension
- The initial hidden dimension

Using these parameters
- Build a NN with 5 layers/blocks
- Begin with a noise vector, the generator will apply non-linear transformations via the block function
- Until the tensor is mapped to the size of the image to be outputted - the same size image as the real images from MNIST
- The final layers does not require a normalization or activation function, but to be scaled with `sigmoid`

In [None]:
class Generator(nn.Module):
    """
    Generator Class
    Values:
        z_dim: the dimension of the noise vector, a scala
        im_dim: the dimension of the image, fitted for the dataset used, a scalar
            (MNIST images are 28 x 28 = 784 so that is your default)
        hidden_dim: the inner dimension, a scala
    """
    def __init__(self, z_dim=10, im_dim=784, hidden_dim=128):
        super(Generator, self).__init__()
        
        self.gen = nn.Sequential(
            get_generator_block(z_dim, hidden_dim),
            get_generator_block(hidden_dim, hidden_dim * 2),
            get_generator_block(hidden_dim * 2, hidden_dim * 4),
            get_generator_block(hidden_dim * 4, hidden_dim * 8),
            
            nn.Linear(hidden_dim * 8, im_dim),
            nn.Sigmoid()
        )
        
    def forward(self, noise):
        """
        Function for completing a forward pass of the generator: Given a noise tensor,
        returns generator images
        
        Parameters:
            noise: a noise tensor with dimensions (n_samples, z_dim)
        """
        return self.gen(noise)
    
    def get_gen(self):
        return self.gen

In [None]:
def test_generator(z_dim, im_dim, hidden_dim, num_test=10000):
    
    gen = Generator(z_dim, im_dim, hidden_dim).get_gen()
    
    # Check there are six modules in the sequential part
    assert len(gen) == 6
    test_input = torch.randn(num_test, z_dim)
    test_output = gen(test_input)
    
    # Check that the output shape is correct
    assert tuple(test_output.shape) == (num_test, im_dim)
    assert test_output.max() < 1, "Make sure to use a sigmoid"
    assert test_output.min() > 0, "Make sure to use a sigmoid"
    assert test_output.std() > 0.05, "Don't use batchnorm here"
    assert test_output.std() < 0.15, "Don't use batchnorm here"
    
test_generator(5, 10, 20)
test_generator(20, 8, 24)
print('Success!')

In [None]:
def get_noise(n_samples, z_dim, device='cpu'):
    return torch.randn(n_samples, z_dim).to(device)

In [None]:
def test_gen_noise(n_samples, z_dim, device='cpu'):
    noise = get_noise(n_samples, z_dim, device)
    
    # Make sure a normal distribution was used
    assert tuple(noise.shape) == (n_samples, z_dim)
    assert torch.abs(noise.std() - torch.tensor(1.0)) < 0.01
    assert str(noise.device).startswith(device)
    
test_gen_noise(1000, 100, 'cpu')
if torch.cuda.is_available():
    test_get_noise(1000, 32, 'cuda')
    
print('Success!')

## Discriminator

In [None]:
def get_discriminator_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.LeakyReLU(negative_slope=0.2)
    )

In [None]:
# Verify the discriminator block function
def test_disc_block(in_features, out_features, num_test=10000):
    block = get_discriminator_block(in_features, out_features)

    # Check there are two parts
    assert len(block) == 2
    test_input = torch.randn(num_test, in_features)
    test_output = block(test_input)

    # Check that the shape is right
    assert tuple(test_output.shape) == (num_test, out_features)
    
    # Check that the LeakyReLU slope is about 0.2
    assert -test_output.min() / test_output.max() > 0.1
    assert -test_output.min() / test_output.max() < 0.3
    assert test_output.std() > 0.3
    assert test_output.std() < 0.5

test_disc_block(25, 12)
test_disc_block(15, 28)
print("Success!")

### Build Discriminator
- The image dimension
- The hidden dimension

Using the above params,
- Build a NN with 4 layers 
- Start with the image tensor and transform it unit it returns a single number ouput
- The output classifies whether an image is fake or real
- A sigmoid is not required after the output layer because it is include in the loss function
- Use discriminator's NN forward pass function that takes in an image tensor to be classified

In [None]:
class Discriminator(nn.Module):
    '''
    Discriminator Class
    Values:
        im_dim: Image dimension (28, 28) for MNIST
        hidden_dim: the inner dimension, a scalar
    '''
    def __init__(self, im_dim=784, hidden_dim=128):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            get_discriminator_block(im_dim, hidden_dim * 4),
            get_discriminator_block(hidden_dim * 4, hidden_dim * 2),
            get_discriminator_block(hidden_dim * 2, hidden_dim),
            
            nn.Linear(hidden_dim, 1)
        )
        
    def forward(self, image):
        '''
        Function for completing a forward pass of the discriminator: Given an image tensor,
        returns a 1d tensor representing fake/real
        
        Parameters:
            image: a flattened image tensor with dimension (im_dim)
        '''
        return self.disc(image)
    
    def get_disc(self):
        return self.disc

In [None]:
print(Discriminator(5, 10).get_disc())

In [None]:
# Verify the discriminator class
def test_discriminator(z_dim, hidden_dim, num_test=100):
    
    disc = Discriminator(z_dim, hidden_dim).get_disc()

    # Check there are three parts
    assert len(disc) == 4

    # Check the linear layer is correct
    test_input = torch.randn(num_test, z_dim)
    test_output = disc(test_input)
    assert tuple(test_output.shape) == (num_test, 1)
    
    # Make sure there's no sigmoid
    assert test_input.max() > 1
    assert test_input.min() < -1

test_discriminator(5, 10)
test_discriminator(20, 8)
print("Success!")

## Training
- criterion: the loss function
- n_epochs: 
- z_dim: noise vector dimension
- display_step: how often to dispay the iage
- batch_size: the number of images per pass
- lr: the learning rate
- device: cuda/cpu


In [None]:
criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.00001

dataloader = DataLoader(
    MNIST('.', download=True, transform=transforms.ToTensor()),
    batch_size=batch_size,
    shuffle=True
)

device = 'cuda'

In [None]:
device = 'cpu'

- Initialize your Generator, Discriminator and Optimizers
- Each optimizer only takes the parameters of one particular model

In [None]:
gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator().to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)

These are the steps you will need to complete:
- Create noise vectors and generate a batch (num_images) of fake images. 
       Make sure to pass the device argument to the noise.
- Get the discriminator's prediction of the fake image 
       and calculate the loss. Don't forget to detach the generator!
       (Remember the loss function you set earlier -- criterion. You need a 
       'ground truth' tensor in order to calculate the loss. 
       For example, a ground truth tensor for a fake image is all zeros.)
- Get the discriminator's prediction of the real image and calculate the loss.
- Calculate the discriminator's loss by averaging the real and fake loss
       and set it to disc_loss.
       
Note: Please do not use concatenation in your solution. The tests are being updated to 
      support this, but for now, average the two losses as described in step (4).
      
*Important*: You should NOT write your own loss function here - use criterion(pred, true)!

In [None]:
def get_disc_loss(
    gen, disc, criterion,
    real, num_images, z_dim, device
):
    '''
    Return the loss of the discriminator given inputs
    Parameter:
        gen: the generator model, which returns an image given z-dimensional noise
        disc: the discriminator model, which return a single-dimensional prediction of real/fake
        criterion: the loss function
        real: a batch of real images
        num_images: the number of images
        z_dim: the dimension of the noise vector, a scalar
        device: the device type
        
    Returns:
        disc_loss: a torch scalar loss value for the current batch
    '''
    
    # Create a noise image of z_dim numbers
    noise = get_noise(num_images, z_dim, device=device)
    # Generate the noise image
    gen_pred = gen(noise)
    
    # Detach the device and get discriminator prediction
    disc_pred = disc(gen_pred.detach())
    # Make fake zeros
    fake_ground_truth = torch.zeros_like(disc_pred, device=device)
    # Calculate the loss between noise and fakezeros
    fake_loss = criterion(disc_pred, fake_ground_truth)
    
    # Predict the discriminator image of real image
    pred_real = disc(real)
    # Calculate the loss with ones
    real_loss = criterion(pred_real, torch.ones_like(pred_real, device=device))
    disc_loss = (fake_loss + real_loss) / 2
    
    return disc_loss
    

In [None]:
def test_disc_reasonable(num_images = 10):
    # Dont use explicit casts to cuda - use the device argument
    # Ensure cuda/cpu is there
    import inspect, re
    lines = inspect.getsource(get_disc_loss)
    assert (re.search(r"to\(.cuda.\)", lines)) is None
    assert (re.search(r"\.cuda\(\)", lines)) is None
    
    z_dim = 64
    gen = torch.zeros_like
    disc = lambda x: x.mean(1)[:, None]
    criterion = torch.mul
    real = torch.ones(num_images, z_dim)
    disc_loss = get_disc_loss(gen, disc, criterion, real, num_images, z_dim, 'cpu')
    assert torch.all(torch.abs(disc_loss.mean() - 0.5) < 1e-5)
    
    gen = torch.ones_like
    criterion = torch.mul
    real = torch.zeros(num_images, z_dim)
    assert torch.all(
        torch.abs(get_disc_loss(
            gen, disc, criterion, real, num_images, z_dim, 'cpu'
        )) < 1e-5
    )
    
    gen = lambda x: torch.ones(num_images, 10)
    disc = lambda x: x.mean(1)[:, None] + 10
    criterion = torch.mul
    real = torch.zeros(num_images, 10)
    assert torch.all(
        torch.abs(
            get_disc_loss(
                gen, disc, criterion, real, num_images, z_dim, 'cpu'
            ).mean() - 5
        ) < 1e-5
    )
    
    gen = torch.ones_like
    disc = nn.Linear(64, 1, bias=False)
    real = torch.ones(num_images, 64) * 0.5
    disc.weight.data = torch.ones_like(disc.weight.data) * 0.5
    disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)
    criterion = lambda x, y: torch.sum(x) + torch.sum(y)
    disc_loss = get_disc_loss(gen, disc, criterion, real, num_images, z_dim, 'cpu').mean()
    disc_loss.backward()
    assert torch.isclose(torch.abs(disc.weight.grad.mean() - 11.25), torch.tensor(3.75))

In [None]:
def test_disc_loss(max_tests=10):
    z_dim = 64
    gen = Generator(z_dim).to(device)
    get_opt = torch.optim.Adam(gen.parameters(), lr=lr)
    disc = Discriminator().to(device)
    disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)
    num_steps = 0
    
    for real, _ in dataloader:
        cur_batch_size = len(real)
        real = real.view(cur_batch_size, -1).to(device)
        
        # Zero out the gradient before backpropagation
        disc_opt.zero_grad()
        
        # Calculate discriminator loss
        disc_loss = get_disc_loss(gen, disc, criterion, real, cur_batch_size, z_dim, device)
        assert (disc_loss - 0.68).abs() < 0.05
        
        # Update gradients
        disc_loss.backward(retain_graph=True)
        
        # Check that they detached correctly
        assert gen.gen[0][0].weight.grad is None
        
        # Update optimizer
        old_weight = disc.disc[0][0].weight.data.clone()
        disc_opt.step()
        new_weight = disc.disc[0][0].weight.data
        
        # Check that some discriminator weights changed
        assert not torch.all(torch.eq(old_weight, new_weight))
        num_steps += 1
        if(num_steps >= max_tests):
            break
            
test_disc_reasonable()
test_disc_loss()
print('Success')

In [None]:
def get_gen_loss(gen, disc, criterion, num_images, z_dim, device):
    noise = get_noise(num_images, z_dim, device=device)
    fake_images = gen(noise.detach())
    
    disc_pred = disc(fake_images)
    grnd_truth_fakes = torch.ones_like(disc_pred, device=device)
    gen_loss = criterion(disc_pred, grnd_truth_fakes)
    
    return gen_loss

In [None]:
def test_gen_reasonable(num_images=10):
    # Don't use explicit casts to cuda - use the device argument
    import inspect, re
    lines = inspect.getsource(get_gen_loss)
    assert (re.search(r"to\(.cuda.\)", lines)) is None
    assert (re.search(r"\.cuda\(\)", lines)) is None
    
    z_dim = 64
    gen = torch.zeros_like
    disc = nn.Identity()
    criterion = torch.mul # Multiply
    gen_loss_tensor = get_gen_loss(gen, disc, criterion, num_images, z_dim, 'cpu')
    assert torch.all(torch.abs(gen_loss_tensor) < 1e-5)
    #Verify shape. Related to gen_noise parametrization
    assert tuple(gen_loss_tensor.shape) == (num_images, z_dim)

    gen = torch.ones_like
    disc = nn.Identity()
    criterion = torch.mul # Multiply
    real = torch.zeros(num_images, 1)
    gen_loss_tensor = get_gen_loss(gen, disc, criterion, num_images, z_dim, 'cpu')
    assert torch.all(torch.abs(gen_loss_tensor - 1) < 1e-5)
    #Verify shape. Related to gen_noise parametrization
    assert tuple(gen_loss_tensor.shape) == (num_images, z_dim)
    

def test_gen_loss(num_images):
    z_dim = 64
    gen = Generator(z_dim).to(device)
    gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
    disc = Discriminator().to(device) 
    disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)
    
    gen_loss = get_gen_loss(gen, disc, criterion, num_images, z_dim, device)
    
    # Check that the loss is reasonable
    assert (gen_loss - 0.7).abs() < 0.1
    gen_loss.backward()
    old_weight = gen.gen[0][0].weight.clone()
    gen_opt.step()
    new_weight = gen.gen[0][0].weight
    assert not torch.all(torch.eq(old_weight, new_weight))


test_gen_reasonable(10)
test_gen_loss(18)
print("Success!")

In [None]:
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
test_generator = True
gen_loss = False
error = False

for epoch in range(n_epochs):
    # Dataloader returns the batches
    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)
        
        # Flatten the batch of real images from dataset
        real = real.view(cur_batch_size, -1).to(device)
        
        # Zero out the gradients before backpropagation
        disc_opt.zero_grad()
        
        # Calculate discriminator loss
        disc_loss = get_disc_loss(
            gen, disc, criterion, real, 
            cur_batch_size, z_dim, device
        )
        
        # Update gradients
        disc_loss.backward(retain_graph=True)        
        # Update optimizer
        disc_opt.step()
        # For testin purposes, to keep track of the generator weights
        if test_generator:
            old_generator_weights = gen.gen[0][0].weight.detach().clone()
            
        # zero out the gradients
        gen_opt.zero_grad()
        gen_loss = get_gen_loss(gen, disc, criterion, cur_batch_size, z_dim, device)
        gen_loss.backward()
        gen_opt.step()
        
        if(test_generator):
            try:
                assert lr > 0.0000002 or (gen.gen[0][0].weight.grad.abs().max() < 0.005 and epoch == 0)
                assert torch.any(gen.gen[0][0].weight.detach().clone() != old_generator_weights)
            except:
                error = True
                print('Runtime tests have failed')
        
        # Keep track of the average discriminator loss
        mean_discriminator_loss += disc_loss.item() / display_step
        # Keep track of the average generator loss
        mean_generator_loss += gen_loss.item() / display_step

        ### Visualization Code ###
        if(cur_step % display_step == 0 and cur_step > 0):
            print(f'Epoch {epoch}, step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}')
            fake_noise = get_noise(cur_batch_size, z_dim, device=device)
            fake = gen(fake_noise)
            show_tensor_images(fake)
            show_tensor_images(real)
            mean_generator_loss = 0
            mean_discriminator_loss = 0
        cur_step += 1