# Generative adversarial network (GAN) basics

The `13_gan_basics` notebook introduces the foundational concepts of Generative Adversarial Networks (GANs), a powerful class of models used for generating realistic data. GANs consist of two networks, a Generator and a Discriminator, that are trained simultaneously in a competitive setting. 

This notebook covers preparing the dataset, defining the Generator model responsible for creating synthetic data, and defining the Discriminator model, which distinguishes between real and generated data.

---

**P.S.:** Since GAN models are so vast, this notebook is mostly an introduction to GAN itself. Training a GAN model takes place in a notebook right after this one.

## Table of contents

1. [Understanding GANs](#understanding-gans)
2. [Setting up the environment](#setting-up-the-environment)
3. [Preparing the dataset](#preparing-the-dataset)
4. [Defining the Generator model](#defining-the-generator-model)
5. [Defining the Discriminator model](#defining-the-discriminator-model)

## Understanding GANs

Generative Adversarial Networks (GANs) are a class of machine learning models used to generate new data that resembles a given dataset. GANs consist of two neural networks that compete against each other: the **generator** and the **discriminator**. The generator attempts to create realistic data, while the discriminator tries to distinguish between real data (from the training set) and fake data (created by the generator).

### **Key concepts**
Generative Adversarial Networks (GANs) are a class of deep learning models designed for generating realistic data. GANs consist of two neural networks, a **generator** and a **discriminator**, that compete against each other in a zero-sum game. The generator creates fake data, aiming to make it indistinguishable from real data, while the discriminator attempts to distinguish between real and fake data.

Key components of GANs include:
- **Generator**: Takes random noise as input and transforms it into synthetic data.
- **Discriminator**: Evaluates whether the input data is real or generated.
- **Adversarial Training**: The generator and discriminator are trained simultaneously, improving their performance iteratively.
- **Loss Function**: Typically based on a minimax game, balancing the generator’s goal of "fooling" the discriminator and the discriminator’s goal of accurate classification.

GANs are known for their ability to produce high-quality synthetic data, making them powerful tools for generative tasks.

### **Applications**
GANs are widely used in a variety of fields, including:
- **Image generation**: Creating realistic images, such as faces or landscapes, from random inputs.
- **Data augmentation**: Generating additional samples for training datasets, particularly for imbalanced datasets.
- **Image-to-image translation**: Tasks like converting sketches to photos or black-and-white images to color.
- **Video generation**: Generating frames for video content, animations, or simulations.
- **Anomaly detection**: Training on normal data and using the discriminator to identify anomalies.

### **Advantages**
- **Versatile generative capabilities**: Capable of creating high-quality data across various domains.
- **Unsupervised learning**: Requires no labeled data for training, only real examples.
- **Realism**: Generates synthetic data that can closely resemble real-world data.
- **Wide applicability**: Can be used in fields ranging from art and entertainment to scientific simulations.

### **Challenges**
- **Training instability**: GANs are notoriously difficult to train, with issues like mode collapse where the generator produces limited variations.
- **Hyperparameter sensitivity**: Requires careful tuning of learning rates, model architectures, and optimization strategies.
- **Evaluation difficulty**: Measuring the quality of generated data is non-trivial and often subjective.
- **Resource intensity**: Training GANs can be computationally expensive, especially for high-resolution outputs.

## Setting up the environment


##### **Q1: How do you install the necessary libraries for working with GANs in PyTorch?**


In [1]:
# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# !pip install numpy matplotlib scikit-learn pandas

##### **Q2: How do you import the required modules for building and training GANs in PyTorch?**


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

##### **Q3: How do you set up the environment to use a GPU, and how do you fallback to CPU if necessary in PyTorch?**


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


##### **Q4: How do you set a random seed in PyTorch to ensure reproducibility when training a GAN?**

In [5]:
seed = 42
torch.manual_seed(seed)
if device == "cuda":
    torch.cuda.manual_seed(seed)

## Preparing the dataset


##### **Q5: How do you load an image dataset such as MNIST or CIFAR-10 using `torchvision.datasets` in PyTorch?**


In [6]:
dataset = torchvision.datasets.SVHN(root="./data", split="train", download=True)

Downloading http://ufldl.stanford.edu/housenumbers/train_32x32.mat to ./data/train_32x32.mat


100%|██████████| 182040794/182040794 [07:06<00:00, 426516.86it/s] 


##### **Q6: How do you apply transformations such as resizing and normalization to the dataset to prepare it for training a GAN?**


In [7]:
transform = transforms.Compose([
    transforms.Resize(64),  # Resize images to 64x64
    transforms.ToTensor(),  # Convert images to tensor
    transforms.Normalize([0.5], [0.5])  # Normalize pixel values between -1 and 1
])

transformed_dataset = torchvision.datasets.SVHN(root="./data", split="train", download=True, transform=transform)

Using downloaded and verified file: ./data/train_32x32.mat


##### **Q7: How do you create DataLoaders in PyTorch to efficiently load batches of data for GAN training?**

In [8]:
batch_size = 128
dataloader = DataLoader(transformed_dataset, batch_size=batch_size, shuffle=True)

## Defining the Generator model


##### **Q8: How do you define the architecture of the Generator model using PyTorch’s `nn.Module`?**


In [9]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),  # Latent vector size of 100
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),  # Output 3-channel image (RGB)
            nn.Tanh()  # Output between -1 and 1
        )

    def forward(self, input):
        return self.main(input)

##### **Q9: How do you create the latent vector (noise) that serves as input to the Generator model?**


In [10]:
latent_vector_size = 100
noise = torch.randn(batch_size, latent_vector_size, 1, 1, device=device)

##### **Q10: How do you implement the forward pass for the Generator model in PyTorch to output fake data?**

In [11]:
generator = Generator().to(device)

fake_images = generator(noise)

## Defining the Discriminator model


##### **Q11: How do you define the architecture of the Discriminator model using PyTorch’s `nn.Module`?**


In [12]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()  # Output between 0 and 1 (probability)
        )

    def forward(self, input):
        return self.main(input)

##### **Q12: How do you implement the forward pass for the Discriminator model to classify real and fake data?**


In [13]:
discriminator = Discriminator().to(device)

real_output = discriminator(torch.randn(batch_size, 3, 64, 64, device=device))  # Simulated real images
fake_output = discriminator(fake_images.detach())  # Fake images from the Generator

##### **Q13: How do you initialize the weights for both the Generator and Discriminator models in PyTorch?**

In [14]:
def weights_init(m):  # Custom weights initialization function for both Generator and Discriminator
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
        
generator.apply(weights_init)
discriminator.apply(weights_init)  # Applying weight initialization

Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)

In [None]:
import shutil
import os

if os.path.exists('data'):
    shutil.rmtree('data')
    print("Folder 'data' has been deleted.")
else:
    print("Folder 'data' does not exist.")