# Generative Adversarial Network Training

Here we fashion and train a GAN to synthesize mnist digits.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

## A generator maps a vector to an image.

We will use a simple convnet for that.

In [None]:
# Generator
class Generator(nn.Sequential):
    def __init__(self, latent_dim):
        super(Generator, self).__init__(
            nn.ConvTranspose2d(latent_dim, 128, 7, 1, 0),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),
            nn.Tanh()
        )

latent_dim = 100
generator = Generator(latent_dim)

# Sample some random noise
z_samples = torch.randn(12, latent_dim, 1, 1)

# Plot some samples
def show_samples():
    with torch.no_grad():
        fake_data = generator(z_samples).reshape(12, 28, 28).numpy()
        fig, axs = plt.subplots(1, 12, figsize=(11, 1))
        for i in range(12):
            axs[i].imshow(fake_data[i], cmap='gray')
            axs[i].axis('off')
        plt.show()

show_samples()

## A discriminator is a classifier

This is a simple convnet, with a single (binary) class output prediction.

In [None]:
# Discriminator
class Discriminator(nn.Sequential):
    def __init__(self):
        super(Discriminator, self).__init__(
            nn.Conv2d(1, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 1, 7, 1, 0, bias=False)
        )

# The GAN training loop

There are a few tricks here.
 * How many optimizers are there. Why?
 * On each iteration, how many times is the generator run?
 * How many times is the discriminator run?  What is the purpose of each run?

In [None]:
# Parameters
batch_size = 64
epochs = 10
lr = 0.001

# Initialize generator and discriminator
generator = Generator(latent_dim)
discriminator = Discriminator()

# Loss functions and optimizers
criterion = nn.BCEWithLogitsLoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)

# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)


# Training loop
for epoch in range(epochs):
    show_samples()

    for batch_idx, (real_data, _) in enumerate(dataloader):
        # Train discriminator
        discriminator.zero_grad()

        real_output = discriminator(real_data)

        z = torch.randn(real_data.size(0), latent_dim, 1, 1)
        fake_data = generator(z)
        fake_output = discriminator(fake_data.detach())

        real_labels = torch.ones_like(real_output)
        fake_labels = torch.zeros_like(fake_output)
        real_loss = criterion(real_output, real_labels)
        fake_loss = criterion(fake_output, fake_labels)
        d_loss = real_loss + fake_loss

        d_loss.backward(retain_graph=True)  # Set retain_graph=True here
        optimizer_D.step()

        # Train generator
        generator.zero_grad()
        fake_output = discriminator(fake_data)

        real_labels = torch.ones_like(fake_output)
        g_loss = criterion(fake_output, real_labels)

        g_loss.backward()
        optimizer_G.step()

    # Print losses and visualize generator output
    print(f"Epoch [{epoch+1}/{epochs}], Discriminator Loss: {d_loss.item():.4f}, Generator Loss: {g_loss.item():.4f}")

