In [0]:
%%capture
from os.path import exists
from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())
cuda_output = !ldconfig -p|grep cudart.so|sed -e 's/.*\.\([0-9]*\)\.\([0-9]*\)$/cu\1\2/'
accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'
!pip install -q http://download.pytorch.org/whl/{accelerator}/torch-0.4.1-{platform}-linux_x86_64.whl torchvision
!pip install 'livelossplot==0.3.0'

**Main imports**

In [0]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
from torch.autograd import Variable
from livelossplot import PlotLosses
from torch.utils.data import ConcatDataset

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

**Import dataset**

In [0]:
# Helper function to make getting another batch of data easier.
def cycle(iterable):
    while True:
        for x in iterable:
            yield x
            
# Sampling function to select specific class labels.
def get_samples(dataset, class_labels):
  
  indices = []
  
  for i in range(len(dataset)):
    if dataset[i][1] in class_labels:
      indices.append(i)
      
  return indices

class_names = ['airplane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# Use both the traning and test sets for learning.
dataset = ConcatDataset([
    torchvision.datasets.CIFAR10('data', train=True, download=True, transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()
    ])),
    torchvision.datasets.CIFAR10('data', train=False, download=True, transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()
    ]))])

# Select only bird and horse images.
image_ind = get_samples(dataset, [2, 7])

# Wrap the dataset into a loader.
dataset_loader = torch.utils.data.DataLoader(dataset=dataset,
                                             shuffle=False, batch_size=16, drop_last=True,
                                             sampler=torch.utils.data.sampler.SubsetRandomSampler(image_ind))

dataset_iterator = iter(cycle(dataset_loader))

print(f'> Size of dataset {len(dataset_loader.dataset)}')
print(f'> Size of samples {len(image_ind)}')

**View some of the test dataset**

In [0]:
plt.figure(figsize=(10,10))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(dataset_loader.dataset[image_ind[i]][0].permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
    plt.xlabel(class_names[dataset_loader.dataset[image_ind[i]][1]])

**Define a simple model**

In [0]:
# Define the components of the deep convolutional generative adversarial network (DCGAN).

# Define the Generator network.
class GANGen(nn.Module):
  
  def __init__(self):
    super(GANGen, self).__init__()
    
    # Define a linear layer (32//4 because we upsample twice).
    self.seq = nn.Linear(100, 128*(32//4)**2)
    
    # Batch normalisation one.
    self.bn_one = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    
    # Upsample one.
    self.up_one = nn.Upsample(scale_factor=2)
    
    # First convolutional layer.
    self.conv_one = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
    
    # Batch normalisation two.
    self.bn_two = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    
    # Upsample two.
    self.up_two = nn.Upsample(scale_factor=2)
    
    # Second convolutional layer.
    self.conv_two = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
    
    # Batch normalisation three.
    self.bn_three = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    
    # Third convolutional layer.
    self.conv_three = nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)
    
    # Pass through a hyperbolic tangent.
    self.tanh = nn.Tanh()
    
    # Define the activation function used throughout the network.
    self.activ = nn.LeakyReLU(0.1, inplace=True)
    
  def forward(self, x):
    
    # Complete a forward pass.
    x = self.seq(x)
    
    # Dimension after the linear layer.
    x = x.view(x.shape[0], 128, (32//4), (32//4))
    
    # Block one.
    x = self.bn_one(x)
    x = self.up_one(x)
    x = self.conv_one(x)
    
    # Block two.
    x = self.bn_two(x)
    x = self.activ(x)
    x = self.up_two(x)
    x = self.conv_two(x)
    
    # Block three.
    x = self.bn_three(x)
    x = self.activ(x)
    x = self.conv_three(x)
    
    x = self.tanh(x)
    
    return x
    
    
# Define the Discriminator network.
class GANDis(nn.Module):
  
  def __init__(self):
    super(GANDis, self).__init__()
    
    # First convolutional layer.
    self.conv_one = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1)
    
    # Second convolutional layer.
    self.conv_two = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1)
    
    # Batch normalisation two.
    self.bn_two = nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True)
    
    # Third convolutional layer.
    self.conv_three = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
    
    # Batch normalisation three.
    self.bn_three = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    
    # Define a sequential layer (outputs between 0 - 1).
    self.seq = nn.Sequential(nn.Linear(64*4*4, 1), nn.Sigmoid())
    
    # Define the dropout function.
    self.drop = nn.Dropout2d(0.2)
    
    # Define the activation function.
    self.activ = nn.LeakyReLU(0.1, inplace=True)
    
  def forward(self, x):
    
    # Complete a forward pass.
    x = self.conv_one(x)
    x = self.activ(x)
    x = self.drop(x)
    
    x = self.conv_two(x)
    x = self.activ(x)
    x = self.drop(x)
    x = self.bn_two(x)
    
    x = self.conv_three(x)
    x = self.activ(x)
    x = self.drop(x)
    x = self.bn_three(x)
    
    # Reshape the data.
    x = x.view(-1, 64*4*4)
    
    return self.seq(x)
    

# Create the networks.
generator = GANGen().to(device)
discriminator = GANDis().to(device)    
    
# Define the loss function.
loss = torch.nn.BCELoss()

# Initialise the optimisers.
opt_g = torch.optim.Adam(generator.parameters(), lr=0.0005, betas=(0.5, 0.999))
opt_d = torch.optim.Adam(discriminator.parameters(), lr=0.0005, betas=(0.5, 0.999))

print(f'> Number of generator parameters {len(torch.nn.utils.parameters_to_vector(generator.parameters()))}')
print(f'> Number of discriminator parameters {len(torch.nn.utils.parameters_to_vector(discriminator.parameters()))}')

num_epoch = 1000000
liveplot = PlotLosses()

**Main training loop**

In [0]:
# Define ground truth.
ones = torch.FloatTensor(16, 1).fill_(1.0)
zeros = torch.FloatTensor(16, 1).fill_(0.0)
valid = Variable(ones, requires_grad=False)
fake = Variable(zeros, requires_grad=False)
valid, fake = valid.to(device), fake.to(device)

# training loop.
for epoch in range(num_epoch):
  
  # arrays for metrics
  generator_loss_arr = np.zeros(0)
  discriminator_loss_arr = np.zeros(0) 
  
  # Loop through the images of the data loader.
  for i, images in enumerate(dataset_loader):

    images = images[0].to(device)
    real_imgs = Variable(images.type(Tensor))

    # TRAIN THE GENERATOR NETWORK # ~~~~~~~~~~~~~~~~~~~~~~~~~~~
    opt_g.zero_grad()

    # Generate 16 fake images from random noise.
    fake_images = generator(Variable(Tensor(np.random.normal(0, 1, (16, 100)))))

    # Measure loss on on how well the generator fools the discriminator.
    loss_g = loss(discriminator(fake_images), valid)

    # Back-propagate the error.
    loss_g.backward()

    # Step the optimiser.
    opt_g.step()

    # TRAIN THE DISCRIMINATOR NETWORK # ~~~~~~~~~~~~~~~~~~~~~~~
    opt_d.zero_grad()

    # Calculate loss on some fake and some real images.
    loss_f = loss(discriminator(fake_images.detach()), fake)
    loss_r = loss(discriminator(real_imgs), valid)
    loss_d = 0.5 * (loss_f + loss_r)

    # Back-propagate the error.
    loss_d.backward()

    # Step the optimiser.
    opt_d.step()

    # Udate the generator loss array.
    generator_loss_arr = np.append(generator_loss_arr, loss_g.cpu().data)

    # Update the discriminator loss array.
    discriminator_loss_arr = np.append(discriminator_loss_arr, loss_d.cpu().data)
    
    # Show the images generated after every epoch.
    if i == 0:
      
      # Loop through each fake image and display it.
      plt.figure(figsize=(8,8))
      for t in range(16):
        plt.subplot(4,4,t+1)
        plt.grid(False)
        plt.imshow(fake_images[t].cpu().data.numpy().transpose((1,2,0)))

  # NOTE: live plot library has dumb naming forcing our 'test' to be called 'validation'
  liveplot.update({
      'Generator Loss': generator_loss_arr.mean(),
      'Discriminator Loss': discriminator_loss_arr.mean() 
  })
  liveplot.draw()

**Generate a Pegasus by interpolating  the joint latent space encodings of a horse and a bird**

In [0]:
# Search the latent space of the GAN for an image of a pegasus.

# Generate 16 random images from noise.
noise = Variable(Tensor(np.random.normal(0, 1, (16, 100))))
fake_images = generator(noise)

# Loop through each fake image and display it.
plt.figure(figsize=(8,8))
for i in range(16):
    plt.subplot(4,4,i+1)
    plt.grid(False)
    plt.imshow(fake_images[i].cpu().data.numpy().transpose((1,2,0)))

In [0]:
import matplotlib

matplotlib.image.imsave('data/pegasus.png', fake_images[3].cpu().data.numpy().transpose((1,2,0)))

from google.colab import files
files.download( "data/pegasus.png" ) 