# Generative Adversarial Networks (GANs) and Their Variants

Welcome to this notebook on Generative Adversarial Networks (GANs)! In this notebook, we will:

- Introduce the basic architecture of GANs, including the roles of the **Generator** and **Discriminator**
- Explain the training process and key challenges
- Implement a simple GAN (using a DCGAN-style architecture) on a dataset (e.g., MNIST)
- Discuss common GAN variants such as Conditional GANs, CycleGANs, and others

This notebook is designed for a practical, hands-on course in generative AI and avoids deep mathematical details while providing sufficient background to work with and build generative models.

## 1. Introduction to GANs

Generative Adversarial Networks (GANs) were introduced by Ian Goodfellow in 2014. The core idea is to have two neural networks—a **Generator** and a **Discriminator**—competing against each other:

- **Generator (G):** Takes in random noise (from a latent space) and attempts to generate realistic data (e.g., images).
- **Discriminator (D):** Receives both real data (from the training set) and fake data (produced by the generator) and tries to distinguish between the two.

During training, G improves at fooling D, while D gets better at detecting fakes. This adversarial process drives the generator to produce increasingly realistic outputs.

### Key Points

- **Adversarial Loss:** Both networks are trained with opposing objectives. The generator’s loss is designed to maximize the discriminator’s error, while the discriminator minimizes classification errors.
- **Training Challenges:** GANs are notoriously difficult to train due to issues such as mode collapse, non-convergence, and delicate balance between the two networks.

In the following sections, we will build a simple GAN, explore its components, and then briefly discuss variants that extend the basic idea.

## 2. Environment Setup and Imports

We'll be using PyTorch for our implementation. Make sure you have the following libraries installed:

- `torch`
- `torchvision`
- `matplotlib`
- `numpy`

Let's import the necessary modules.

In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import numpy as np

# Set manual seed for reproducibility
manualSeed = 999
print("Random Seed: ", manualSeed)
torch.manual_seed(manualSeed)

# Decide which device we want to run on
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

## 3. Data Preparation

For this demonstration, we will use the MNIST dataset. Later, you could try using other datasets (e.g., CelebA) to see how the model adapts to more complex images.

We will normalize the images to the range [-1, 1] (a common practice for GANs) and load the data using PyTorch’s `DataLoader`.

In [None]:
# Hyperparameters
batch_size = 128
image_size = 28
nc = 1  # number of channels for MNIST (grayscale)
nz = 100  # size of the latent z vector

# Create the dataset
dataset = dsets.MNIST(root='./data', train=True, download=True, 
                        transform=transforms.Compose([
                            transforms.Resize(image_size),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5,), (0.5,))
                        ]))

# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

## 4. Defining the GAN Architecture

We will now define the **Generator** and **Discriminator** networks. For this example, we use a simple architecture similar to that of a Deep Convolutional GAN (DCGAN), adapted for MNIST.

### Generator

The generator takes a latent vector (random noise) as input and produces an image. We use a series of transposed convolutions to upsample the latent space to the image dimensions.

### Discriminator

The discriminator is a convolutional network that takes an image as input and outputs a single scalar representing the probability that the input is real.

In [None]:
## Generator Code
class Generator(nn.Module):
    def __init__(self, nz, ngf, nc):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # Input is Z, going into a convolution
            nn.ConvTranspose2d(nz, ngf * 4, 7, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),

            # State size: (ngf*4) x 7 x 7
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),

            # State size: (ngf*2) x 14 x 14
            nn.ConvTranspose2d(ngf * 2, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # Output size: nc x 28 x 28
        )

    def forward(self, input):
        return self.main(input)

## Discriminator Code
class Discriminator(nn.Module):
    def __init__(self, nc, ndf):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # Input is nc x 28 x 28
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            # State size: ndf x 14 x 14
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            # State size: (ndf*2) x 7 x 7
            nn.Conv2d(ndf * 2, 1, 7, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input).view(-1)

## 5. Training the GAN

We now define the training loop. In each iteration:

1. **Discriminator Training:**
   - Train on real images from the dataset
   - Train on fake images produced by the generator
2. **Generator Training:**
   - Update the generator to produce images that fool the discriminator

For simplicity, we use the Binary Cross Entropy loss. You may explore other loss functions in advanced exercises.

## 6. Training Loop
**Steps per Batch:**
1. **Train the Discriminator:**  
   - Use a batch of real images and label them as 1.
   - Generate a batch of fake images from noise and label them as 0.
   - Compute loss and backpropagate.
2. **Train the Generator:**  
   - Generate a batch of fake images.
   - Compute the loss against a target label of 1 (i.e., tricking the discriminator).
   - Backpropagate and update weights.
**Logging:**  
We log key metrics (losses, sample images) to wandb to track training progress.

In [None]:
# Hyperparameters for the model architecture
ngf = 64  # Generator feature map size
ndf = 64  # Discriminator feature map size

# Create the generator
netG = Generator(nz, ngf, nc).to(device)
print(netG)

# Create the discriminator
netD = Discriminator(nc, ndf).to(device)
print(netD)

# Loss function
criterion = nn.BCELoss()

# Setup optimizers for both G and D
lr = 0.0002
beta1 = 0.5
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# Training Loop
num_epochs = 5  # Adjust epochs as needed
real_label = 1.
fake_label = 0.

img_list = []

print("Starting Training Loop...")
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ############################
        netD.zero_grad()
        
        # Train with real images
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        
        output = netD(real_cpu)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()
        
        # Train with fake images
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach())
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()
        
        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ############################
        netG.zero_grad()
        label.fill_(real_label)  # Generator wants the discriminator to think these are real
        output = netD(fake)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()
        
        # Output training stats
        if i % 100 == 0:
            print(f'[Epoch {epoch}/{num_epochs}] [Batch {i}/{len(dataloader)}] ' 
                  f'Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f} ' 
                  f'D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f} / {D_G_z2:.4f}')
    
    # Save generator's output on fixed_noise for visualization
    with torch.no_grad():
        fixed_noise = torch.randn(64, nz, 1, 1, device=device)
        fake = netG(fixed_noise).detach().cpu()
    img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

# Plot some generated images
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Generated Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

## 6. Exploring GAN Variants

Once you are comfortable with the basic GAN, you may explore the following variants:

- **Conditional GANs (cGANs):** Generate images conditioned on labels. For example, generate MNIST digits conditioned on the digit label.
- **DCGAN:** The architecture we used here is inspired by DCGAN. For higher resolution images, deeper networks may be needed.
- **CycleGAN:** Used for image-to-image translation without paired examples (e.g., converting horses to zebras).
- **InfoGAN:** Designed to learn interpretable and disentangled representations by maximizing mutual information.

### Hands-On Exercise

1. Modify the generator and discriminator to accept label information (i.e., implement a conditional GAN on MNIST).
2. Experiment with different loss functions or network architectures to see how the output quality changes.
3. Try using a different dataset (such as CelebA) and adjust the network layers accordingly.

Each variant introduces new challenges and techniques that are valuable in professional AI work.

## 7. Conclusion and Next Steps

In this notebook, you learned:

- The basic concept of GANs and their adversarial training process
– How to implement a simple DCGAN on MNIST using PyTorch
- An overview of popular GAN variants and ideas for further exploration

### Recommended Tools and Platforms

- **PyTorch / TensorFlow:** Frameworks for deep learning
- **Weights & Biases or TensorBoard:** For tracking experiments
- **Google Colab or Kaggle Kernels:** For running experiments in the cloud