
DCGAN with Spectral Normalisation Layers
==============
`https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html`

Based on the Pytorch tutorial by Nathan Inkawhich `<https://github.com/inkawhich>`





**Define imports**

In [0]:
%matplotlib inline
!pip install -q torch torchvision livelossplot
!pip install torchsummary

from __future__ import print_function
import argparse
import os
from os.path import isfile
import random
from matplotlib.colors import Normalize
import torch
import torch.nn as nn
import torch.nn.parallel
from time import sleep
import torchvision
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torchsummary import summary
import torchvision.datasets as dset
from livelossplot import PlotLosses
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as plt_image
import matplotlib.animation as animation
from IPython.display import HTML

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

**Define a class for the Custom Dataset**

In [0]:
class CustomDataset(CIFAR10):
    def __init__(self, root: str, class_names: list, train=True, transform=None, target_transform=None,
                 download=False):

      super().__init__(root, transform=transform, target_transform=target_transform, download=download)

      self.image_classes = [self.get_index(name) for name in class_names]
      self.targets, self.data = self.get_classes(self.image_classes)

    def get_index(self, class_name: str) -> int:
      name_dict = {'plane': 0, 'bird': 2, 'horse': 7}
      return name_dict[class_name]

    def get_classes(self, class_indexes: set):
      class_dict = {i:[] for i in class_indexes}
      for index in range(len(self.data)):
        target = self.targets[index]      

        if target in class_indexes:
          image = self.data[index]
          class_dict[target].append(image)

      data = np.concatenate([np.array(class_dict[ind]) for ind in class_dict])
      targets = np.concatenate([np.array([ind]*len(class_dict[ind])) for ind in class_dict])
      return targets, data

Define Settings
------

-  **batch_size** - the batch size used in training. The DCGAN paper
   uses a batch size of 128
-  **num_epochs** - number of training epochs to run. Training for
   longer will probably lead to better results but will also take much
   longer
-  **lr** - learning rate for training. As described in the DCGAN paper,
   this number should be 0.0002
-  **beta1** - beta1 hyperparameter for Adam optimizers. As described in
   paper, this number should be 0.5




In [0]:
# Batch size during training
batch_size = 64

# Number of training epochs
num_epochs = 50

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

Prepare Data
------

- **select_classes** - list of classes to train on
- valid combinations include ['bird', 'horse'] and ['plane', 'horse']

In [0]:
# helper function to make getting another batch of data easier
def cycle(iterable):
    while True:
        for x in iterable:
            yield x

class_names = ['airplane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
select_classes = ['plane', 'horse']

train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

# Create the dataset
training_set = CustomDataset('root', select_classes, train=True, download=True, transform=train_transform)
testing_set = CustomDataset('root', select_classes, train=False, download=True, transform=train_transform)

# Create the data loaders
train_loader = DataLoader(training_set, shuffle=True, batch_size=batch_size, drop_last=True)
test_loader = DataLoader(testing_set, shuffle=True, batch_size=batch_size, drop_last=True)

data_iterator = iter(cycle(train_loader))
print('Dataset contains the following image classes: ' + str(select_classes))
print(f'> Size of dataset {len(train_loader.dataset) + len(test_loader.dataset)}')

# Visualise some of the data
plt.figure(figsize=(10,10))
images, labels = next(data_iterator)

for i in range(64):
  plt.subplot(8,8,i+1)
  plt.xticks([])
  plt.yticks([])
  plt.grid(False)
  plt.imshow(images[i].permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
  plt.xlabel(class_names[labels[i]])

In [0]:
# custom weights initialization called on generator_network and discriminator_network
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

Define Models
-------------
Create the Generator and Discriminator


In [0]:
class Generator(nn.Module):
    def __init__(self, f=32):
        super(Generator, self).__init__()
        self.generate = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( 100, f * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(f * 8),
            nn.ReLU(True),
            # state size. (f*8) x 4 x 4
            nn.ConvTranspose2d(f * 8, f * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(f * 4),
            nn.ReLU(True),
            # state size. (f*4) x 8 x 8
            nn.ConvTranspose2d( f * 4, f * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(f * 2),
            nn.ReLU(True),
            # state size. (f*2) x 16 x 16
            nn.ConvTranspose2d( f * 2, f, 4, 2, 1, bias=False),
            nn.BatchNorm2d(f),
            nn.ReLU(True),
            # state size. (f) x 32 x 32
            nn.ConvTranspose2d( f, 3, kernel_size=1, stride=1, padding=0, bias=False),
            nn.Tanh()
            # state size. (3) x 64 x 64
        )

    def forward(self, input):
        return self.generate(input)


class Discriminator(nn.Module):
    def __init__(self, f=32):
        super(Discriminator, self).__init__()
        self.discriminate = nn.Sequential(
            # input is (3) x 64 x 64
            torch.nn.utils.spectral_norm(nn.Conv2d(3, f, 4, 2, 1, bias=False)),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (f) x 32 x 32
            torch.nn.utils.spectral_norm(nn.Conv2d(f, f * 2, 4, 2, 1, bias=False)),
            nn.BatchNorm2d(f * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (f*2) x 16 x 16
            torch.nn.utils.spectral_norm(nn.Conv2d(f * 2, f * 4, 4, 2, 1, bias=False)),
            nn.BatchNorm2d(f * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (f*4) x 8 x 8
            torch.nn.utils.spectral_norm(nn.Conv2d(f * 4, f * 8, 4, 2, 1, bias=False)),
            nn.BatchNorm2d(f * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (f*8) x 4 x 4
            torch.nn.utils.spectral_norm(nn.Conv2d(f * 8, 1, 2, 2, 0, bias=False)),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.discriminate(input)

Instantiate the generator and apply the ``weights_init``
function

In [0]:
# Create the generator and discriminator
discriminator_network = Discriminator().to(device)
generator_network = Generator().to(device)

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
generator_network.apply(weights_init)
discriminator_network.apply(weights_init)

# Print the models
print('------------------ Generator Architecture -------------------')
print(generator_network)
summary(generator_network, input_size=(100, 1, 1))

print('\n------------------ Discriminator Architecture -------------------')
print(discriminator_network)
summary(discriminator_network, input_size=(3, 32, 32))

Create checkpoints to save weights during training

In [0]:
def save_checkpoint(dis, gen, d_opt, g_opt, epoch, d_losses, g_losses, gen_loss_arr, dis_loss_arr, plot):
  check_point = {'epoch': epoch + 1,
  'd_model_state': dis.state_dict(),
  'g_model_state': gen.state_dict(),
  'd_optimiser_state': d_opt.state_dict(),
  'g_optimiser_state': g_opt.state_dict(),
  'd_losses': d_losses,
  'g_losses': g_losses,
  'live_g': gen_loss_arr,
  'live_d': dis_loss_arr,
  'plot': plot
  }
  torch.save(check_point, 'checkpoint.tar')
  return check_point

def load_checkpoint(discriminator, generator, optimiser_D, optimiser_G):
  check_point = torch.load('checkpoint.tar')
  epoch = check_point['epoch']
  discriminator.load_state_dict(check_point['d_model_state'])
  generator.load_state_dict(check_point['g_model_state'])
  optimiser_G.load_state_dict(check_point['g_optimiser_state'])
  optimiser_D.load_state_dict(check_point['d_optimiser_state'])
  d_losses, g_losses = check_point['d_losses'], check_point['g_losses']
  gen_loss_arr, dis_loss_arr = check_point['live_g'], check_point['live_d']
  plot = check_point['plot']
  return epoch, d_losses, g_losses, gen_loss_arr, dis_loss_arr, plot

**Create the Optimisers and Loss Function**

In [0]:
# Initialize BCELoss function
loss_function = nn.BCELoss()

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(discriminator_network.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(generator_network.parameters(), lr=lr, betas=(beta1, 0.999))

**Main Training Loop**

In [0]:
# Lists to keep track of progress
liveplot = PlotLosses()
G_losses = []
D_losses = []

# number of times to train discriminator before training generator
step = 5

def train(data_loaders: list, num_epochs: int, liveplot, epoch=0, settings=None):

  if settings is not None:
    epoch, errD, errG = settings['epoch'], settings['D_losses'], settings['G_losses']
    gen_loss_arr, dis_loss_arr = settings['gen_loss_arr'], settings['dis_loss_arr']
    liveplot = settings['plot']

  
  print("Starting Training Loop...")
  # For each epoch
  for epoch in range(epoch, num_epochs):
      
      def iteration(loader, gen_loss_arr, dis_loss_arr):
        for i, data in enumerate(loader, 0):

            #----------------------------------------------------
            # Update Discriminator network: maximize log(D(x)) + log(1 - D(G(z)))
            ## Train with all-real batch
            discriminator_network.zero_grad()
            # Format batch
            real_cpu = data[0].to(device)
            b_size = real_cpu.size(0)
            label = torch.full((b_size,), 1, device=device)
            # Forward pass real batch through D
            output = discriminator_network(real_cpu).view(-1)
            # Calculate loss on all-real batch
            errD_real = loss_function(output, label)
            # Calculate gradients for D in backward pass
            errD_real.backward()
            D_x = output.mean().item()

            ## Train with all-fake batch
            # Generate batch of latent vectors
            noise = torch.randn(b_size, 100, 1, 1, device=device)
            # Generate fake image batch with G
            fake_batch = generator_network(noise)
            label.fill_(0)  # fake label
            # Classify all fake batch with D
            output = discriminator_network(fake_batch.detach()).view(-1)
            # Calculate D's loss on the all-fake batch
            errD_fake = loss_function(output, label)
            # Calculate the gradients for this batch
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            # Add the gradients from the all-real and all-fake batches
            errD = errD_real + errD_fake
            # Update D
            optimizerD.step()

            if i % step == 0:
                #----------------------------------------------------
                # Update Generator network: maximize log(D(G(z)))
                generator_network.zero_grad()
                label.fill_(1)  # fake labels are real for generator cost
                # Since we just updated D, perform another forward pass of all-fake batch through D
                output = discriminator_network(fake_batch).view(-1)
                # Calculate G's loss based on this output
                errG = loss_function(output, label)
                # Calculate gradients for G
                errG.backward()
                D_G_z2 = output.mean().item()
                # Update G
                optimizerG.step()

                # Output training stats
                if i % (step * 10) == 0:
                    tl = sum([len(lo) for lo in data_loaders])
                    print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                          % (epoch, num_epochs, i, tl,
                            errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))    
                    
                    # Save Losses for plotting late
                    G_losses.append(errG.item())
                    D_losses.append(errD.item())

        
                    gen_loss_arr = np.append(gen_loss_arr, errG.item())
                    dis_loss_arr = np.append(dis_loss_arr, errD.item())
        
        return gen_loss_arr, dis_loss_arr
                        
      gen_loss_arr = np.zeros(0)
      dis_loss_arr = np.zeros(0)
      
      # Iterate over each data loader in a single epoch (the training and testing set)
      for loader in data_loaders:
        gen_loss_arr, dis_loss_arr = iteration(loader, gen_loss_arr, dis_loss_arr)

      # Update the graph after each epoch
      liveplot.update({
      'Generator loss': gen_loss_arr.mean(),
      'Discriminator loss': dis_loss_arr.mean()
      })
      liveplot.draw()
      sleep(1.)

      # Check how the generator is doing by saving G's output on noise
      with torch.no_grad():
          fake = generator_network(torch.randn(64, 100, 1, 1, device=device)).detach().cpu()

      # Plot generator images (created from random noise) after each epoch
      g = generator_network.generate(torch.randn(fake.size(0), 100, 1, 1).to(device))
      plt.grid(False)
      plt.imshow(torchvision.utils.make_grid(g).cpu().data.permute(0,2,1).contiguous().permute(2,1,0).clamp(0,1), cmap=plt.cm.binary)

      # save model weights after each epoch
      save_checkpoint(discriminator_network, generator_network, optimizerD, optimizerG, epoch, D_losses, G_losses, gen_loss_arr, dis_loss_arr, liveplot)
    

## ------ READ IN CHECKPOINT FILE IF IT EXISTS --------
if isfile('checkpoint.tar'):
  epoch, loss_d, loss_g, gen_loss_arr, dis_loss_arr, liveplot = load_checkpoint(discriminator_network, generator_network, optimizerD, optimizerG)
  settings = {'epoch': epoch, 'D_losses':loss_d, 'G_losses':loss_g, 'gen_loss_arr':gen_loss_arr, 'dis_loss_arr': dis_loss_arr, 'plot': liveplot}
else:
  settings = None

train([train_loader, test_loader], 250, liveplot, settings=settings)

**Plot the losses and for each training iteration**

In [0]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

**Plot the fake images**

In [0]:
plt.figure(figsize=(8,8))
g = generator_network(torch.randn(64, 100, 1, 1, device=device)).detach().cpu()
img_list = [g[i].permute(0,2,1).contiguous().permute(2,1,0).clamp(0,1)]

for i in range(64):
  plt.subplot(8,8,i+1)
  plt.xticks([])
  plt.yticks([])
  plt.grid(False)
  plt.imshow(g[i].permute(0,2,1).contiguous().permute(2,1,0).clamp(0,1), cmap=plt.cm.binary)