# Generative adversarial network (GAN) basics

The `13_gan_basics` notebook introduces the foundational concepts of Generative Adversarial Networks (GANs), a powerful class of models used for generating realistic data. GANs consist of two networks, a Generator and a Discriminator, that are trained simultaneously in a competitive setting. 

This notebook covers preparing the dataset, defining the Generator model responsible for creating synthetic data, and defining the Discriminator model, which distinguishes between real and generated data.

---

**P.S.:** Since GAN models are so vast, this notebook is mostly an introduction to GAN itself. Training a GAN model takes place in a notebook right after this one.

## Table of contents

1. [Understanding GANs](#understanding-gans)
2. [Setting up the environment](#setting-up-the-environment)
3. [Preparing the dataset](#preparing-the-dataset)
4. [Defining the Generator model](#defining-the-generator-model)
5. [Defining the Discriminator model](#defining-the-discriminator-model)

## Understanding GANs

Generative Adversarial Networks (GANs) are a class of machine learning models used to generate new data that resembles a given dataset. GANs consist of two neural networks that compete against each other: the **generator** and the **discriminator**. The generator attempts to create realistic data, while the discriminator tries to distinguish between real data (from the training set) and fake data (created by the generator).

### **How GANs work**

GANs are based on a game-theoretic scenario where two players — the generator and the discriminator — are set against each other in a zero-sum game. The generator’s goal is to produce data that looks as real as possible, while the discriminator’s goal is to detect whether the input data is real (from the dataset) or fake (from the generator). Through this adversarial process, the generator gradually improves its ability to create realistic data, and the discriminator becomes better at distinguishing between real and generated data.

#### **Generator**

The generator is a neural network responsible for creating new data instances from random noise or latent vectors. Its goal is to generate data that is indistinguishable from the real data provided in the training set. The generator doesn’t have access to the real data directly. Instead, it learns through feedback from the discriminator.

The generator takes a random input (often referred to as the **latent vector**) and maps it to a data space that resembles the real data. Initially, the generator produces poor-quality samples, but over time, it learns to generate more realistic data as it receives feedback from the discriminator.

#### **Discriminator**

The discriminator is a separate neural network tasked with determining whether a given data instance is real or generated (fake). It acts as a binary classifier, outputting a probability indicating whether the input data is from the real dataset or the generator.

The discriminator is trained on both real data (labeled as real) and generated data (labeled as fake). Its goal is to correctly classify real and generated data. As the generator improves, the discriminator must also improve to distinguish between real and high-quality generated data.

### **The adversarial process**

The training process of GANs involves alternating between two phases:

1. **Training the discriminator**: The discriminator is presented with both real data from the training set and fake data produced by the generator. It is trained to classify real data as real and fake data as fake. The discriminator’s goal is to maximize its ability to correctly classify the inputs.
2. **Training the generator**: The generator is trained to produce data that fools the discriminator. The generator’s objective is to generate data that the discriminator misclassifies as real. The generator receives feedback from the discriminator in the form of gradients, which help it adjust its weights to produce more convincing fake data.

This adversarial process is iterative, with both networks improving over time. As the generator becomes better at creating realistic data, the discriminator must also improve its ability to detect fakes. The two networks are in constant competition, pushing each other to become more accurate.

### **Role of the latent space**

The generator in a GAN takes random noise as input, typically from a **latent space**. This latent space is a lower-dimensional representation of the data distribution, and the generator learns to map this latent space to the data distribution of the real dataset. By exploring different points in the latent space, the generator can create diverse data samples that resemble the training data.

In practical applications, exploring the latent space allows us to generate a wide range of data, even from random inputs. For example, in image generation, different points in the latent space correspond to different styles or variations of generated images.

### **Challenges in training GANs**

Training GANs is known to be challenging for several reasons:

- **Mode collapse**: Mode collapse occurs when the generator produces limited diversity in its outputs, effectively collapsing into generating only a few variations of data. This happens when the generator focuses too much on fooling the discriminator in specific ways without covering the entire data distribution.
- **Vanishing gradients**: If the discriminator becomes too strong early in training, it may classify all generator outputs as fake with high confidence. This results in very small gradient updates for the generator, hindering its ability to improve.
- **Training instability**: Since GANs involve two networks learning simultaneously, their losses are interdependent, leading to potential instability in the training process. Finding the right balance between the generator and discriminator’s learning rates and capacities is critical for stable training.

### **Applications of GANs**

GANs have gained popularity for their ability to generate high-quality, realistic data across a variety of domains. Some of the common applications of GANs include:

- **Image generation**: GANs can generate realistic images from random noise. They are widely used in tasks like photo-realistic image generation, face synthesis, and style transfer.
- **Data augmentation**: GANs can generate additional data samples to augment training datasets, which is useful in fields where labeled data is scarce.
- **Super-resolution**: GANs can be used to enhance the resolution of images, generating high-quality versions of low-resolution inputs.
- **Video generation**: GANs can generate video frames based on previous ones, leading to applications in video synthesis and prediction.
- **Art and creativity**: GANs are being used in creative fields to generate art, music, and other forms of media, often producing novel and unique outputs.
- **Text-to-image generation**: GANs can be used to generate images from textual descriptions, enabling tasks like automatic illustration or image captioning.

### **Variants of GANs**

Over time, several variants of the basic GAN framework have been developed to address specific challenges or extend GANs' capabilities:

- **DCGAN (Deep Convolutional GAN)**: This variant uses convolutional neural networks in the generator and discriminator to improve the quality of generated images.
- **Conditional GAN (cGAN)**: In cGANs, both the generator and discriminator are conditioned on additional information, such as class labels. This allows the generator to produce data that belongs to a specific class, enabling more control over the generated outputs.
- **CycleGAN**: This model enables image-to-image translation tasks without the need for paired training examples. It has been applied to tasks such as converting photos to paintings or transferring styles between different image domains.
- **Wasserstein GAN (WGAN)**: The WGAN improves training stability by using a different distance metric (the Wasserstein distance) to measure the difference between real and generated data distributions. This helps mitigate issues related to vanishing gradients and mode collapse.

## Setting up the environment


##### **Q1: How do you install the necessary libraries for working with GANs in PyTorch?**


In [1]:
# !pip install torch torchvision matplotlib

##### **Q2: How do you import the required modules for building and training GANs in PyTorch?**


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

##### **Q3: How do you set up the environment to use a GPU, and how do you fallback to CPU if necessary in PyTorch?**


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


##### **Q4: How do you set a random seed in PyTorch to ensure reproducibility when training a GAN?**

In [5]:
seed = 42
torch.manual_seed(seed)
if device == "cuda":
    torch.cuda.manual_seed(seed)

## Preparing the dataset


##### **Q5: How do you load an image dataset such as MNIST or CIFAR-10 using `torchvision.datasets` in PyTorch?**


In [6]:
dataset = torchvision.datasets.SVHN(root="./data", split="train", download=True)

Downloading http://ufldl.stanford.edu/housenumbers/train_32x32.mat to ./data/train_32x32.mat


100%|██████████| 182040794/182040794 [07:06<00:00, 426516.86it/s] 


##### **Q6: How do you apply transformations such as resizing and normalization to the dataset to prepare it for training a GAN?**


In [7]:
transform = transforms.Compose([
    transforms.Resize(64),  # Resize images to 64x64
    transforms.ToTensor(),  # Convert images to tensor
    transforms.Normalize([0.5], [0.5])  # Normalize pixel values between -1 and 1
])

transformed_dataset = torchvision.datasets.SVHN(root="./data", split="train", download=True, transform=transform)

Using downloaded and verified file: ./data/train_32x32.mat


##### **Q7: How do you create DataLoaders in PyTorch to efficiently load batches of data for GAN training?**

In [8]:
batch_size = 128
dataloader = DataLoader(transformed_dataset, batch_size=batch_size, shuffle=True)

## Defining the Generator model


##### **Q8: How do you define the architecture of the Generator model using PyTorch’s `nn.Module`?**


In [9]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),  # Latent vector size of 100
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),  # Output 3-channel image (RGB)
            nn.Tanh()  # Output between -1 and 1
        )

    def forward(self, input):
        return self.main(input)

##### **Q9: How do you create the latent vector (noise) that serves as input to the Generator model?**


In [10]:
latent_vector_size = 100
noise = torch.randn(batch_size, latent_vector_size, 1, 1, device=device)

##### **Q10: How do you implement the forward pass for the Generator model in PyTorch to output fake data?**

In [11]:
generator = Generator().to(device)

fake_images = generator(noise)

## Defining the Discriminator model


##### **Q11: How do you define the architecture of the Discriminator model using PyTorch’s `nn.Module`?**


In [12]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()  # Output between 0 and 1 (probability)
        )

    def forward(self, input):
        return self.main(input)

##### **Q12: How do you implement the forward pass for the Discriminator model to classify real and fake data?**


In [13]:
discriminator = Discriminator().to(device)

real_output = discriminator(torch.randn(batch_size, 3, 64, 64, device=device))  # Simulated real images
fake_output = discriminator(fake_images.detach())  # Fake images from the Generator

##### **Q13: How do you initialize the weights for both the Generator and Discriminator models in PyTorch?**

In [14]:
def weights_init(m):  # Custom weights initialization function for both Generator and Discriminator
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
        
generator.apply(weights_init)
discriminator.apply(weights_init)  # Applying weight initialization

Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)

In [None]:
import shutil
import os

if os.path.exists('data'):
    shutil.rmtree('data')
    print("Folder 'data' has been deleted.")
else:
    print("Folder 'data' does not exist.")