# Generative Adversarial Network Training

Here we fashion and train a GAN to synthesize mnist digits.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

## A generator maps a vector to an image.

We will use a simple convnet for the generator.

* It takes a 100-dimensional random input and produces an image.
* At the start it produces a random garbage image for any input vector.

In [None]:
# Generator
class Generator(nn.Sequential):
    def __init__(self, latent_dim):
        super(Generator, self).__init__(
            nn.ConvTranspose2d(latent_dim, 128, 7, 1, 0),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),
            nn.Tanh()
        )

latent_dim = 100
generator = Generator(latent_dim)

# Sample some random noise
z_samples = torch.randn(12, latent_dim, 1, 1)

# Plot some samples
def show_samples():
    with torch.no_grad():
        fake_data = generator(z_samples).reshape(12, 28, 28).numpy()
        fig, axs = plt.subplots(1, 12, figsize=(11, 1))
        for i in range(12):
            axs[i].imshow(fake_data[i], cmap='gray')
            axs[i].axis('off')
        plt.show()

show_samples()

## A discriminator is a classifier

This is a simple convnet, with a single (binary) class output prediction.

* It takes an image as input.
* And then it produces a logit as output.

(It will still need to be softmaxed or passed through a sigmoid to be a classification probability. We will do that later, in the loss by using `BCEWithLogitsLoss` which does binary cross-entropy loss a built-in sigmoid to preprocess the logits into probabilities.)

In [None]:
# Discriminator
class Discriminator(nn.Sequential):
    def __init__(self):
        super(Discriminator, self).__init__(
            nn.Conv2d(1, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 1, 7, 1, 0, bias=False)
        )

## Training data

In [None]:
# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Show a sample of images
fig, axs = plt.subplots(1, 12, figsize=(11, 1))
for i in range(12):
    axs[i].imshow(dataset[i][0].squeeze().numpy(), cmap='gray')
    axs[i].axis('off')
plt.show()

# The GAN training loop

There are a few tricks here. Some questions to think about.
 * How many optimizers are there. Why?
 * On each iteration, how many times is the generator run?
 * How many times is the discriminator run?  What is the purpose of each run?



In [None]:
# Parameters
batch_size = 64
epochs = 10
lr = 0.001
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize generator and discriminator
generator = Generator(latent_dim)
discriminator = Discriminator()

# Loss functions and optimizers
criterion = nn.BCEWithLogitsLoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)

# Training loop
for epoch in range(epochs):
    show_samples()

    for batch_idx, (real_data, _) in enumerate(dataloader):
        # Train discriminator
        discriminator.zero_grad()

        # The true label for real data is 1
        real_estimate = discriminator(real_data)
        real_labels = torch.ones_like(real_estimate)
        real_loss = criterion(real_estimate, real_labels)

        # The true label for fake data is 0
        z = torch.randn(real_data.size(0), latent_dim, 1, 1)
        fake_data = generator(z)
        fake_estimate = discriminator(fake_data.detach())
        fake_labels = torch.zeros_like(fake_estimate)
        fake_loss = criterion(fake_estimate, fake_labels)

        d_loss = real_loss + fake_loss

        d_loss.backward(retain_graph=True)  # Set retain_graph=True here
        optimizer_D.step()

        # Train generator
        generator.zero_grad()

        # For the generator, the target label for fake data is 1
        fake_estimate_to_beat = discriminator(fake_data)
        real_labels = torch.ones_like(fake_estimate_to_beat)
        g_loss = criterion(fake_estimate_to_beat, real_labels)

        g_loss.backward()
        optimizer_G.step()

    # Print losses and visualize generator output
    print(f"Epoch [{epoch+1}/{epochs}], Discriminator Loss: {d_loss.item():.4f}, Generator Loss: {g_loss.item():.4f}")



## Bonus: Wasserstein GAN

The following is a variant of the Wasserstein GAN. Instead of using cross-entropy loss which saturates because it works in probabilities, it directly guides the discriminator to drive fake_estimate down towards negative infinity and real_estimate high towards positive infinity by minimizing (fake_estimate - real_estimate).

Even though it avoids saturation by letting the discriminator predict scores outside [0,1], it would explode if left unconstrained. The Wasserstein-GP scheme avoids explosions in the discriminator by severely penalizing gradients larger than one along the path between real and fake images. In the code below, that is `gradient_penalty`.

Using this discriminator, Wasserstein GAN training directly guides the generator to maximize generate images that maximize `fake_estimate_to_beat` (making its fake images look as real as possible).

In [None]:
# Initialize generator and discriminator
generator = Generator(latent_dim)
discriminator = Discriminator()

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)

# Training loop
for epoch in range(epochs):
    show_samples()

    for batch_idx, (real_data, _) in enumerate(dataloader):
        # Train discriminator
        discriminator.zero_grad()

        real_estimate = discriminator(real_data)

        z = torch.randn(real_data.size(0), latent_dim, 1, 1)
        fake_data = generator(z)
        detached_fake_data = fake_data.detach()
        fake_estimate = discriminator(detached_fake_data)

        # Gradient penalty
        alpha = torch.rand(real_data.size(0), 1, 1, 1).to(real_data.device)
        interpolated = alpha * real_data + (1 - alpha) * detached_fake_data
        interpolated.requires_grad = True
        interpolated_estimate = discriminator(interpolated)
        grad_outputs = torch.ones_like(interpolated_estimate)
        gradients = torch.autograd.grad(outputs=interpolated_estimate, inputs=interpolated,
                                        grad_outputs=grad_outputs, create_graph=True, retain_graph=True)[0]
        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()

        # WGAN-GP loss: fakes should be very negative and reals very positive
        # BUT the discriminator should not have gradients sloped more than 1.
        d_loss = fake_estimate.mean() - real_estimate.mean() + 20 * gradient_penalty

        d_loss.backward(retain_graph=True)
        optimizer_D.step()

        # Train generator
        generator.zero_grad()

        # WGAN-GP generator loss: generator wants to make the estimate positive
        z = torch.randn(real_data.size(0), latent_dim, 1, 1)
        fake_estimate_to_beat = discriminator(fake_data)
        g_loss = -fake_estimate_to_beat.mean()

        g_loss.backward()
        optimizer_G.step()

    # Print losses and visualize generator output
    print(f"Epoch [{epoch+1}/{epochs}], Discriminator Loss: {d_loss.item():.4f}, Generator Loss: {g_loss.item():.4f}")