MNIST Conditional GAN (CGAN) Exercise - Student Version
Welcome to this hands-on exercise on Conditional Generative Adversarial Networks! You'll build a CGAN that can generate specific digits (0-9) on command.
Learning Objectives

Understand the difference between GANs and Conditional GANs
Learn how to incorporate class labels into both Generator and Discriminator
Implement label embedding and concatenation techniques
Generate specific digits on demand
Analyze the quality and diversity of conditional generation

What makes CGANs special?
Unlike regular GANs that generate random samples, Conditional GANs let you control what gets generated by providing additional information (conditions) like class labels, text descriptions, or other attributes.


In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

Using device: cpu


In [1]:
# Define transforms for MNIST data
transform = transforms.Compose([
    transforms.ToTensor(),
    # TODO: Add normalization transform to scale pixel values to [-1, 1]
    # YOUR CODE HERE:

])

# TODO: Load MNIST dataset (keep labels this time!)
# YOUR CODE HERE:
train_dataset = torchvision.datasets.MNIST(
    root='/content/', train=True, download=True, transform=None
)

# Create data loader
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Dataset information
num_classes = 10  # Digits 0-9
print(f"Dataset size: {len(train_dataset)}")
print(f"Number of classes: {num_classes}")
print(f"Number of batches: {len(train_loader)}")

# Let's visualize some samples with their labels
def visualize_dataset_samples():
    # Get a batch of data
    data_iter = iter(train_loader)
    images, labels = next(data_iter)

    # Plot first 10 images with labels
    fig, axes = plt.subplots(2, 5, figsize=(12, 6))
    for i, ax in enumerate(axes.flat):
        ax.imshow(images[i].squeeze(), cmap='gray')
        ax.set_title(f'Label: {labels[i].item()}')
        ax.axis('off')

    plt.suptitle('MNIST Dataset Samples with Labels')
    plt.tight_layout()
    plt.show()

visualize_dataset_samples()

NameError: name 'transforms' is not defined

In [None]:
class LabelEmbedding(nn.Module):
    """
    Converts class labels into dense vector representations
    """
    def __init__(self, num_classes, embedding_dim):
        super(LabelEmbedding, self).__init__()
        # TODO: Create an embedding layer
        # Hint: Use nn.Embedding(num_classes, embedding_dim)
        # YOUR CODE HERE:
        self.embedding =

    def forward(self, labels):
        # TODO: Apply embedding to labels
        # YOUR CODE HERE:
        return

# Test the embedding layer
embedding_dim = 50
label_embedding = LabelEmbedding(num_classes, embedding_dim)

# Test with some sample labels
sample_labels = torch.tensor([0, 1, 2, 3, 4])
embedded = label_embedding(sample_labels)
print(f"Original labels shape: {sample_labels.shape}")
print(f"Embedded labels shape: {embedded.shape}")
print(f"Each label becomes a {embedding_dim}-dimensional vector")

In [None]:
class ConditionalGenerator(nn.Module):
    def __init__(self, noise_dim=100, num_classes=10, embedding_dim=50, img_dim=28*28):
        super(ConditionalGenerator, self).__init__()

        # TODO: Create label embedding layer
        # YOUR CODE HERE:
        self.label_embedding =

        # Calculate input dimension: noise + embedded label
        # TODO: Calculate the total input dimension
        # YOUR CODE HERE:
        input_dim =

        # TODO: Define the generator network
        # Architecture: input_dim -> 256 -> 512 -> 1024 -> img_dim
        # YOUR CODE HERE:
        self.model = nn.Sequential(
            # Layer 1: input_dim -> 256
            # Hint: Use nn.Linear followed by nn.LeakyReLU(0.2)

            # Layer 2: 256 -> 512

            # Layer 3: 512 -> 1024

            # Output layer: 1024 -> img_dim
            # Use Tanh activation for final layer

        )

    def forward(self, noise, labels):
        # TODO: Implement forward pass
        # Steps:
        # 1. Embed the labels
        # 2. Concatenate noise and embedded labels
        # 3. Pass through the network
        # YOUR CODE HERE:

        pass

In [None]:
class ConditionalDiscriminator(nn.Module):
    def __init__(self, img_dim=28*28, num_classes=10, embedding_dim=50):
        super(ConditionalDiscriminator, self).__init__()

        # TODO: Create label embedding layer
        # YOUR CODE HERE:
        self.label_embedding =

        # Calculate input dimension: image + embedded label
        # TODO: Calculate the total input dimension
        # YOUR CODE HERE:
        input_dim =

        # TODO: Define the discriminator network
        # Architecture: input_dim -> 1024 -> 512 -> 256 -> 1
        # YOUR CODE HERE:
        self.model = nn.Sequential(
            # Layer 1: input_dim -> 1024
            # Hint: Use nn.Linear, nn.LeakyReLU(0.2), nn.Dropout(0.3)

            # Layer 2: 1024 -> 512

            # Layer 3: 512 -> 256

            # Output layer: 256 -> 1
            # Use Sigmoid for binary classification

        )

    def forward(self, images, labels):
        # TODO: Implement forward pass
        # Steps:
        # 1. Flatten images if needed
        # 2. Embed the labels
        # 3. Concatenate images and embedded labels
        # 4. Pass through the network
        # YOUR CODE HERE:

        pass

SyntaxError: invalid syntax (ipython-input-9-1062803426.py, line 7)

In [None]:
# Hyperparameters
noise_dim = 100
embedding_dim = 50
learning_rate = 0.0002

# TODO: Create instances of Generator and Discriminator
# YOUR CODE HERE:
generator =
discriminator =

# Move models to device
generator.to(device)
discriminator.to(device)

# TODO: Define loss function and optimizers
# YOUR CODE HERE:
criterion =
gen_optimizer =
disc_optimizer =

print("Models initialized successfully!")
print(f"Generator parameters: {sum(p.numel() for p in generator.parameters())}")
print(f"Discriminator parameters: {sum(p.numel() for p in discriminator.parameters())}")

# Test the networks with dummy data
test_noise = torch.randn(5, noise_dim).to(device)
test_labels = torch.randint(0, num_classes, (5,)).to(device)
test_images = torch.randn(5, 28*28).to(device)

print(f"\nTesting networks:")
print(f"Generator output shape: {generator(test_noise, test_labels).shape}")
print(f"Discriminator output shape: {discriminator(test_images, test_labels).shape}")

In [None]:
def train_conditional_gan(generator, discriminator, train_loader, num_epochs=50):
    """
    Train the Conditional GAN for specified number of epochs
    """

    # Lists to store losses for plotting
    gen_losses = []
    disc_losses = []

    for epoch in range(num_epochs):
        gen_epoch_loss = 0
        disc_epoch_loss = 0

        for batch_idx, (real_images, real_labels) in enumerate(train_loader):
            batch_size = real_images.size(0)

            # Flatten images and move to device
            real_images = real_images.view(batch_size, -1).to(device)
            real_labels = real_labels.to(device)

            # Create binary labels for real/fake classification
            # TODO: Create real and fake binary labels
            # YOUR CODE HERE:
            real_binary_labels =
            fake_binary_labels =

            # =================================================================
            # TRAIN DISCRIMINATOR
            # =================================================================

            disc_optimizer.zero_grad()

            # Train on real images with their true labels
            # TODO: Get discriminator output for real images and calculate loss
            # YOUR CODE HERE:
            real_output =
            real_loss =

            # Train on fake images with random labels
            # TODO: Generate fake images with random labels
            # YOUR CODE HERE:
            noise = torch.randn(batch_size, noise_dim).to(device)
            fake_labels = torch.randint(0, num_classes, (batch_size,)).to(device)
            fake_images =
            fake_output =
            fake_loss =

            # Total discriminator loss
            # TODO: Calculate total discriminator loss and backpropagate
            # YOUR CODE HERE:
            disc_loss =
            disc_loss.backward()
            disc_optimizer.step()

            # =================================================================
            # TRAIN GENERATOR
            # =================================================================

            gen_optimizer.zero_grad()

            # TODO: Train generator to fool discriminator
            # Generate fake images and try to make discriminator classify them as real
            # YOUR CODE HERE:
            noise = torch.randn(batch_size, noise_dim).to(device)
            gen_labels = torch.randint(0, num_classes, (batch_size,)).to(device)
            fake_images =
            fake_output =
            gen_loss =

            gen_loss.backward()
            gen_optimizer.step()

            # Accumulate losses
            gen_epoch_loss += gen_loss.item()
            disc_epoch_loss += disc_loss.item()

        # Average losses for the epoch
        gen_epoch_loss /= len(train_loader)
        disc_epoch_loss /= len(train_loader)

        gen_losses.append(gen_epoch_loss)
        disc_losses.append(disc_epoch_loss)

        # Print progress
        if (epoch + 1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}]')
            print(f'Generator Loss: {gen_epoch_loss:.4f}')
            print(f'Discriminator Loss: {disc_epoch_loss:.4f}')
            print('-' * 50)

    return gen_losses, disc_losses

In [None]:
def plot_losses(gen_losses, disc_losses):
    """Plot training losses"""
    plt.figure(figsize=(10, 5))
    plt.plot(gen_losses, label='Generator Loss')
    plt.plot(disc_losses, label='Discriminator Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Conditional GAN Training Losses')
    plt.legend()
    plt.grid(True)
    plt.show()

def generate_specific_digits(generator, digits_to_generate, num_samples_per_digit=5):
    """Generate specific digits on demand"""
    generator.eval()
    with torch.no_grad():
        # TODO: Generate images for specific digit classes
        # YOUR CODE HERE:
        all_images = []
        all_labels = []

        for digit in digits_to_generate:
            # Create noise and labels for this digit
            noise =
            labels =

            # Generate images
            fake_images =

            all_images.append(fake_images)
            all_labels.extend([digit] * num_samples_per_digit)

        # Concatenate all images
        all_images = torch.cat(all_images, dim=0)
        all_images = all_images.view(-1, 28, 28).cpu().numpy()

        # Plot results
        num_digits = len(digits_to_generate)
        fig, axes = plt.subplots(num_digits, num_samples_per_digit,
                                figsize=(num_samples_per_digit*2, num_digits*2))

        if num_digits == 1:
            axes = axes.reshape(1, -1)

        for i, digit in enumerate(digits_to_generate):
            for j in range(num_samples_per_digit):
                idx = i * num_samples_per_digit + j
                axes[i, j].imshow(all_images[idx], cmap='gray')
                axes[i, j].set_title(f'Generated {digit}')
                axes[i, j].axis('off')

        plt.tight_layout()
        plt.show()

def generate_all_digits(generator, num_samples_per_digit=8):
    """Generate samples for all digit classes (0-9)"""
    # TODO: Generate samples for each digit class
    # YOUR CODE HERE:
    digits_to_generate = list(range(10))  # 0, 1, 2, ..., 9
    generate_specific_digits(generator, digits_to_generate, num_samples_per_digit)

def compare_conditional_vs_real(train_loader, generator):
    """Compare real and conditionally generated images"""
    # Get real images and labels
    real_images, real_labels = next(iter(train_loader))

    # TODO: Generate fake images with the same labels as real images
    # YOUR CODE HERE:
    generator.eval()
    with torch.no_grad():
        noise =
        fake_images =
        fake_images = fake_images.view(-1, 28, 28)

    # Plot comparison for first 10 samples
    fig, axes = plt.subplots(2, 10, figsize=(15, 4))

    for i in range(10):
        # Real images
        axes[0, i].imshow(real_images[i].squeeze(), cmap='gray')
        axes[0, i].set_title(f'Real {real_labels[i].item()}')
        axes[0, i].axis('off')

        # Generated images
        axes[1, i].imshow(fake_images[i].cpu().numpy(), cmap='gray')
        axes[1, i].set_title(f'Generated {real_labels[i].item()}')
        axes[1, i].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
print("Starting Conditional GAN training...")
print("This may take several minutes depending on your hardware.")

# TODO: Train the conditional GAN
# YOUR CODE HERE:
gen_losses, disc_losses =

print("Training completed!")

In [None]:
def interactive_digit_generation(generator):
    """
    Interactive function to generate any digit on demand
    """
    print("Interactive Digit Generation!")
    print("Enter digits (0-9) separated by spaces, or 'quit' to exit")

    while True:
        user_input = input("\nWhich digits would you like to generate? ")

        if user_input.lower() == 'quit':
            break

        try:
            # Parse user input
            digits = [int(d) for d in user_input.split() if d.isdigit() and 0 <= int(d) <= 9]

            if not digits:
                print("Please enter valid digits (0-9)")
                continue

            print(f"Generating digits: {digits}")
            generate_specific_digits(generator, digits, num_samples_per_digit=4)

        except Exception as e:
            print(f"Error: {e}. Please enter digits separated by spaces.")

# Uncomment to run interactive generation
# interactive_digit_generation(generator)

In [None]:
def evaluate_conditional_generation(generator, num_samples=100):
    """
    Evaluate the quality and diversity of conditional generation
    """
    generator.eval()

    # Generate samples for each class
    class_samples = {i: [] for i in range(10)}

    with torch.no_grad():
        for digit in range(10):
            noise = torch.randn(num_samples, noise_dim).to(device)
            labels = torch.full((num_samples,), digit, dtype=torch.long).to(device)
            fake_images = generator(noise, labels)
            fake_images = fake_images.view(num_samples, 28, 28).cpu().numpy()
            class_samples[digit] = fake_images

    # TODO: Calculate diversity within each class
    # Hint: You could calculate pixel-wise variance or other diversity metrics
    # YOUR CODE HERE:

    # Plot some statistics
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))

    for digit in range(10):
        row = digit // 5
        col = digit % 5

        # Calculate mean image for this digit
        mean_image = np.mean(class_samples[digit], axis=0)
        axes[row, col].imshow(mean_image, cmap='gray')
        axes[row, col].set_title(f'Mean Generated {digit}')
        axes[row, col].axis('off')

    plt.suptitle('Mean Generated Images for Each Digit Class')
    plt.tight_layout()
    plt.show()

    return class_samples

# Run evaluation
class_samples = evaluate_conditional_generation(generator)