In [21]:
import torch
from torch import nn

In [22]:
# Configurable variables
DIR = "./data/"
NOISE_DIMENSION = 50
GENERATOR_OUTPUT_IMAGE_SHAPE = 28 * 28 * 1
NUM_EPOCHS = 50
BATCH_SIZE = 128
PRINT_STATS_AFTER_BATCH = 50
OPTIMIZER_LR = 0.0002
OPTIMIZER_BETAS = (0.5, 0.999)


## Generator

In [6]:
class Generator(nn.Module):
  """
    Vanilla GAN Generator
  """
  def __init__(self,):
    super().__init__()
    self.layers = nn.Sequential(
      ###
      # Linear layer to apply the transformation and extend the number of features
      # This is a layer where every input influences every output of the layer 
      # to a degree specified by the layer’s weights
      nn.Linear(NOISE_DIMENSION, 128, bias=False),
      # Normalization layers re-center and normalize the output of 
      # one layer before feeding it to another. 
      nn.BatchNorm1d(128, 0.8), 
      # Activation layer, similar to ReLU but allows a small negative slope 
      # for negative inputs.
      # When you want to prevent the dying ReLU problem.
      nn.LeakyReLU(0.25), 
      ###
      nn.Linear(128, 256, bias=False),
      nn.BatchNorm1d(256, 0.8),
      nn.LeakyReLU(0.25),
      ###
      nn.Linear(256, 512, bias=False),
      nn.BatchNorm1d(512, 0.8),
      nn.LeakyReLU(0.25),
      ###
      nn.Linear(512, GENERATOR_OUTPUT_IMAGE_SHAPE, bias=False),
      # Activation layer, similar to a Sigmoid but converge more quickly
      # because the the range of values is [-1, 1] and the mean 0.
      nn.Tanh()
    )

  def forward(self, x):
    """Forward pass"""
    return self.layers(x)

## Discriminator

In [23]:
class Discriminator(nn.Module):
  """
    Vanilla GAN Discriminator
  """
  def __init__(self):
    super().__init__()
    self.layers = nn.Sequential(
      nn.Linear(GENERATOR_OUTPUT_IMAGE_SHAPE, 1024), 
      nn.LeakyReLU(0.25),
      nn.Dropout(0.3),
      nn.Linear(1024, 512), 
      nn.LeakyReLU(0.25),
      nn.Dropout(0.3),
      nn.Linear(512, 256), 
      nn.LeakyReLU(0.25),
      nn.Dropout(0.3),
      nn.Linear(256, 1),
      nn.Sigmoid()
    )

  def forward(self, x):
    """Forward pass"""
    return self.layers(x)

## Utilities 

In [30]:
import os
import numpy as np
import matplotlib.pyplot as plt

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def generate_noise(number_of_images = 1, noise_dimension = NOISE_DIMENSION, device=None):

  return torch.randn(number_of_images, noise_dimension, device=device)

def generate_image(generator, epoch = 0, batch = 0, device=device):

  images = []
  noise = generate_noise(BATCH_SIZE, device=device)
  generator.eval()
  images = generator(noise)
  plt.figure(figsize=(10, 10))
  for i in range(16):

    image = images[i]

    image = image.cpu().detach().numpy()
    image = np.reshape(image, (28, 28))

    plt.subplot(4, 4, i+1)
    plt.imshow(image, cmap='gray')
    plt.axis('off')
  if not os.path.exists(f'./data/images'):
    os.mkdir(f'./data/images')
  plt.savefig(f'./data/images/epoch{epoch}_batch{batch}.jpg')


def save_models(generator, discriminator, epoch):

  torch.save(generator.state_dict(), f'./data/generator_{epoch}.pth')
  torch.save(discriminator.state_dict(), f'./data/discriminator_{epoch}.pth')


def print_training_progress(batch, generator_loss, discriminator_loss):
  print('Losses after mini-batch %5d: generator %e, discriminator %e' %
        (batch, generator_loss, discriminator_loss))

In [31]:
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader

def prepare_dataset():

  # MNIST dataset
  dataset = MNIST(os.getcwd(), download=True, train=True, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
  ]))
  # Batch and shuffle data with DataLoader
  trainloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
  return trainloader

In [32]:
def initialize_models(device = device):

  generator = Generator()
  discriminator = Discriminator()
  # Move models to specific device
  generator.to(device)
  discriminator.to(device)
  # Return models
  return generator, discriminator


def initialize_loss():

  return nn.BCELoss()


def initialize_optimizers(generator, discriminator):

  generator_optimizer = torch.optim.AdamW(generator.parameters(), lr=OPTIMIZER_LR,betas=OPTIMIZER_BETAS)
  discriminator_optimizer = torch.optim.AdamW(discriminator.parameters(), lr=OPTIMIZER_LR,betas=OPTIMIZER_BETAS)
  return generator_optimizer, discriminator_optimizer

def efficient_zero_grad(model):

  for param in model.parameters():
    param.grad = None


def forward_and_backward(model, data, loss_function, targets):

  outputs = model(data)
  error = loss_function(outputs, targets)
  error.backward()
  return error.item()

## Generator and Discriminator Training

In [33]:
def train_step(generator, discriminator, real_data, \
  loss_function, generator_optimizer, discriminator_optimizer, device = device):
  
  # 1. PREPARATION
  # Set real and fake labels.
  real_label, fake_label = 1.0, 0.0
  # Get images on CPU or GPU as configured and available
  # Also set 'actual batch size', whih can be smaller than BATCH_SIZE
  # in some cases.
  real_images = real_data[0].to(device)
  actual_batch_size = real_images.size(0)
  label = torch.full((actual_batch_size,1), real_label, device=device)
  
  # 2. TRAINING THE DISCRIMINATOR
  # Zero the gradients for discriminator
  efficient_zero_grad(discriminator)
  # Forward + backward on real images, reshaped
  real_images = real_images.view(real_images.size(0), -1)
  error_real_images = forward_and_backward(discriminator, real_images, \
    loss_function, label)
  # Forward + backward on generated images
  noise = generate_noise(actual_batch_size, device=device)
  generated_images = generator(noise)
  label.fill_(fake_label)
  error_generated_images =forward_and_backward(discriminator, \
    generated_images.detach(), loss_function, label)
  # Optim for discriminator
  discriminator_optimizer.step()
  
  # 3. TRAINING THE GENERATOR
  # Forward + backward + optim for generator, including zero grad
  efficient_zero_grad(generator)
  label.fill_(real_label)
  error_generator = forward_and_backward(discriminator, generated_images, loss_function, label)
  generator_optimizer.step()
  
  # 4. COMPUTING RESULTS
  # Compute loss values in floats for discriminator, which is joint loss.
  error_discriminator = error_real_images + error_generated_images
  # Return generator and discriminator loss so that it can be printed.
  return error_generator, error_discriminator

In [34]:
def perform_epoch(dataloader, generator, discriminator, loss_function, \
    generator_optimizer, discriminator_optimizer, epoch):

  for batch_no, real_data in enumerate(dataloader, 0):
    # Perform training step
    generator_loss_val, discriminator_loss_val = train_step(generator, \
      discriminator, real_data, loss_function, \
      generator_optimizer, discriminator_optimizer)
    # Print statistics and generate image after every n-th batch
    if batch_no % PRINT_STATS_AFTER_BATCH == 0:
      print_training_progress(batch_no, generator_loss_val, discriminator_loss_val)
      generate_image(generator, epoch, batch_no)
      
  # Save models on epoch completion.
  # save_models(generator, discriminator, epoch)
  
  # Clear memory after every epoch
  torch.cuda.empty_cache()

In [None]:
def train_gan():
  """ 
  Train the DCGAN. 
  """

  torch.manual_seed(42)

  dataloader = prepare_dataset()

  generator, discriminator = initialize_models()

  loss_function = initialize_loss()
  generator_optimizer, discriminator_optimizer = initialize_optimizers(generator, discriminator)

  for epoch in range(NUM_EPOCHS):
    print(f'Starting epoch {epoch}...')
    perform_epoch(dataloader, generator, discriminator, loss_function, \
      generator_optimizer, discriminator_optimizer, epoch)

  print('Finished :-)')


if __name__ == '__main__':
  print(device)
  train_gan()

## Create GIF

In [51]:
from PIL import Image
from glob import glob

images = []
image_paths = sorted(glob("./data/images/*_batch100.jpg"))
for image_path in image_paths:
    img = Image.open(image_path)
    images.append(img)

output_path = "output.gif"

images[0].save(output_path, save_all=True, append_images=images[1:], loop=0, duration=250)