<a href="https://colab.research.google.com/github/emmaguo13/ml-fundies/blob/main/GAN_Tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Training a GAN on Mnist

In [None]:
from typing import Tuple
import math

import torch
import torch.nn as nn
import wandb #if you dont have wandb setup, please install it and ask divi for the login

In [None]:
%load_ext autoreload
%autoreload 2
# so you dont have to restart kernel if you make changes to other files

For the first part of this notebook, we are going to create the Generator model and the Descriminator model.

For a simple example of how to do this, check models.py in the accompanying repo. The models in this file are just one layer -- we would like to train a model that is larger than this so that it will do better.

## Generator
This is the model that will take in random noise and output an image. We will decide the size of the random noise and images later on in this notebook, so just make sure your model will take in a vector of size `input_length` and output an image of size `output_length`. Ensure your model has multiple layers, but the design of the rest of the model is up to you. If you choose to train on larger images, (Ex: 64x64 or 128x128) consider using some deconvolutional layers.

## Descriminator
This is the model that will take in an image and output a probability that this image is real. Thus, it will take in an image of size `output_length` and output 1 value. You can think of this as just a binary classification model (therefore we will train on binary cross entropy). As before, create your Discriminator model with multiple layers and the correct input and output sizes. If you choose to train on larger images, (Ex: 64x64 or 128x128) consider using some convolutional layers.

In [None]:
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, input_length: int, output_length: int):
        super(Generator, self).__init__()
        # TODO instantiate the layers

    def forward(self, x):
        # TODO pass x through the layers + activations and return it
        return


class Discriminator(nn.Module):
    def __init__(self, input_length: int):
        super(Discriminator, self).__init__()
        # TODO instantiate the layers

    def forward(self, x):
        # TODO pass x through the layers + activations and return it
        return 


# Loading the Dataset
For this notebook, we are going to train on the CelebA dataset in order to create pictures of faces.

In [None]:
from torch.utils.data import DataLoader
from torchvision.datasets import CelebA, MNIST
import torchvision.transforms as transforms 

# TODO fill in these parameters with whatever you want
image_size = None # should be a tuple
batch_size = None

# first we create some transforms to normalize our data and resize it to the shape that we specify
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor()
])

# this will download the celeba dataset for you, if it doesn't work
# then find it online
dataset = MNIST("~/datasets", train=True, download=True, transform=transform)
data_loader = DataLoader(dataset=dataset, batch_size=batch_size)

Now lets look at some of the images we will be training on.

In [None]:
import matplotlib.pyplot as plt

samples, label = next(iter(data_loader))
sample = samples[0].numpy().transpose(1,2,0) #pytorch has channels as first dim, we need it as last dim

plt.imshow(sample, cmap="gray")
plt.show()
print(label[0])

Now that we have loaded our dataset and decided our image size, lets instantiate our generator and discriminator.

In [None]:
# Choose the size of the noise vector to give to our generator
input_length = None

# Instantiate Models
generator = None
discriminator = None

# It is always good to print out your model 
# to make sure it is what you are expecting
print(generator)
print(discriminator)

In [None]:
# sanity check to make sure your shapes are correct
outputs = generator(torch.zeros(batch_size, input_length))
outputs = outputs.view(batch_size, 1, image_size[0], image_size[1])
fake_prob = discriminator(outputs)

# visualize the image created by an untrained generator, this should look like noise
output = outputs[0].detach().numpy().transpose(1,2,0)
plt.imshow(output, cmap="gray")
plt.show()

## Optimizers
We need two optimizers to train our two models. Choose the learning rate for these below. Feel free to experiment with these!

In [None]:
generator_learning_rate = None
discriminator_learning_rate = None

# Create optimizers with the learning rates of your chooosing
generator_optimizer = torch.optim.Adam(generator.parameters(), lr=generator_learning_rate)
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=discriminator_learning_rate)

## Training
Now we have everything we need to train our GAN! Except the training loop. Fill in the missing code. Once you're done, run the cells below to train your GAN and checkout wandb to see how it is doing!

In [None]:
def train(
    generator,
    discriminator,
    generator_optimizer,
    discriminator_optimizer,
    data_loader,
    batch_size,
    input_length,
    img_shape: tuple,
    epochs: int = 3,
) -> Tuple[nn.Module]:
    """Trains the even GAN

    Args:
        batch_size: The number of examples in a training batch
        epochs: The number of epochs to train for.

    Returns:
        generator: The trained generator model
        discriminator: The trained discriminator model
    """
    # loss is binary cross entropy loss
    loss = nn.BCELoss()
    img_w = img_shape[0]
    img_h = img_shape[1]

    for i in range(epochs):
        for sample in data_loader:
            # zero the gradients on each iteration
            generator_optimizer.zero_grad()

            # Here we create the noise input for generator and pass it through the generator to create our fake data
            noise = torch.randint(0, 2, size=(batch_size, input_length)).float()
            generated_data = generator(noise)
            generated_data = generated_data.view(batch_size, 1, img_w, img_h) # resize to be image shape with channel 1

            # Here we get real data
            true_data = sample[0]
            # TODO: create the labels for the true data, this should be a tensor of size batch_size.
            # Remember we are doing binary classification here.
            true_labels = None

            # TODO: Train the generator
            # AKA do a forward pass and get the loss (loss function is defined above)
            # We invert the labels here and don't train the discriminator because we want the generator
            # to make things the discriminator classifies as true.
            generator_discriminator_out = None
            generator_loss = None
            
            # Notice that we do not call .step on the discriminator_optimizer
            # This is so that we do not update the parameters of the discriminator when training the generator
            generator_loss.backward()
            generator_optimizer.step()

            # TODO: Train the discriminator on the true data
            # AKA do a forward pass and get the loss on the true data
            # We don't invert the labels here, why?
            discriminator_optimizer.zero_grad()
            true_discriminator_out = None
            true_discriminator_loss = None

            # Now we do a forward pass using the fake data and get the loss with inverted labels
            # We add .detach() here so that we do not backprop into the generator when we train the discriminator
            # if you're not sure what this does, thats ok! ask us we love to answer questions
            generator_discriminator_out = discriminator(generated_data.detach())
            generator_discriminator_loss = loss(generator_discriminator_out, torch.zeros(batch_size))
            discriminator_loss = (true_discriminator_loss + generator_discriminator_loss) / 2
            
            discriminator_loss.backward()
            discriminator_optimizer.step()
            
            wandb.log({"Generator Loss": generator_loss, 
                    "Discriminator Loss (on real images)" : true_discriminator_loss,
                    "Discriminator Loss (on fake images)": generator_discriminator_loss,
                    "Discriminator Guess (on fake images)": torch.mean(generator_discriminator_out),
                    "Discriminator Guess (on real images)": torch.mean(true_discriminator_out)})

    return generator, discriminator

In [None]:
wandb.init(project='gan-notebook')

# fill in a name to keep track of which wandb runs are yours
your_name = None
wandb.config.name = your_name

# fill in the number of epochs you want to train for
epochs = 10
wandb.config.epochs = epochs

trained_generator, trained_discriminator = train(generator, discriminator, 
                                                 generator_optimizer, discriminator_optimizer, 
                                                 data_loader, batch_size,
                                                 input_length, image_size, epochs)