In [5]:
import os
import torch
from torch import nn
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import uuid

In [6]:
# Configurable variables
NUM_EPOCHS = 50
NOISE_DIMENSION = 50
BATCH_SIZE = 128
TRAIN_ON_GPU = False
UNIQUE_RUN_ID = str(uuid.uuid4())
PRINT_STATS_AFTER_BATCH = 50
OPTIMIZER_LR = 2e-4
OPTIMIZER_BETAS = (0.5, 0.999)
GENERATOR_OUTPUT_IMAGE_SHAPE = 28 * 28 * 1

In [None]:
# Speed Ups
torch.autograd.set_detect_anomaly(False)
torch.autograd.profiler.profile(False)
torch.autograd.profiler.emit_nvtx(False)
torch.bachends.cudnn.benchmark = True

## Generator

In [26]:
class Generator(nn.Module):
    """
    Vanilla GAN Generator
    """
    def __init__(self,):
        super().__init__()
        self.layers = nn.Sequential(
            # First upsampling
            nn.Linear(NOISE_DIMENSION, 128, bias=False),
            nn.BatchNorm1d(128, momentum=0.8),
            nn.LeakyReLU(0.2),
            # Second upsampling
            nn.Linear(128, 256, bias=False),
            nn.BatchNorm1d(256, momentum=0.8),
            nn.LeakyReLU(0.2),
            # Third upsampling
            nn.Linear(256, 512, bias=False),
            nn.BatchNorm1d(512, momentum=0.8),
            nn.LeakyReLU(0.2),
            # Final upsampling
            nn.Linear(512, GENERATOR_OUTPUT_IMAGE_SHAPE, bias=False),
            nn.Tanh()
        )
        
    def forward(self, x):
        """Forward pass"""
        return self.layers(x)

## Discriminator

In [11]:
class Discriminator(nn.Module):
    """
    Vanilla GAN Discriminator
    """
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(GENERATOR_OUTPUT_IMAGE_SHAPE, 1024),
            nn.LeakyReLU(0.25),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.25),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.25),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        """Forward pass"""
        return self.layers(x)

## Housekeeping functions

In [31]:
def get_device():
    """ Retrieve device based on settings and availability. """
    return torch.device("cuda:0" if torch.cuda.is_available() and TRAIN_ON_GPU else "cpu")

def make_directory_for_run():
    """ Make a directory for this training run """
    print(f"Preparing training run {UNIQUE_RUN_ID}")
    if not os.path.exists("./runs"):
        os.mkdir('./runs')
    os.mkdir(f"./runs/{UNIQUE_RUN_ID}")
    
def generate_image(generator, epoch=0, batch=0, device=get_device()):
    """ Genereate subplots with generated examples. """
    images = []
    noise = generate_noise(BATCH_SIZE, device=device)
    generator.eval()
    images = generator(noise)
    plt.figure(figsize=(10, 10))
    for i in range(16):
        # Get iamge
        image = images[i]
        # Convert image back onto CPU and reshape
        image = image.cpu().detach().numpy()
        image = np.reshape(image, (28, 28))
        # Plot
        plt.subplot(4, 4, i+1)
        plt.imshow(image, cmap='gray')
        plt.axis('off')
    if not os.path.exists(f"./runs/{UNIQUE_RUN_ID}/images"):
        os.mkdir(f"./runs/{UNIQUE_RUN_ID}/images")
    plt.savefig(f"./runs/{UNIQUE_RUN_ID}/images/epoch{epoch}_batch{batch}.jpg")
    
def save_models(generator, discriminator, epoch):
    """ Save models at specific point in time. """
    torch.save(generator.state_dict(), f"./runs/{UNIQUE_RUN_ID}/generator_{epoch}.pth")
    torch.save(discriminator.state_dict(), f"./runs/{UNIQUE_RUN_ID}/discriminator_{epoch}.pth")
    
def print_training_progress(batch, generator_loss, discriminator_loss):
    """ Print training progress. """
    print(f"Losses after mini-batch {batch:5d}: generator {generator_loss:e}, discriminator {discriminator_loss:e}")

In [23]:
def prepare_dataset():
    """ Prepare dataset through DataLoader """
    # Prepare MNIST dataset
    dataset = MNIST(os.getcwd(), download=True, train=True, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)) # range changes to [-1, 1] to match tanh in Generator
    ]))
    # Batch and shuffle data with DataLoader
    trainloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
    # Return dataset through DataLoader
    return trainloader

In [18]:
def initialize_models(device = get_device()):
    """ Initialize Generator and Discriminator models """
    generator = Generator()
    discriminator = Discriminator()
    # Move models to specific device
    generator.to(device)
    discriminator.to(device)
    # Return models
    return generator, discriminator

def initialize_loss():
    """ Initialize loss function. """
    return nn.BCELoss()

def initialize_optimizers(generator, discriminator):
    """ Initialize optimizers for Generator and Discriminator """
    generator_optimizer = torch.optim.AdamW(generator.parameters(), lr=OPTIMIZER_LR, betas=OPTIMIZER_BETAS)
    discriminator_optimizer = torch.optim.AdamW(discriminator.parameters(), lr=OPTIMIZER_LR, betas=OPTIMIZER_BETAS)
    return generator_optimizer, discriminator_optimizer

In [19]:
def generate_noise(number_of_images = 1, noise_dimension = NOISE_DIMENSION, device=None):
    """ Generate noise for number_of_images images, with a specific noise_dimension """
    return torch.randn(number_of_images, noise_dimension, device=device)

def efficient_zero_grad(model):
    """
    Apply zero_grad more efficiently
    Source: https://betterprogramming.pub/how-to-make-your-pytorch-code-run-faster-93079f3c1f7b
    """
    for param in model.parameters():
        param.grad = None
        
def forward_and_backward(model, data, loss_function, targets):
    """
    Perform forward and backward pass in a generic way. Returns loss value.
    """
    outputs = model(data)
    error = loss_function(outputs, targets)
    error.backward()
    return error.item()

## Training

In [20]:
def perform_train_step(generator, discriminator, real_data, \
    loss_function, generator_optimizer, discriminator_optimizer, device=get_device()):
    """ Perform a single training step. """
    
    # 1. PREPARATION
    # Set real and fake labels.
    real_label, fake_label = 1.0, 0.0
    # Get images on CPU or PGU as configured and available
    # Also set 'actual batch size', which can eb smaller than BATCH_SIZE in some cases.
    # This is because the last batch of an epoch may have less examples than the others.
    real_images = real_data[0].to(device) # retrieve the image of (image, label) from the original MNIST dataset
    actual_batch_size = real_images.size(0)
    label = torch.full((actual_batch_size, 1), real_label, device=device)
    
    # 3. TRAINING THE DISCRIMINATOR
    # Zero the gradients for discriminator
    efficient_zero_grad(discriminator)
    # Forward + backward on real images, reshaped
    real_images = real_images.view(real_images.size(0), -1) # (batch_size, 1 x 24 x 24)
    error_real_images = forward_and_backward(discriminator, real_images, \
        loss_function, label)
    # Forward + backward on generated images
    noise = generate_noise(actual_batch_size, device=device)
    generated_images = generator(noise)
    label.fill_(fake_label)
    error_generated_images = forward_and_backward(discriminator, \
        generated_images.detach(), loss_function, label)
    # Optim for discriminator
    discriminator_optimizer.step()
    
    # 3. TRAINING THE GENERATOR
    # Forward + backward + optim for generator, including zero grad
    efficient_zero_grad(generator)
    label.fill_(real_label)
    error_generator = forward_and_backward(discriminator, generated_images, loss_function, label)
    generator_optimizer.step()
    
    # 4. COMPUTING RESULTS
    # Compute loss values in floats for discriminator, which is joint loss.
    error_discriminator = error_real_images + error_generated_images
    # Return generator and discriminator loss so that it can be printed.
    return error_generator, error_discriminator

In [28]:
def perform_epoch(dataloader, generator, discriminator, loss_function, \
    generator_optimizer, discriminator_optimizer, epoch):
    """ Perform a single epoch. """
    for batch_no, real_data in enumerate(dataloader, 0):
        # Perform training step
        generator_loss_val, discriminator_loss_val = perform_train_step(generator, \
            discriminator, real_data, loss_function, \
            generator_optimizer, discriminator_optimizer)
        # Print statistics and generate iamge after every n-th batch
        if batch_no % PRINT_STATS_AFTER_BATCH == 0:
            print_training_progress(batch_no, generator_loss_val, discriminator_loss_val)
            generate_image(generator, epoch, batch_no)
        # Save models on epoch completion
        save_models(generator, discriminator, epoch)
        # Clear memory after every epoch
        torch.cuda.empty_cache()

In [None]:
def train_gan():
    """ Train the GAN """
    # Make directory for unique run
    make_directory_for_run()
    # Set fixed random number seed
    torch.manual_seed(42)
    # Get prepared dataset
    dataloader = prepare_dataset()
    # Initialize models
    generator, discriminator = initialize_models()
    # Intialize loss and optimizers
    loss_function = initialize_loss()
    generator_optimizer, discriminator_optimizer = initialize_optimizers(generator, discriminator)
    # Train the model
    for epoch in range(NUM_EPOCHS):
        print(f"Starting epoch {epoch}...")
        perform_epoch(dataloader, generator, discriminator, loss_function, \
            generator_optimizer, discriminator_optimizer, epoch)
    # Finished :-)
    print(f"Finished unique run {UNIQUE_RUN_ID}")
    
if __name__ == '__main__':
    train_gan()

Preparing training run f8b9562c-0cfc-42dd-86f4-fc1bcbcda720
Starting epoch 0...




Losses after mini-batch     0: generator 6.821358e-01, discriminator 1.371052e+00
Losses after mini-batch    50: generator 9.486220e-01, discriminator 1.265845e+00
Losses after mini-batch   100: generator 9.288982e-01, discriminator 1.391765e+00
Losses after mini-batch   150: generator 9.200873e-01, discriminator 1.214300e+00
Losses after mini-batch   200: generator 1.325961e+00, discriminator 9.432145e-01
Losses after mini-batch   250: generator 8.926033e-01, discriminator 7.523752e-01
Losses after mini-batch   300: generator 1.207826e+00, discriminator 7.872662e-01
Losses after mini-batch   350: generator 1.794275e+00, discriminator 5.709856e-01
Losses after mini-batch   400: generator 2.630622e+00, discriminator 1.079310e+00
Losses after mini-batch   450: generator 1.310829e+00, discriminator 6.749234e-01
Starting epoch 1...




Losses after mini-batch     0: generator 3.286405e+00, discriminator 7.502441e-01
Losses after mini-batch    50: generator 2.526597e+00, discriminator 6.209961e-01
Losses after mini-batch   100: generator 2.311579e+00, discriminator 9.383871e-01
Losses after mini-batch   150: generator 2.358335e+00, discriminator 5.270703e-01
Losses after mini-batch   200: generator 2.459430e+00, discriminator 6.058400e-01
Losses after mini-batch   250: generator 2.022813e+00, discriminator 5.303159e-01
Losses after mini-batch   300: generator 1.952054e+00, discriminator 5.427356e-01
Losses after mini-batch   350: generator 2.460001e+00, discriminator 5.836852e-01
Losses after mini-batch   400: generator 1.933246e+00, discriminator 5.173903e-01
Losses after mini-batch   450: generator 1.244366e+00, discriminator 7.715344e-01
Starting epoch 2...




Losses after mini-batch     0: generator 1.844283e+00, discriminator 6.430355e-01


  plt.figure(figsize=(10, 10))


Losses after mini-batch    50: generator 2.950277e+00, discriminator 4.986420e-01
Losses after mini-batch   100: generator 1.866525e+00, discriminator 1.042526e+00
Losses after mini-batch   150: generator 1.802803e+00, discriminator 5.897691e-01
Losses after mini-batch   200: generator 2.252990e+00, discriminator 6.471228e-01
Losses after mini-batch   250: generator 2.357257e+00, discriminator 7.078981e-01
Losses after mini-batch   300: generator 1.921964e+00, discriminator 6.540326e-01
Losses after mini-batch   350: generator 2.455935e+00, discriminator 5.927333e-01
Losses after mini-batch   400: generator 2.647455e+00, discriminator 5.649123e-01
Losses after mini-batch   450: generator 2.049061e+00, discriminator 4.340605e-01
Starting epoch 3...




Losses after mini-batch     0: generator 1.776047e+00, discriminator 7.464757e-01
Losses after mini-batch    50: generator 8.638223e-01, discriminator 1.294404e+00
Losses after mini-batch   100: generator 5.532578e-01, discriminator 1.399839e+00
Losses after mini-batch   150: generator 2.133494e+00, discriminator 6.503580e-01
Losses after mini-batch   200: generator 2.199848e+00, discriminator 7.439395e-01
Losses after mini-batch   250: generator 1.457270e+00, discriminator 7.415901e-01
Losses after mini-batch   300: generator 1.546610e+00, discriminator 6.207586e-01
Losses after mini-batch   350: generator 2.911987e+00, discriminator 7.024477e-01
Losses after mini-batch   400: generator 4.125157e+00, discriminator 9.426774e-01
Losses after mini-batch   450: generator 1.488790e+00, discriminator 7.143639e-01
Starting epoch 4...




Losses after mini-batch     0: generator 6.714826e-01, discriminator 8.527733e-01
Losses after mini-batch    50: generator 2.120138e+00, discriminator 5.561877e-01
Losses after mini-batch   100: generator 2.181667e+00, discriminator 5.622258e-01
Losses after mini-batch   150: generator 9.359910e-01, discriminator 7.202892e-01
Losses after mini-batch   200: generator 1.850514e+00, discriminator 6.086128e-01
Losses after mini-batch   250: generator 2.190278e+00, discriminator 6.233092e-01
Losses after mini-batch   300: generator 1.984682e+00, discriminator 5.689805e-01
Losses after mini-batch   350: generator 1.712351e+00, discriminator 6.819035e-01
Losses after mini-batch   400: generator 1.516251e+00, discriminator 5.978690e-01
Losses after mini-batch   450: generator 3.147737e+00, discriminator 1.065923e+00
Starting epoch 5...




Losses after mini-batch     0: generator 2.033937e+00, discriminator 8.341357e-01
Losses after mini-batch    50: generator 2.613985e+00, discriminator 8.813044e-01
Losses after mini-batch   100: generator 1.902510e+00, discriminator 6.652126e-01
Losses after mini-batch   150: generator 3.131765e+00, discriminator 8.461483e-01
Losses after mini-batch   200: generator 1.541710e+00, discriminator 7.788189e-01
Losses after mini-batch   250: generator 1.723683e+00, discriminator 7.408310e-01
Losses after mini-batch   300: generator 1.850818e+00, discriminator 7.455605e-01
Losses after mini-batch   350: generator 2.017582e+00, discriminator 6.853232e-01
Losses after mini-batch   400: generator 2.370290e+00, discriminator 7.878802e-01
Losses after mini-batch   450: generator 5.348853e-01, discriminator 1.458541e+00
Starting epoch 6...




Losses after mini-batch     0: generator 1.571191e+00, discriminator 7.706105e-01
Losses after mini-batch    50: generator 7.680866e-01, discriminator 9.325773e-01
Losses after mini-batch   100: generator 1.607902e+00, discriminator 8.353691e-01
Losses after mini-batch   150: generator 1.773142e+00, discriminator 7.642600e-01
Losses after mini-batch   200: generator 1.783865e+00, discriminator 6.878004e-01
Losses after mini-batch   250: generator 7.145129e-01, discriminator 1.052205e+00
Losses after mini-batch   300: generator 2.506669e+00, discriminator 9.438508e-01
Losses after mini-batch   350: generator 1.362371e+00, discriminator 7.660693e-01
Losses after mini-batch   400: generator 2.252081e+00, discriminator 7.970077e-01
Losses after mini-batch   450: generator 1.243782e+00, discriminator 8.663449e-01
Starting epoch 7...




Losses after mini-batch     0: generator 8.245392e-01, discriminator 9.423319e-01
Losses after mini-batch    50: generator 1.675701e+00, discriminator 7.774864e-01
Losses after mini-batch   100: generator 2.532477e+00, discriminator 9.672236e-01
Losses after mini-batch   150: generator 1.963042e+00, discriminator 8.228761e-01
Losses after mini-batch   200: generator 1.489735e+00, discriminator 1.090393e+00
Losses after mini-batch   250: generator 1.384515e+00, discriminator 8.818959e-01
Losses after mini-batch   300: generator 1.464323e+00, discriminator 9.132618e-01
Losses after mini-batch   350: generator 1.581639e+00, discriminator 8.709914e-01
Losses after mini-batch   400: generator 1.121489e+00, discriminator 1.030180e+00
Losses after mini-batch   450: generator 1.726229e+00, discriminator 8.733801e-01
Starting epoch 8...




Losses after mini-batch     0: generator 1.180437e+00, discriminator 9.028890e-01
Losses after mini-batch    50: generator 1.170083e+00, discriminator 8.742793e-01
Losses after mini-batch   100: generator 2.136473e+00, discriminator 1.066518e+00
Losses after mini-batch   150: generator 7.499750e-01, discriminator 9.631000e-01
Losses after mini-batch   200: generator 1.810656e+00, discriminator 1.011894e+00
Losses after mini-batch   250: generator 1.299252e+00, discriminator 9.459539e-01
Losses after mini-batch   300: generator 2.095550e+00, discriminator 1.005014e+00
Losses after mini-batch   350: generator 2.435631e+00, discriminator 1.157133e+00
Losses after mini-batch   400: generator 1.803940e+00, discriminator 8.878756e-01
Losses after mini-batch   450: generator 1.273365e+00, discriminator 9.479349e-01
Starting epoch 9...




Losses after mini-batch     0: generator 1.898897e+00, discriminator 1.080953e+00
Losses after mini-batch    50: generator 1.068114e+00, discriminator 9.329784e-01
Losses after mini-batch   100: generator 1.978964e+00, discriminator 8.214071e-01
Losses after mini-batch   150: generator 1.287008e+00, discriminator 8.903011e-01
Losses after mini-batch   200: generator 2.062446e+00, discriminator 9.928671e-01
Losses after mini-batch   250: generator 1.391501e+00, discriminator 9.111294e-01
Losses after mini-batch   300: generator 1.718695e+00, discriminator 9.032275e-01
