### 0. Imports

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data import sampler
from torchvision import datasets, transforms
import tqdm
from torchvision.utils import save_image, make_grid
from torch.optim.lr_scheduler import StepLR, MultiStepLR
import torch.nn.functional as F
import matplotlib.pyplot as plt
from gan import Generator, Discriminator

### 1. Setup

In [None]:
WORKING_ENV = 'PAPERSPACE'
assert WORKING_ENV in ['LABS', 'COLAB', 'PAPERSPACE']

if WORKING_ENV == 'COLAB':
    from google.colab import drive
    %load_ext google.colab.data_table
    content_path = '/content/drive/MyDrive/vae'
    data_path = './data/'
    drive.mount('/content/drive/')

else:
    !pip install ipywidgets
    content_path = '/notebooks'
    data_path = './data/'

In [None]:

mean = torch.Tensor([0.5, 0.5, 0.5])
std = torch.Tensor([0.5, 0.5, 0.5])
unnormalize = transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist())

def denorm(x, channels=None, w=None ,h=None, resize = False):

    x = unnormalize(x)
    if resize:
        if channels is None or w is None or h is None:
            print('Number of channels, width and height must be provided for resize.')
        x = x.view(x.size(0), channels, w, h)
    return x

def show(img):
    npimg = img.cpu().numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)))

if not os.path.exists(content_path/'GAN'):
    os.makedirs(content_path/'GAN')

GPU = True
if GPU:
    device = torch.device("cuda"  if torch.cuda.is_available() else "cpu")
else:
    device = torch.device("cpu")
print(f'Using {device}')

# We set a random seed to ensure that your results are reproducible.
if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
torch.manual_seed(0)

### 2. Load data

In [None]:
batch_size = 128
image_size = 32

transform = transforms.Compose([
     transforms.ToTensor(),
     transforms.Normalize(mean=mean, std=std),
])
# note - data_path was initialized at the top of the notebook
cifar10_train = datasets.CIFAR10(data_path, train=True, download=True, transform=transform)
cifar10_test = datasets.CIFAR10(data_path, train=False, download=True, transform=transform)
loader_train = DataLoader(cifar10_train, batch_size=batch_size)
loader_test = DataLoader(cifar10_test, batch_size=batch_size)

### 3. Define hyperparameters and initialise models

In [None]:
num_epochs = 20
learning_rate = 0.0002
learning_rate_G = 0.0002
learning_rate_D = 0.0001
latent_vector_size = 150

gen_fm = 150
disc_fm = 64

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


def init_models():
    use_weights_init = True

    model_G = Generator().to(device)
    if use_weights_init:
        model_G.apply(weights_init)
    params_G = sum(p.numel() for p in model_G.parameters() if p.requires_grad)
    print("Total number of parameters in Generator is: {}".format(params_G))
    print(model_G)
    print('\n')

    model_D = Discriminator().to(device)
    if use_weights_init:
        model_D.apply(weights_init)
    params_D = sum(p.numel() for p in model_D.parameters() if p.requires_grad)
    print(
        "Total number of parameters in Discriminator is: {}".format(params_D))
    print(model_D)
    print('\n')

    print("Total number of parameters is: {}".format(params_G + params_D))
    return model_G, model_D, params_G, params_D

### 4. Train the model

***Define the loss and initialise optimisers***
NB we are defining a model with historical averaging, hence the 'hist' suffix.

In [None]:
def G_loss_function(y, y_hat):
    loss = nn.BCELoss()
    return loss(y_hat, y)

def D_loss_function(y, y_hat):
    loss = nn.BCELoss()
    return loss(y_hat, y)

def loss_hist(params, params_avg):
    loss = F.mse_loss(params, params_avg)
    return loss

In [None]:
model_G_hist, model_D_hist, params_G_hist, params_D_hist  = init_models()
beta1 = 0.5
optimizerD = torch.optim.Adam(model_D_hist.parameters(), lr=learning_rate_D, betas=(beta1, 0.999))
optimizerG = torch.optim.Adam(model_G_hist.parameters(), lr=learning_rate_G, betas=(beta1, 0.999))

fixed_noise = torch.randn(batch_size, latent_vector_size, 1, 1, device=device)
real_label = 0.9
fake_label = 0.1

Logging

In [None]:
train_losses_G_hist = []
train_losses_D_hist = []
train_losses_G_all = []
train_losses_D_all = []

Training loop

In [None]:
for epoch in range(num_epochs):
    param_avg_G = torch.zeros(params_G_hist).to(device)
    param_avg_D = torch.zeros(params_D_hist).to(device)
    n = 0
    with tqdm.tqdm(loader_train, unit="batch") as tepoch:
        for i, data in enumerate(tepoch):
            train_loss_D = 0
            train_loss_G = 0

            #######################################################################
            #                  ** TRAIN DISCRIMINATOR WITH REAL **
            #######################################################################
            model_D_hist.zero_grad()
            real_cpu = data[0].to(device)
            size = real_cpu.size(0)
            output_real = model_D_hist(real_cpu).view(-1)
            target = torch.full((size,), real_label, dtype=torch.float, device=device)
            errD_real = D_loss_function(target, output_real)
            errD_real.backward()
            D_x = output_real.mean().item()

            #######################################################################
            #                  ** TRAIN DISCRIMINATOR WITH FAKE **
            #######################################################################
            noise = torch.randn(size, latent_vector_size, 1, 1, device=device)
            fake_image = model_G_hist(noise)
            target = target.fill_(fake_label)
            output_fake = model_D_hist(fake_image.detach()).view(-1)

            #######################################################################
            #                          ** UPDATE GRADIENTS **
            #######################################################################
            errD_fake = D_loss_function(target, output_fake)
            errD_fake.backward()

            #######################################################################
            #                          ** PERFORM HISTORICAL AVERAGING **
            #######################################################################
            param_D_current = torch.cat([param.view(-1) for param in model_D_hist.parameters()]).to(device)
            param_avg_D = ( (n*param_avg_D.detach()  + param_D_current.detach()) / (n+1)).to(device)
            errD_hist_avg = loss_hist(param_D_current, param_avg_D)
            errD_hist_avg.backward()
            errD = errD_real + errD_fake + errD_hist_avg
            D_G_z1 = output_fake.mean().item()
            train_loss_D += errD.item()
            optimizerD.step()

            #######################################################################
            #                     ** UPDATE GENERATOR NETWORK **
            #######################################################################
            model_G_hist.zero_grad()
            output = model_D_hist(fake_image).view(-1)
            target = target.fill_(real_label)
            errG_standard = G_loss_function(target, output)
            errG_standard.backward()
            D_G_z2 = output.mean().item()

            #######################################################################
            #            ** PERFORM HISTORICAL AVERAGING ON GENERATOR **
            #######################################################################
            param_G_current = torch.cat([param.view(-1) for param in model_G_hist.parameters()]).to(device)
            param_avg_G = ((n * param_avg_G.detach() + param_G_current.detach())/ (n+1)).to(device)
            errG_avg = loss_hist(param_G_current, param_avg_G)
            errG_avg.backward()
            errG = errG_avg + errG_standard
            train_loss_G += errG.item()
            optimizerG.step()
            train_losses_D_all.append(errD.item())
            train_losses_G_all.append(errG.item())
            n += 1

            # Logging
            if i % 50 == 0:
                tepoch.set_description(f"Epoch {epoch}")
                tepoch.set_postfix(D_G_z=f"{D_G_z1:.3f}/{D_G_z2:.3f}", D_x=D_x,
                                  Loss_D=errD.item(), Loss_G=errG.item())

    if epoch == 0:
        save_image(denorm(real_cpu.cpu()).float(), content_path + '/GAN/historical_averaging/real_samples.png')
    with torch.no_grad():
        fake = model_G_hist(fixed_noise)
        save_image(denorm(fake.cpu()).float(), content_path + '/GAN/historical_averaging/fake_samples_epoch{epoch}.png')
    train_losses_D_hist.append(train_loss_D)
    train_losses_G_hist.append(train_loss_G)


torch.jit.save(torch.jit.trace(model_G_hist, (fixed_noise)), content_path + '/GAN/historical_averaging/GAN_G_model.pth')
torch.jit.save(torch.jit.trace(model_D_hist, (fake)), content_path + '/GAN/historical_averaging/GAN_D_model.pth')


### 5. Show generator samples

In [None]:
input_noise = torch.randn(100, latent_vector_size, 1, 1, device=device)
with torch.no_grad():
    generated = model_G_hist(input_noise).cpu()
    generated = make_grid(denorm(generated)[:100], nrow=10, padding=2, normalize=False,
                        range=None, scale_each=False, pad_value=0)
    plt.figure(figsize=(15,15))
    save_image(generated, content_path/'CW_GAN/Teaching_final.png')
    show(generated)

it = iter(loader_test)
sample_inputs, _ = next(it)
fixed_input = sample_inputs[0:64, :, :, :]

img = make_grid(denorm(fixed_input), nrow=8, padding=2, normalize=False,
                range=None, scale_each=False, pad_value=0)
plt.figure(figsize=(15,15))
show(img)