# Generative Adversarial Networks
GANs are a unique type of neural network that pits two models against each other in order to ultimately create a model that can generate novel output when provided random noise. The "discriminator" model has the job of taking an input image and guessing whether it's a real image from the dataset or a generated image. The "generator" model takes a noise vector as input and outputs an image, which should hopefully look like the images in the dataset!


![Model Architecture](model_architecture.png "A schematic showing the model architecture")

In [None]:
#Imports
import torch, torchvision
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.nn import Module, Sequential, Linear, ReLU, BatchNorm1d, Sigmoid, Flatten, Unflatten, Conv2d
import matplotlib.pyplot as plt

In [None]:
SEED = 1234567890
torch.manual_seed(SEED)
# Optionally always use cpu (more efficient for smaller models)
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" 
# DEVICE = "cpu"
IMAGE_WIDTH = 28
NOISE_SIZE = 49
print(f"Using random seed: {SEED}")
print(f"Using Device: {DEVICE}")

## Data
The raw MNIST data is stored in a [B, H, W] int8 tensor, so we need to convert and rescale the data so each pixel value lies between 0 and 1.

In [None]:
# Dataloading
class MNISTDataset(Dataset):
    def __init__(self):
        mnist_data = torchvision.datasets.MNIST('data/mnist', 
                                        download=True)
        self.data = mnist_data.data.to(dtype=torch.float32, device=DEVICE) / 255
        self.data = self.data.flatten(1)
    
    def __len__(self):
        return self.data.shape[0]
    
    def __getitem__(self, idx):
        return self.data[idx]

### Dataloading Preview

Take a look at the images from the dataset.

In [None]:
def display_tensor(tensor):
    images = tensor.cpu().reshape(-1, IMAGE_WIDTH, IMAGE_WIDTH).detach().numpy()
    _, ax = plt.subplots(ncols=len(images))
    for i, image in enumerate(images):
        ax[i].set_xticks([])
        ax[i].set_yticks([])
        ax[i].imshow(image, cmap='gray')
    plt.show()

In [None]:
mnist_data = MNISTDataset()
display_tensor(mnist_data[0:10])

# Creating a Discriminator
Before we can start generating images, we need to make our discriminator, a model which will predict whether an input image is from the MNIST dataset or was generated some other way. We'll be dealing with large models, so we'll define a ResidualLayer module to make it easy to implement residual connections for speeding up training.

In [None]:
class ResidualLayer(Module):
    def __init__(self, sequential):
        super().__init__()
        self.sequential = sequential

    def forward(self, x):
        return x + self.sequential(x)

This defines the architecture of the discriminator. Feel free to play around with it if you want to see how things change!

In [None]:
# Discriminator
def create_discriminator():
    return Sequential(
            ResidualLayer(Sequential(
                BatchNorm1d(784), 
                Linear(784, 784),       
                ReLU()
            )),
            BatchNorm1d(784),
            Linear(784, 392),
            ReLU(),
            ResidualLayer(Sequential(
                BatchNorm1d(392),
                Linear(392, 392),
                ReLU()
            )),
            BatchNorm1d(392),
            Linear(392, 1),
            Sigmoid()
        ).to(device=DEVICE)

Here's an alternative discriminator with a convolutional neural network (CNN) architecture. You can plug it in later to see how its results differ, but most notably, it's a bit slower.

In [None]:
# Alternative Discriminator Architecture
def create_cnn_discriminator():
    return Sequential(
            Unflatten(1, (1, IMAGE_WIDTH, IMAGE_WIDTH)),
            Conv2d(1, 16, kernel_size=7, padding=3),
            ReLU(),
            Conv2d(16, 32, kernel_size=5, padding=2),
            ReLU(),
            Conv2d(32, 64, kernel_size=3, padding=1),
            ReLU(),
            Flatten(),
            BatchNorm1d(64 * IMAGE_WIDTH * IMAGE_WIDTH),
            Linear(64 * IMAGE_WIDTH * IMAGE_WIDTH, IMAGE_WIDTH * IMAGE_WIDTH),
            ReLU(),
            BatchNorm1d(IMAGE_WIDTH * IMAGE_WIDTH),
            Linear(IMAGE_WIDTH * IMAGE_WIDTH, 1),
            Sigmoid()
        ).to(device=DEVICE)

As a loss function, we use the negative log likelihood of the discriminator making a correct prediction. To take a training step for the discriminator, we'll pass it a variety of images and labels indicating whether each image is from the MNIST dataset (1) or from our generator model (0). We can then calculate the loss, perform backpropagation, and take a gradient descent step.

In [None]:
def discriminator_loss_fn(predictions, labels, epsilon = 1e-5):
    return -torch.mean(torch.log(epsilon + 1 - torch.abs(predictions - labels))) + epsilon

# Train Discriminator    
def discriminator_train_step(discriminator, real_images, fake_images, optimizer):
    optimizer.zero_grad()

    images = torch.cat([real_images, fake_images])

    # Create labels for the samples
    labels = torch.cat((torch.ones(len(real_images), device=DEVICE), 
                        torch.zeros(len(fake_images), device=DEVICE)))
    
    # Compute predictions and update the parameters
    predictions = discriminator(images).squeeze(dim = 1)
    loss = discriminator_loss_fn(predictions, labels)
    loss.backward()
    optimizer.step()

    return loss

As a test, let's generate some images which are just noise and feed them to the discriminator to see if it can learn to distinguish complete noise from actual MNIST data.

In [None]:
fake_images = torch.randn(5, IMAGE_WIDTH, IMAGE_WIDTH)
display_tensor(fake_images)

In [None]:
torch.manual_seed(SEED)

discriminator = create_discriminator()
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr = 0.0002, betas=(0.5, 0.9))
for real_images in DataLoader(mnist_data, batch_size=256):
    # Generate random noise to train against
    fake_images = torch.randn_like(real_images)

    loss = discriminator_train_step(discriminator, real_images, fake_images, discriminator_optimizer)
    print("Loss: {:10.7f}".format(loss), end='\r')

You should see the loss get very small very quickly: we've learned to distinguish between them! Now that we have a way of benchmarking our generator model, we're ready to start training it.

# Creating a Generator

We'll now create a feed-forward network to serve as the generator. Generating new images turns out to be a more complicated task than classifying them, so this network will be much larger than the discriminator. It takes as input a 49-element vector containing complete noise, and outputs a 784-element vector with elements between 0 and 1 which can be reshaped into an image.

In [None]:
# Generator
def create_generator():
    return Sequential(
            ResidualLayer(Sequential(
                BatchNorm1d(49),
                Linear(49, 49),
                ReLU()
            )),
            BatchNorm1d(49),
            Linear(49, 98),
            ReLU(),
            ResidualLayer(Sequential(
                BatchNorm1d(98),
                Linear(98, 98),
                ReLU()
            )),
            BatchNorm1d(98),
            Linear(98, 196),
            ReLU(),
            ResidualLayer(Sequential(
                BatchNorm1d(196),
                Linear(196, 196),
                ReLU()
            )),
            BatchNorm1d(196),
            Linear(196, 392),
            ReLU(),
            ResidualLayer(Sequential(
                BatchNorm1d(392),
                Linear(392, 392),
                ReLU()
            )),
            BatchNorm1d(392),
            Linear(392, 784),
            ResidualLayer(Sequential(
                BatchNorm1d(784),
                Linear(784, 784),
                ReLU()
            )),
            Sigmoid()
        ).to(device=DEVICE)

As a loss function, we use the negative log likelihood of the discriminator making an *incorrect* prediction, since the generator's goal is to fool the discriminator as often as possible. To take a training step for the generator, we'll pass it a batch of noise vectors, have it generate images, then send those images to the discriminator for classification. Based on the output of that model, we calculate this loss function, then perform backpropagation all the way back to the generator and take a gradient descent step.

In [None]:
def generator_loss_fn(predictions, epsilon = 1e-5):
    return -torch.mean(torch.log(epsilon + predictions))

# Train Generator
def generator_train_step(generator, discriminator, batch_size, generator_optimizer, discriminator_optimizer):
    generator_optimizer.zero_grad()
    discriminator_optimizer.zero_grad()

    noise = torch.randn(batch_size, NOISE_SIZE, device=DEVICE)
    generated_images = generator(noise)
    
    # Compute predictions and update the parameters
    predictions = discriminator(generated_images).squeeze(dim = 1)
    loss = generator_loss_fn(predictions)
    loss.backward()
    generator_optimizer.step()

    return loss

Now let's train our Model! But first we need to define some evaluation functions, so we can see what's going on.

In [None]:
# Evaluation
def evaluate(generator, discriminator, dataloader):
    # Variables for tracking stats as we iterate through the data in batches
    total_samples = 0
    correct_on_real_images = 0
    correct_on_generated_images = 0

    for batch in dataloader:
        # Get number correct on real images
        real_preds = discriminator(batch).squeeze(dim=1)
        correct_on_real_images += (real_preds > 0.5).sum()

        # Get number correct on generated images
        noise = torch.randn(len(batch), NOISE_SIZE, device=DEVICE)
        generated_preds = discriminator(generator(noise)).squeeze(dim=1)
        correct_on_generated_images += (generated_preds < 0.5).sum()
        
        # Track how many images we've seen
        total_samples += len(batch)
    
    # Calculate Accuracies
    real_correct_acc = correct_on_real_images / total_samples
    generated_correct_acc = correct_on_generated_images / total_samples

    return real_correct_acc, generated_correct_acc

And now we can train! For 10 epochs (depending on your hardware, each epoch should take ~30 seconds), we take minibatches of 64 data images, take a step training the discriminator on 64 real images and 64 generated images from the generator, then take a step training the generator on 64 noise inputs. You can watch the accuracy of the discriminator on both data and generated images change from epoch to epoch, and also see a sample of generated images from the generator at the end of each epoch.

In [None]:
torch.manual_seed(SEED)

NUM_EPOCHS = 10

discriminator = create_discriminator() # Try replacing this with: create_cnn_discriminator()
generator = create_generator()

discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr = 0.0002, betas=(0.5, 0.9))
generator_optimizer = torch.optim.Adam(generator.parameters(), lr = 0.0001, betas=(0.5, 0.9))

batch_size=64

test_noise = torch.randn(6, NOISE_SIZE, device=DEVICE)

for epoch in range(NUM_EPOCHS):
    print("-------------------------------------------------")
    print(f"Epoch {1 + epoch}")
    # Training Loop
    for real_images in tqdm(DataLoader(mnist_data, batch_size=batch_size)):
        # Discriminator training loop
        discriminator.train()
        generator.eval()

        noise = torch.randn(batch_size, NOISE_SIZE, device=DEVICE)
        fake_images = generator(noise)
        discriminator_train_step(discriminator, real_images, fake_images, discriminator_optimizer)

        # Generator training subroutine
        discriminator.eval()
        generator.train()
        generator_train_step(generator, discriminator, batch_size, generator_optimizer, discriminator_optimizer)
    
    # Evaluate current performance
    real_correct_acc, generated_correct_acc = evaluate(generator, discriminator, DataLoader(mnist_data, batch_size=2048))
    generator.eval()
    discriminator.eval()
    print()
    print(f"Real Image Accuracy:      {real_correct_acc}")
    print(f"Generated Image Accuracy: {generated_correct_acc}")
    print()
    print("Example images:")
    display_tensor(generator(test_noise))

Depending on how your seed is set, you might see the generated images as complete noise, you may see every generated image showing different versions of the same number, or if you're lucky, you'll see a wide variety of beautiful numbers! As it turns out, these models are quite sensitive to initial conditions, and so on any given attempt to train the model, there can be a wide variety of results. Feel free to modify the code to run more training epochs, run it with different seeds, or play with the architecture of the model (perhaps, for example, by using the CNN architecture for the discriminator rather than the feed-forward architecture there by default).

The phenomenon where most of the generated images converge to being the same one or two digits is a fairly common one, referred to as "mode collapse." There are difficulties getting the model to learn several discrete categories simultaneously without also producing images somehow "between" those categories, which will be recognized as not being numbers, and thus be caught easily by the discriminator. Therefore, the generator can be punished heavily for generating too diverse a spread of images, and so it often learns to be good at producing just a couple specific digits well, rather than all of them.

![And that's it! You may now submit lab 10.](final_image.png "And that's it! You may now submit lab 10")

# Sources
- [Original GAN paper](https://arxiv.org/pdf/1406.2661)
- [Useful article discussing ways of troubleshooting GAN models](https://jonathan-hui.medium.com/gan-why-it-is-so-hard-to-train-generative-advisory-networks-819a86b3750b)
- [A little more about mode collapse](https://neptune.ai/blog/gan-loss-functions)