In [None]:
# gradient penalty is based on https://gist.github.com/cwkx/e85fefe8bffbe3b3598f8f582914eb12, which is released under the MIT licesne
# model structure and classes based off https://github.com/Lornatang/WassersteinGAN_GP-PyTorch, which is released under the Apache-2.0 license
# make sure you reference any code you have studied as above here

# imports
import math
import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
import time
import os
from google.colab import files
from torch import autograd
from tqdm import tqdm
import random

# hyperparameters
batch_size  = 64
n_channels  = 3
latent_size = 512
dataset = 'stl10'
# dataset = 'mnist'
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
# optional Google drive integration - this will allow you to save and resume training, and may speed up redownloading the dataset
from google.colab import drive
drive.mount('/content/drive')

In [None]:
opt = {"lr": 1e-4,
       "b1": 0.5,
       "b2": 0.9}
N_CRITIC = 5
N_SAMPLE = 100
SAVEFIGS = True
SAVEFOLDER = str(time.time())
os.mkdir('drive/MyDrive/0WGAN-Outputs/'+SAVEFOLDER)

In [None]:
# helper function to make getting another batch of data easier
def cycle(iterable):
    while True:
        for x in iterable:
            yield x

# you may use cifar10 or stl10 datasets
if dataset == 'cifar10':
    train_loader = torch.utils.data.DataLoader(
        torchvision.datasets.CIFAR10('drive/My Drive/training/cifar10', train=True, download=True, transform=torchvision.transforms.Compose([
            torchvision.transforms.Resize((batch_size,batch_size)),
            torchvision.transforms.CenterCrop(batch_size),
            torchvision.transforms.RandomHorizontalFlip(p=0.4),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])),
        shuffle=True, batch_size=batch_size, drop_last=True
    )
    class_names = ['airplane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# stl10 has larger images which are much slower to train on. You should develop your method with CIFAR-10 before experimenting with STL-10
if dataset == 'stl10':
    train_loader = torch.utils.data.DataLoader(
        torchvision.datasets.STL10('drive/My Drive/training/stl10', split='train+unlabeled', download=True, transform=torchvision.transforms.Compose([
            torchvision.transforms.Resize((batch_size,batch_size)),
            torchvision.transforms.CenterCrop(batch_size),
            torchvision.transforms.RandomHorizontalFlip(p=0.4),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])),
    shuffle=True, batch_size=batch_size, drop_last=True)
    train_iterator = iter(cycle(train_loader))
    class_names = ['airplane', 'bird', 'car', 'cat', 'deer', 'dog', 'horse', 'monkey', 'ship', 'truck'] # these are slightly different to CIFAR-10

if dataset == 'mnist':
    train_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('drive/My Drive/training/mnist', train=True, download=True, transform=torchvision.transforms.Compose([
            torchvision.transforms.Grayscale(3),
            torchvision.transforms.Resize((batch_size,batch_size)),
            torchvision.transforms.CenterCrop(batch_size),
            torchvision.transforms.ToTensor()
        ])),
    shuffle=True, batch_size=batch_size, drop_last=True)
    train_iterator = iter(cycle(train_loader))
    class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] 
    n_channels = 1

train_iterator = iter(cycle(train_loader))

In [None]:
# let's view some of the training data
plt.rcParams['figure.dpi'] = 175
x,t = next(train_iterator)
x,t = x.to(device), t.to(device)
plt.grid(False)
plt.imshow(torchvision.utils.make_grid(x).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
plt.show()

**Define a WGAN-GP**

In [None]:
def grad_penalty(M, real_data, fake_data, lmbda=10):
    alpha = torch.rand(real_data.size(0), 1, 1, 1).to(device)
    interpolates = alpha * real_data + ((1 - alpha) * fake_data)
    interpolates = interpolates.to(device)
    interpolates = autograd.Variable(interpolates, requires_grad=True)
    int_disc = M(interpolates)
    gradients = torch.autograd.grad(outputs=int_disc, inputs=interpolates, grad_outputs=torch.ones(int_disc.size()).to(device), create_graph=True, retain_graph=True, only_inputs=True)[0]
    gradients = gradients.view(gradients.size(0), -1) 
    return ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lmbda


def weights_init(m):
    if m.__class__.__name__.find("Conv") != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    elif m.__class__.__name__.find("BatchNorm") != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)
    
    
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.main = nn.Sequential(
            nn.Conv2d(3, 96, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(96, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(512, 1, 4, 1, 0),
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        out = self.main(input)
        out = torch.flatten(out)
        return out


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.main = nn.Sequential(
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 96, 4, 2, 1, bias=False),
            nn.BatchNorm2d(96),
            nn.ReLU(True),
            nn.ConvTranspose2d(96, 3, 4, 2, 1),
            nn.Tanh()
        )


    def forward(self, input: torch.Tensor) -> torch.Tensor:
        out = self.main(input)
        return out

In [None]:
reverse_mean = -0.5/0.5
reverse_std = 1.0/0.5
unnormalise = torchvision.transforms.Compose([ # to return images to correct space for display
    # torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=(reverse_mean,reverse_mean,reverse_mean), std=(reverse_std,reverse_std,reverse_std))
])

generator = Generator().to(device)
discriminator = Discriminator().to(device)

generator = generator.apply(weights_init)
discriminator = discriminator.apply(weights_init)

optimiser_g = torch.optim.Adam(generator.parameters(), lr=opt['lr'], betas=(opt['b1'], opt['b2']))
optimiser_d = torch.optim.Adam(discriminator.parameters(), lr=opt['lr'], betas=(opt['b1'], opt['b2']))

discriminator.train()
generator.train()

epoch=0

**Main training loop**

In [None]:
  # training loop, you will want to train for more than 10 here!
epochs = 1000
while (epoch<epochs):
    
    # array(s) for the performance measures
    logs = {}
    gen_loss_arr = np.zeros(0)
    dis_loss_arr = np.zeros(0)
    grad_pen_arr = np.zeros(0)

    progress = tqdm(enumerate(train_loader), total=len(train_loader))
    # for i in range(len(train_loader)):
    for i, data in progress:
        x = data[0].to(device)
        noise = torch.randn(batch_size, 100, 1, 1, device=device)
        # update discriminator...
        discriminator.zero_grad()
        x_out = discriminator(x)
        x_d_err = torch.mean(x_out)
        D_x = x_out.mean().item()

        # Generate fake image
        fake = generator(noise)
        f_out = discriminator(fake.detach())
        f_d_err = torch.mean(f_out)
        D_G_z1 = f_out.mean().item()

        # calc grad pen
        gradient_penalty = grad_penalty(discriminator, x.data, fake.data)
        errD = -x_d_err + f_d_err + gradient_penalty
        errD.backward()

        # update discriminator
        optimiser_d.step()

        if (i+1) % N_CRITIC == 0:
            # update g every 2 steps
            generator.zero_grad()
            fake = generator(noise)
            fake_out = discriminator(fake)
            errG = -torch.mean(fake_out)
            D_G_z2 = fake_out.mean().item()
            errG.backward()
            optimiser_g.step()


            progress.set_description(f"[{epoch + 1}/{epochs}][{i + 1}/{len(train_loader)}] "
                                                 f"Loss_D: {errD.item():.6f} Loss_G: {errG.item():.6f} "
                                                 f"D(x): {D_x:.6f} D(G(z)): {D_G_z1:.6f}/{D_G_z2:.6f}")


    # generate some examples of where we are up to
    sample = generator(torch.randn(x.size(0), 100, 1, 1).to(device))
    visualised_sample = unnormalise(sample)

    # plot some examples
    plt.rcParams['figure.dpi'] = 100
    plt.grid(False)
    plt.imshow(torchvision.utils.make_grid(visualised_sample[:8]).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
    if SAVEFIGS:
        # fname = str(time.time())+".jpg"
        plt.savefig('drive/MyDrive/0WGAN-Outputs/'+SAVEFOLDER+'/'+str(time.time())+'.jpg')
        # plt.savefig(fname)
    plt.show()
    plt.pause(0.0001)
    epoch = epoch+1

    torch.save({'G':generator.state_dict(), 'optimiser_G':optimiser_g.state_dict(), 'D':discriminator.state_dict(), 'optimiser_D':optimiser_d.state_dict(), 'epoch':epoch}, 'drive/My Drive/training/save.chkpt')

In [None]:
# now show a batch of data for the submission, right click and save the image for your report
sample = unnormalise(generator(torch.randn(x.size(0), 100, 1, 1).to(device)))

plt.rcParams['figure.dpi'] = 175
plt.grid(False)
plt.imshow(torchvision.utils.make_grid(sample).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
plt.savefig("generatedgrid.png")
plt.show()

In [None]:
sample = generator(torch.randn(batch_size, 100, 1, 1).to(device))
z = unnormalise(sample)

# now show some interpolations (note you do not have to do linear interpolations as shown here, you can do non-linear or gradient-based interpolation if you wish)
col_size = int(np.sqrt(batch_size))

z0 = z[0:col_size].repeat(col_size,1,1,1) # z for top row
z1 = z[batch_size-col_size:].repeat(col_size,1,1,1) # z for bottom row

t = torch.linspace(0,1,col_size).unsqueeze(1).repeat(1,col_size).view(batch_size,1,1,1).to(device)

lerp_z = (1-t)*z0 + t*z1 # linearly interpolate between two points in the latent space
# lerp_g = A.decode(lerp_z) # sample the model at the resulting interpolated latents

plt.rcParams['figure.dpi'] = 175
plt.grid(False)
plt.imshow(torchvision.utils.make_grid(lerp_z).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
plt.show()

In [None]:
# optional example code to save your training progress for resuming later if you authenticated Google Drive previously
# torch.save({'G':generator.state_dict(), 'optimiser_G':optimiser_g.state_dict(), 'D':discriminator.state_dict(), 'optimiser_D':optimiser_d.state_dict(), 'epoch':epoch}, 'drive/My Drive/training/save.chkpt')

In [None]:
# # # optional example to resume training if you authenticated Google Drive previously
# # params = torch.load('drive/My Drive/training/save.chkpt')
# # params = torch.load('drive/My Drive/training/Copy of save.chkpt')
# params = torch.load('./stl10.chkpt')

# # params = torch.load('./save-2.chkpt')
# generator.load_state_dict(params['G'])
# optimiser_g.load_state_dict(params['optimiser_G'])
# discriminator.load_state_dict(params['D'])
# optimiser_d.load_state_dict(params['optimiser_D'])
# epoch = params['epoch']

# discriminator.train()
# generator.train()