In [1]:
#Generating Pegasus with DCGAN

#Citations to relevant components of code are referenced throughout the notebook in single line comments

In [2]:
#Import necessary libraries
from __future__ import print_function
import argparse
import random
%matplotlib inline
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.utils.data
import torchvision.utils as vutils
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
# Ignore excessive warnings
import logging
logging.propagate = False 
logging.getLogger().setLevel(logging.ERROR)

Set hyperparameter values during initialization

In [3]:
# Number of workers for dataloader
workers = 2
batch_size = 128
# Training image size - CIFAR-10 is 32x32 by default; however, STL-10 was resized and downscaled to these dimensions using a Transformer
image_size = 32
# Number of channels in training images - 3 when using RGB, 1 for grayscale
nc = 3
# Size of z latent vector 
nz = 100
# Size of feature maps in generator, descriminator
ngf = 64
ndf = 64
# Number of training epochs - observed that with the current hyperparameters, best pegasus are generated between epochs 250-450, before DCGAN begins producing images that too closely resemble individual class labels
epochs = 600
# Learning rate for optimizers - deemed optimal at 0.0003 after many experimental trial runs (initially tested on .0002 as proposed by Radford et al. in the DCGAN paper)
lr = 0.0003
# Beta1 hyperparameter for Adam optimizers
beta1 = 0.5
# Number of GPUs to utilize 
ngpu = 1 #Passing 0 will default to CPU

#Referenced for tuning and GAN optimizations: https://github.com/soumith/ganhacks

Data preparation: CIFAR-10 and STL-10 datasets


In [4]:
def STL_preprocessing():
  
    transforms_ = transforms.Compose([
      #Scale STL-10 down from 96x96 to 32x32 for compatibility with CIFAR-10
      transforms.Resize(32),
      transforms.CenterCrop(32),
      transforms.ToTensor(),
      transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
  ])
  
    dataset = torchvision.datasets.STL10(root="./data",
                                          split="train",
                                          download=True,
                                          transform=transforms_)
    
    #Specify desired class labels: in generating a pegasus, a combination of horse, bird, and airplane was found optimal

    #Initialize arrays and counters for sorting data in buckets
    num_birds = 0
    num_planes = 0
    num_horses = 0

    dataset_total = []

    #Class labels - deer were also considered but ultimately not utilized in the final implementation
    plane_label = 0
    bird_label = 1
    horse_label = 6

    #Iterate and sort in corresponding bucket
    for i in dataset:
        #All of STL-10 labeled data is utilized for training; only 500 per class
        if i[1] == plane_label and num_planes < 500:
            dataset_total.append(i)
            num_planes +=1
        if i[1] == horse_label and num_horses < 500:
            dataset_total.append(i)
            #For experimentation when using STL only consider doubling up on horses
            #dataset_total.append(i)
            num_horses +=1
        if i[1] == bird_label and num_birds < 500:
            dataset_total.append(i)
            num_birds +=1

    #Check data counts for each class
    print("Num planes (STL-10) ", num_planes)   
    print("Num horses (STL-10) ", num_horses)  
    print("Num birds (STL-10) ", num_birds)  
    print("Total (STL-10) ", len(dataset_total))
    print("")   

    return dataset_total

In [5]:
def CIFAR_preprocessing(STL_cleaned):
  
    dataset_train = torchvision.datasets.CIFAR10('data', train=True, download=True, transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()
    ]))

    dataset_test = torchvision.datasets.CIFAR10('data', train=False, download=True, transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()
    ]))
    
    #Merge training and test sets s.t. more data available for pegasus generation
    cifar_dataset = torch.utils.data.ConcatDataset((dataset_train, dataset_test))

    num_planes = 0
    num_birds = 0
    num_horses = 0 

    #Initialize arrays and counters for sorting data in buckets
    ##So far the dataset contains the cleaned STL-10 images
    dataset_total = STL_cleaned

    #Class labels - deer were also considered but ultimately not utilized in the final implementation
    plane_label = 0
    bird_label = 2
    horse_label = 7

    #Iterate through CIFAR and sort in corresponding bucket
    for i in cifar_dataset:
        #Set relevant quantities for each class to define its proportion of the total dataset
        if i[1] == plane_label and num_planes < 1800:
            dataset_total.append(i)
            num_planes +=1
        if i[1] == horse_label and num_horses < 5000:
            #Try doubling up on horse to maintain horse structure 
            dataset_total.append(i)
            #dataset_total.append(i)
            num_horses +=1
        if i[1] == bird_label and num_birds < 900:
            dataset_total.append(i)
            num_birds +=1

    #Check data counts for each class
    print("Num planes (STL-10 and CIFAR-10 merged) ", num_planes)   
    print("Num horses (STL-10 and CIFAR-10 merged) ", num_horses)  
    print("Num birds (STL-10 and CIFAR-10 merged) ", num_birds)  
    print("Total (STL-10 and CIFAR-10 merged) ", len(dataset_total))
    print("")

    return dataset_total

In [6]:
#Additional data cleaning and processing functions which were used during experimentation 

#Transform images to grayscale - experimented as possible solution to force whiteness of the pegasus
def grayscale_transform():
  tograyscale = transforms.Compose([
    torchvision.transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
])
  return tograyscale

#Upscale CIFAR-10 to 96x96 to train alongsize full resolution STL-10; similar transforms were used when testing on 64x64 CIFAR-10 and STL-10 combined 
def CIFAR_to96():
  cifar_transforms = transforms.Compose([
    transforms.Resize(96),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
  return cifar_transforms



DCGAN: Define Generator and Discriminator


In [7]:
#Implementation follows PyTorch documentation DCGAN tutorial: https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html

In [8]:
#Generator
##By default the DCGAN implementation outlined in the Pytorch documentation operators on 64x64 images; the structure was modified to output 32x32 images from the Generator and takes 32x32 input into the Discriminator, so as to work with CIFAR-10: https://wandb.ai/sairam6087/dcgan/reports/DCGAN-on-CIFAR-10--Vmlldzo5NjMyOQ

#G and D operating on 32x32 input/output
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8), #(ngf*8) x 4 x 4
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),  #(ngf*8) x 4 x 4
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True), #(ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False), 
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True), #(ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, nc, 4, 2, 1, bias=False), 
            nn.Tanh() #3x32x32
        )

    def forward(self, input):
        return self.main(input)

In [9]:
#Discriminator

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential( #3x32x32
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            #Using Leaky ReLU activation function in discriminator
            nn.LeakyReLU(0.2, inplace=True), #(ndf) x 16 x 16
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True), #(ndf*2) x 8 x 8
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True), #(ndf*4) x 4 x 4
            nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

In [10]:
#Initialize weights randomly from a normal distribution
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

Define the training function


In [11]:
#https://wandb.ai/sairam6087/dcgan/reports/DCGAN-on-CIFAR-10--Vmlldzo5NjMyOQ - furthers the DCGAN Pytorch implementation by adding label smoothing
def train(gen, disc, device, dataloader, optimizerG, optimizerD, criterion, epoch, iters):
  gen.train()
  disc.train()
  img_list = []
  fixed_noise = torch.randn(64, nz, 1, 1, device=device)

  # Establish convention for real and fake labels during training (with label smoothing)
  real_label = 0.9
  fake_label = 0.1
  for i, data in enumerate(dataloader, 0):

      # Update Discriminator

      # Train with all-real batch
      disc.zero_grad()
      # Format batch
      real_cpu = data[0].to(device)
      b_size = real_cpu.size(0)
      label = torch.full((b_size,), real_label, device=device)
      # Forward pass real batch through D
      output = disc(real_cpu).view(-1)
      # Calculate loss on all-real batch
      errD_real = criterion(output, label)
      # Calculate gradients for D in backward pass
      errD_real.backward()
      D_x = output.mean().item()

      #Train with all-fake batch
      # Generate batch of latent vectors
      noise = torch.randn(b_size, nz, 1, 1, device=device)
      # Generate fake image batch with G
      fake = gen(noise)
      label.fill_(fake_label)
      # Classify all fake batch with D
      output = disc(fake.detach()).view(-1)
      # Calculate D loss on the all-fake batch
      errD_fake = criterion(output, label)
      # Calculate the gradients for this batch
      errD_fake.backward()
      D_G_z1 = output.mean().item()
      # Add the gradients from the all-real and all-fake batches
      errD = errD_real + errD_fake
      # Update D
      optimizerD.step()

      # Update Generator
      gen.zero_grad()
      label.fill_(real_label)
      # Since we just updated D, perform another forward pass of all-fake batch through D
      output = disc(fake).view(-1)
      # Calculate G's loss based on this output
      errG = criterion(output, label)
      # Calculate gradients for G
      errG.backward()
      D_G_z2 = output.mean().item()
      # Update G
      optimizerG.step()  

      #Output generator images at each epoch and save best pegasus output
      if (iters % 500 == 0) or ((epoch == epochs-1) and (i == len(dataloader)-1)):
          print("Current epoch ", epoch)  
          print("")
          #Print training stats
          print("Generator loss: ", errG.item())
          print("Discriminator loss: ", errD.item())
          print("")
          with torch.no_grad():
              fake = gen(fixed_noise).detach().cpu()
          plt.rcParams['figure.dpi'] = 175
          plt.grid(False)
          plt.imshow(torchvision.utils.make_grid(fake).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
          plt.show()

          #To improve training speed, training and candidate image generation was carried out using wandb.ai: https://wandb.ai/
      iters += 1

**Main training loop**

In [12]:
def main():
    use_cuda = not False and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    
    #Set seeds
    manualSeed = 42
    random.seed(manualSeed)
    torch.manual_seed(manualSeed)

    #Python random seed
    random.seed(manualSeed)
    #Pytorch random seed       
    torch.manual_seed(manualSeed) 
    #Numpy random seed
    np.random.seed(manualSeed)
    torch.backends.cudnn.deterministic = True

    #Load the dataset
    print("Loading the dataset...")
    print("")
    
    #First pre-process STL-10
    prepared_STL = STL_preprocessing()

    #Next process CIFAR-10
    dataset_merged = CIFAR_preprocessing(prepared_STL)

    trainloader = torch.utils.data.DataLoader(dataset_merged, shuffle=True, batch_size=batch_size, drop_last=True)
    
    # Create the generator
    netG = Generator(ngpu).to(device)

    # Handle multi-gpu if desired
    if (device.type == 'cuda') and (ngpu > 1):
        netG = nn.DataParallel(netG, list(range(ngpu)))

    # Apply the weights_init function to randomly initialize all weights
    #  to mean=0, stdev=0.2.
    netG.apply(weights_init)

    # Create the Discriminator
    netD = Discriminator(ngpu).to(device)

    # Handle multi-gpu if desired
    if (device.type == 'cuda') and (ngpu > 1):
        netD = nn.DataParallel(netD, list(range(ngpu)))

    # Apply the weights_init function to randomly initialize all weights
    #  to mean=0, stdev=0.2.
    netD.apply(weights_init)

    # Initialize BCELoss function
    criterion = nn.BCELoss()

    # Setup Adam optimizers for both G and D
    optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
    
    iters = 0
    for epoch in range(1, epochs + 1):
        train(netG, netD, device, trainloader, optimizerG, optimizerD, criterion, epoch, iters)
        

Start training

In [None]:
if __name__ == '__main__':
    main()

Dataset - Google Drive Integrations


In [None]:
# optional Google drive integration - this will allow you to save and resume training, and may speed up redownloading the dataset
from google.colab import drive
drive.mount('/content/drive')