![image.png](https://i.imgur.com/a3uAqnb.png)

# Generative Adversarial Networks (GANs) - MNIST Image Generation

Welcome to this hands-on exercise on Generative Adversarial Networks (GANs)! You'll build a GAN from scratch to generate handwritten digits similar to the MNIST dataset.

## Learning Objectives
- Understand GAN architecture and training dynamics
- Implement Generator and Discriminator networks
- Learn adversarial training process
- Visualize and analyze results

---

## 1. Introduction to GANs

GANs consist of two neural networks competing against each other:
- **Generator**: Creates fake data from random noise
- **Discriminator**: Distinguishes between real and fake data

The generator tries to fool the discriminator, while the discriminator tries to correctly identify fake data. This adversarial process leads to increasingly realistic generated data.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    # TODO: Add normalization transform to scale pixel values to [-1, 1]
    # Hint: Use transforms.Normalize with mean=0.5, std=0.5
    # YOUR CODE HERE:

])

# TODO: Complete the dataset loading
# YOUR CODE HERE:
train_dataset = torchvision.datasets.MNIST(
    root='/content/', train=True, download=True, transform=None
)

# Create data loader
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

print(f"Dataset size: {len(train_dataset)}")
print(f"Number of batches: {len(train_loader)}")

100%|██████████| 9.91M/9.91M [00:01<00:00, 6.05MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 160kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.44MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.53MB/s]

Dataset size: 60000
Number of batches: 938





**Question**: Why do we normalize to [-1, 1] instead of [0, 1]?

The Generator takes random noise as input and transforms it into fake images. It learns to map from the noise space to the data space.

In [None]:
class Generator(nn.Module):
    def __init__(self, noise_dim=100, img_dim=28*28):
        super(Generator, self).__init__()

        # TODO: Define the generator architecture
        # The generator should transform noise (noise_dim) into images (img_dim)
        # Use a series of Linear layers with appropriate activations
        # Suggested architecture: noise_dim -> 256 -> 512 -> 1024 -> img_dim
        # YOUR CODE HERE:
        self.model = nn.Sequential(
            # Layer 1: noise_dim -> 256
            # Hint: Use nn.Linear followed by nn.LeakyReLU(0.2)

            # Layer 2: 256 -> 512

            # Layer 3: 512 -> 1024

            # Output layer: 1024 -> img_dim
            # Use Tanh activation for final layer to output values in [-1, 1]

        )

    def forward(self, x):
        # TODO: Implement forward pass
        # YOUR CODE HERE:
        pass

**Question**: Why do we use Tanh activation in the generator's output layer?

The Discriminator is a binary classifier that distinguishes between real and fake images. It outputs a probability that the input image is real.

In [None]:
class Discriminator(nn.Module):
    def __init__(self, img_dim=28*28):
        super(Discriminator, self).__init__()

        # TODO: Define the discriminator architecture
        # The discriminator should classify images as real (1) or fake (0)
        # Architecture: img_dim -> 1024 -> 512 -> 256 -> 1
        # YOUR CODE HERE:
        self.model = nn.Sequential(
            # Layer 1: img_dim -> 1024
            # Hint: Use nn.Linear followed by nn.LeakyReLU(0.2)

            # Layer 2: 1024 -> 512

            # Layer 3: 512 -> 256

            # Output layer: 256 -> 1
            # Use Sigmoid for binary classification

        )

    def forward(self, x):
        # TODO: Implement forward pass
        # YOUR CODE HERE:
        pass

Initialize Networks and Optimizers

In [None]:
# Hyperparameters
noise_dim = 100
learning_rate = 0.0002

# TODO: Create instances of Generator and Discriminator
# YOUR CODE HERE:
generator =
discriminator =

# Move models to device
generator.to(device)
discriminator.to(device)

# TODO: Define loss function and optimizers
# Use Binary Cross Entropy Loss and Adam optimizer
# YOUR CODE HERE:
criterion =
gen_optimizer =
disc_optimizer =

print("Models initialized successfully!")
print(f"Generator parameters: {sum(p.numel() for p in generator.parameters())}")
print(f"Discriminator parameters: {sum(p.numel() for p in discriminator.parameters())}")

Training Function

In [None]:
def train_gan(generator, discriminator, train_loader, num_epochs=50):
    """
    Train the GAN for specified number of epochs
    """

    # Lists to store losses for plotting
    gen_losses = []
    disc_losses = []

    for epoch in range(num_epochs):
        gen_epoch_loss = 0
        disc_epoch_loss = 0

        for batch_idx, (real_images, _) in enumerate(train_loader):
            batch_size = real_images.size(0)

            # Flatten images and move to device
            real_images = real_images.view(batch_size, -1).to(device)

            # Create labels
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)

            # =================================================================
            # TRAIN DISCRIMINATOR
            # =================================================================

            disc_optimizer.zero_grad()

            # Train on real images
            # TODO: Get discriminator output for real images and calculate loss
            # YOUR CODE HERE:
            real_output =
            real_loss =

            # Train on fake images
            # TODO: Generate fake images and get discriminator output
            # YOUR CODE HERE:
            noise = torch.randn(batch_size, noise_dim).to(device)
            fake_images = generator(noise).detach()
            fake_output =
            fake_loss =

            # Total discriminator loss
            # TODO: Calculate total discriminator loss and backpropagate
            # YOUR CODE HERE:
            disc_loss =
            disc_loss.backward()
            disc_optimizer.step()

            # =================================================================
            # TRAIN GENERATOR
            # =================================================================

            gen_optimizer.zero_grad()

            # TODO: Train generator by trying to fool discriminator
            # Generate new fake images and get discriminator output
            # Calculate loss using real_labels (we want discriminator to think fake images are real)
            # YOUR CODE HERE:
            noise = torch.randn(batch_size, noise_dim).to(device)
            fake_images = generator(noise)
            fake_output =
            # Use real_labels because we want discriminator to think fake images are real
            gen_loss =

            gen_loss.backward()
            gen_optimizer.step()

            # Accumulate losses
            gen_epoch_loss += gen_loss.item()
            disc_epoch_loss += disc_loss.item()

        # Average losses for the epoch
        gen_epoch_loss /= len(train_loader)
        disc_epoch_loss /= len(train_loader)

        gen_losses.append(gen_epoch_loss)
        disc_losses.append(disc_epoch_loss)

        # Print progress
        if (epoch + 1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}]')
            print(f'Generator Loss: {gen_epoch_loss:.4f}')
            print(f'Discriminator Loss: {disc_epoch_loss:.4f}')
            print('-' * 50)

    return gen_losses, disc_losses

Visualization Functions

In [None]:
def plot_losses(gen_losses, disc_losses):
    """Plot training losses"""
    plt.figure(figsize=(10, 5))
    plt.plot(gen_losses, label='Generator Loss')
    plt.plot(disc_losses, label='Discriminator Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('GAN Training Losses')
    plt.legend()
    plt.grid(True)
    plt.show()

def generate_and_display_images(generator, num_images=16):
    """Generate and display fake images"""
    generator.eval()
    with torch.no_grad():
        noise = torch.randn(num_images, noise_dim).to(device)
        fake_images = generator(noise)

        # Reshape images for display
        fake_images = fake_images.view(fake_images.size(0), 28, 28)
        fake_images = fake_images.cpu().numpy()

        # Create subplot
        fig, axes = plt.subplots(4, 4, figsize=(8, 8))
        for i, ax in enumerate(axes.flat):
            if i < num_images:
                ax.imshow(fake_images[i], cmap='gray')
                ax.axis('off')

        plt.suptitle('Generated Images')
        plt.tight_layout()
        plt.show()

def compare_real_vs_fake(train_loader, generator, num_images=8):
    """Compare real and generated images side by side"""
    # Get real images
    real_images = next(iter(train_loader))[0][:num_images]

    # Generate fake images
    generator.eval()
    with torch.no_grad():
        # SOLUTION: Generate fake images
        noise = torch.randn(num_images, noise_dim).to(device)
        fake_images = generator(noise)
        fake_images = fake_images.view(fake_images.size(0), 28, 28)

    # Plot comparison
    fig, axes = plt.subplots(2, num_images, figsize=(12, 4))

    for i in range(num_images):
        # Real images
        axes[0, i].imshow(real_images[i].squeeze(), cmap='gray')
        axes[0, i].set_title('Real' if i == 0 else '')
        axes[0, i].axis('off')

        # Fake images
        axes[1, i].imshow(fake_images[i].cpu().numpy(), cmap='gray')
        axes[1, i].set_title('Generated' if i == 0 else '')
        axes[1, i].axis('off')

    plt.tight_layout()
    plt.show()

Training Execution

In [None]:
gen_losses, disc_losses = train_gan(generator, discriminator, train_loader, num_epochs=50)

print("Training completed!")

Results and Analysis

In [None]:
plot_losses(gen_losses, disc_losses)

# Generate and display sample images
print("Generated Images:")
generate_and_display_images(generator)


# Compare real vs fake images
print("Real vs Generated Images Comparison:")
compare_real_vs_fake(train_loader, generator)

# Contributed by: Abdullah Jan

![image.png](https://i.imgur.com/a3uAqnb.png)