In [None]:
pip install -r requirements.txt

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

In [None]:
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 128),
            nn.LeakyReLU(0.1),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.disc(x)

What is the difference between nn.LeakyReLU and nn.ReLU?
nn.LeakyReLU allows a small gradient when the unit is not active (i.e. x < 0) while nn.ReLU has a gradient of 0 when x < 0.

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, img_dim),
            nn.Tanh()
        )

    def forward(self, x):
        return self.gen(x)
    
# Why are we using tanh in the generator and sigmoid in the discriminator?
# The tanh function outputs values between -1 and 1, which is the range of the MNIST images. 
# The sigmoid function outputs values between 0 and 1, which is the range of the discriminator output.

In [None]:
# Hyperparameters
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 5e-4 # learning rate, this is the best as of now and suggested by multiple papers and Andrej Karpathy
z_dim = 128 # latent dim for generator, this is used to generate random noise which is then used to generate images
img_dim = 28 * 28 * 1
batch_size = 16
num_epochs = 25

In [None]:
# Initialize the models
disc = Discriminator(img_dim).to(device)
gen = Generator(z_dim, img_dim).to(device)

In [None]:
# Noise
# Why do we need noise?
# The noise is used as an input to the generator to create a fake image. This noise is sampled from a normal distribution.
fixed_noise = torch.randn((batch_size, z_dim)).to(device)

In [None]:
# Convert the images to be between -1 and 1
transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

This function is part of the PyTorch library, specifically the torchvision.transforms module. It's used to define a sequence of transformations to be applied to an input image or data before feeding it into a neural network for processing. Let's break down what each part of this function does:

transforms.Compose: This function combines several transformations into a single transformation pipeline. It takes a list of transformations as input and applies them sequentially to the input data.

transforms.ToTensor(): This transformation converts input data (such as images) into PyTorch tensors. It's commonly used because neural networks typically operate on tensors rather than raw image data.

transforms.Normalize((0.5,), (0.5,)): This transformation normalizes the tensor by subtracting the mean and dividing by the standard deviation. In this case, it subtracts 0.5 from each pixel value and then divides by 0.5. This effectively scales the pixel values to be between -1 and 1, which is a common practice in deep learning to make training more stable.

So, in summary, the transforms.Compose function defines a transformation pipeline that converts input data into tensors and then normalizes those tensors. This pipeline is often used when preprocessing images or other data for input into neural networks.

In [None]:
# Download and Load the dataset
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
# Optimizers and loss function
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)

loss_fn = nn.BCELoss()

In [None]:
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")

In the context of Generative Adversarial Networks (GANs) training on the MNIST dataset, these lines create two SummaryWriter objects, one for logging generated (fake) images and another for logging real images. The SummaryWriter is a utility provided by PyTorch's TensorBoard integration for logging various metrics and visualizations during model training.

Let's break down what each line does:

writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake"): This line creates a SummaryWriter object named writer_fake. It specifies a directory path runs/GAN_MNIST/fake where the logs for fake images will be stored. Typically, during GAN training, the generator network produces fake images, and you might want to visualize these images over time to see how the generator improves.

writer_real = SummaryWriter(f"runs/GAN_MNIST/real"): Similarly, this line creates another SummaryWriter object named writer_real, but this time it's for logging real images. In the context of GAN training, real images are the ones sampled from the MNIST dataset that serve as the ground truth for the discriminator network. Logging real images can help monitor how well the discriminator distinguishes between real and fake images.

By using these SummaryWriter objects, you can log various information such as images, scalar values (e.g., loss), histograms, and more during the training process. This information can then be visualized in TensorBoard to gain insights into the training progress and the performance of the GAN.

In [None]:
step = 0

for epoch in tqdm(range(num_epochs), desc="Training"):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        loss_disc_real = loss_fn(disc_real, torch.ones_like(disc_real)) # first parameter is prediction and second is target
        disc_fake = disc(fake).view(-1)
        loss_disc_fake = loss_fn(disc_fake, torch.zeros_like(disc_fake)) # first parameter is prediction and second is target
        loss_disc = (loss_disc_real + loss_disc_fake) / 2 # average of both the losses
        disc.zero_grad()
        loss_disc.backward(retain_graph=True) # retain_graph=True is used to retain the computational graph so that we can call backward() again, here we are using we want to use fake again
        opt_disc.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        output = disc(fake).view(-1)
        loss_gen = loss_fn(output, torch.ones_like(output))
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Print losses occasionally and print to tensorboard
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                  Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image("Mnist Fake Images", img_grid_fake, global_step=step)
                writer_real.add_image("Mnist Real Images", img_grid_real, global_step=step)

                step += 1

# Save the model
torch.save(gen.state_dict(), "gen.pth")
torch.save(disc.state_dict(), "disc.pth")

# Close the tensorboard writer
writer_fake.close()
writer_real.close()

Results on changing hyperparameters:
1. lr = 5e-4, z_dim = 64, img_dim = 28 * 28 * 1, batch_size = 16, num_epochs = 25
   1. Results: D Loss = 0.5637 | G Loss = 0.9129
2. lr = 7e-4, z_dim = 32, img_dim = 28 * 28 * 1, batch_size = 16, num_epochs = 25
   1. Results: D Loss = 0.4870 | G Loss = 1.0671
3. lr = 3e-4, z_dim = 64, img_dim = 28 * 28 * 1, batch_size = 32, num_epochs = 25
   1. Results: D Loss = 0.5484 | G Loss = 0.9143
4. lr = 3e-4, z_dim = 64, img_dim = 28 * 28 * 1, batch_size = 16, num_epochs = 30
   1. Results: D Loss = 0.5484 | G Loss = 0.9143

Results: 2 > 1 > 3




How to improve the GAN?
1. Use different architectures for the generator and discriminator (e.g. DCGAN, WGAN, etc.)
2. Use different hyperparameters, one thing to note is that GANs are very sensitive to hyperparameters
3. Use different loss functions
4. Train for longer
5. Add regularization techniques (e.g. weight clipping in WGAN)
6. Use different types of normalization
7. Use different types of noise
8. Use different types of optimizers

What is DCGAN?<br>
DCGAN stands for Deep Convolutional Generative Adversarial Network. It is a type of GAN that uses convolutional layers in both the generator and discriminator.
This allows the model to learn more complex patterns in the data and generate higher quality images. DCGANs are commonly used for image generation tasks and have been shown to produce realistic images in a variety of domains.

What is WGAN?<br>
WGAN stands for Wasserstein Generative Adversarial Network. It is a type of GAN that uses the Wasserstein distance as the loss function instead of the traditional binary cross-entropy loss. And this is used to stabilize the training of GANs and produce higher quality images. WGANs have been shown to be more stable and produce better results than traditional GANs in many cases.

What is weight clipping?<br>
Weight clipping is a regularization technique used in WGANs to enforce a Lipschitz constraint on the discriminator. This involves clipping the weights of the discriminator to a small range after each training step. This helps to prevent the discriminator from becoming too powerful and dominating the training process, which can lead to mode collapse and other issues. Weight clipping has been shown to improve the stability and performance of WGANs in practice.

How is weight clipping different from learning rate?<br>
Weight clipping is a regularization technique that is applied to the weights of the discriminator in a GAN. It involves clipping the weights to a small range after each training step. This helps to prevent the discriminator from becoming too powerful and dominating the training process. Learning rate, on the other hand, is a hyperparameter that controls how much the weights of the model are updated during training. It determines the size of the steps taken in the direction of the gradient during optimization. Both weight clipping and learning rate are important hyperparameters that can affect the performance of a GAN, but they serve different purposes and are applied in different ways.

In [None]:
# Test the generator
with torch.no_grad():
    noise = torch.randn(9, z_dim).to(device)
    img = gen(noise).view(-1, 1, 28, 28)
    img_grid = torchvision.utils.make_grid(img, nrow=3, normalize=True)
    plt.imshow(img_grid.permute(1, 2, 0).cpu().numpy())
    plt.show()