$$
\newcommand{\mat}[1]{\boldsymbol {#1}}
\newcommand{\mattr}[1]{\boldsymbol {#1}^\top}
\newcommand{\matinv}[1]{\boldsymbol {#1}^{-1}}
\newcommand{\vec}[1]{\boldsymbol {#1}}
\newcommand{\vectr}[1]{\boldsymbol {#1}^\top}
\newcommand{\rvar}[1]{\mathrm {#1}}
\newcommand{\rvec}[1]{\boldsymbol{\mathrm{#1}}}
\newcommand{\diag}{\mathop{\mathrm {diag}}}
\newcommand{\set}[1]{\mathbb {#1}}
\newcommand{\norm}[1]{\left\lVert#1\right\rVert}
\newcommand{\pderiv}[2]{\frac{\partial #1}{\partial #2}}
\newcommand{\bm}[1]{{\bf #1}}
\newcommand{\bb}[1]{\bm{\mathrm{#1}}}
$$

# Part 3: Generative Adversarial Networks
<a id=part3></a>

In this part we will implement and train a generative adversarial network and apply it to the task of image generation.

In [None]:
import unittest
import os
import sys
import pathlib
import urllib
import shutil
import re
import zipfile
import numpy as np
import torch
import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2

test = unittest.TestCase()
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

In [None]:
sys.path.append('../')
from project.gan import *

import cs3600.plot as plot
import cs3600.download
from hw4.answers import PART3_CUSTOM_DATA_URL as CUSTOM_DATA_URL

DATA_DIR = pathlib.Path.home().joinpath('.pytorch-datasets')
if CUSTOM_DATA_URL is None:
    DATA_URL = 'http://vis-www.cs.umass.edu/lfw/lfw-bush.zip'
else:
    DATA_URL = CUSTOM_DATA_URL

_, dataset_dir = cs3600.download.download_data(out_path=DATA_DIR, url=DATA_URL, extract=True, force=False)

Create a `Dataset` object that will load the extraced images:

In [None]:
import torchvision.transforms as T
from torchvision.datasets import ImageFolder

im_size = 64
tf = T.Compose([
    # Resize to constant spatial dimensions
    T.Resize((im_size, im_size)),
    # PIL.Image -> torch.Tensor
    T.ToTensor(),
    # Dynamic range [0,1] -> [-1, 1]
    T.Normalize(mean=(.5,.5,.5), std=(.5,.5,.5)),
])

ds_gwb = ImageFolder(os.path.dirname(dataset_dir), tf)

OK, let's see what we got. You can run the following block multiple times to display a random subset of images from the dataset.

In [None]:
_ = plot.dataset_first_n(ds_gwb, 50, figsize=(15,10), nrows=5)
print(f'Found {len(ds_gwb)} images in dataset folder.')

In [None]:
x0, y0 = ds_gwb[0]
x0 = x0.unsqueeze(0).to(device)
print(x0.shape)

test.assertSequenceEqual(x0.shape, (1, 3, im_size, im_size))

## Generative Adversarial Nets (GANs)
<a id=part3_2></a>

In [None]:
import torch.optim as optim
from torch.utils.data import DataLoader
from hw4.answers import part3_gan_hyperparams
torch.manual_seed(42)
vanilla = False
wgan = True
spectral = True
# Hyperparams
hp = part3_gan_hyperparams()
batch_size = hp['batch_size']
z_dim = hp['z_dim']

# Data
dl_train = DataLoader(ds_gwb, batch_size, shuffle=True)
im_size = ds_gwb[0][0].shape

# Model
if spectral:
    dsc = SNDiscriminator().to(device)
else:
    dsc = Discriminator(im_size).to(device)
gen = Generator(z_dim, featuremap_size=64).to(device)

weights_init(dsc)
weights_init(gen)

# Optimizer
def create_optimizer(model_params, opt_params):
    opt_params = opt_params.copy()
    optimizer_type = opt_params['type']
    opt_params.pop('type')
    return optim.__dict__[optimizer_type](model_params, **opt_params)

# vanilla GAN
if vanilla:
    dsc_optimizer = create_optimizer(dsc.parameters(), hp['discriminator_optimizer'])
    gen_optimizer = create_optimizer(gen.parameters(), hp['generator_optimizer'])
elif wgan:
    # WGAN
    gen_optimizer = torch.optim.RMSprop(gen.parameters(), lr = 0.00005)
    dsc_optimizer = torch.optim.RMSprop(dsc.parameters(), lr = 0.00005)

elif spectral:
    dsc_optimizer = create_optimizer(dsc.parameters(), hp['discriminator_optimizer'])
    gen_optimizer = create_optimizer(gen.parameters(), hp['discriminator_optimizer'])


def dsc_loss_fn(y_data, y_generated):
    return discriminator_loss_fn(y_data, y_generated, hp['data_label'], hp['label_noise'])

def gen_loss_fn(y_generated):
    return generator_loss_fn(y_generated, hp['data_label'])

def gen_wgan_loss(y_generated):
    return gen_wgan_loss(y_generated,hp['data_label'])

def dsc_wgan_loss(y_data, y_generated):
    return dsc_wgan_loss(y_data, y_generated)

def dsc_wgan_gp_loss(y_data, y_generated, dsc_model):
    return dsc_wgan_gp_loss(y_generated, dsc_model)


# Training
checkpoint_file = 'checkpoints/wgan_sn_gp'
checkpoint_file_final = f'{checkpoint_file}_final'
if os.path.isfile(f'{checkpoint_file}.pt'):
    os.remove(f'{checkpoint_file}.pt')

# Show hypers
hp['batch_size'] = 64
hp['discriminator_optimizer']['betas'] = (0.0, 0.9)

print(hp)

print(vanilla, wgan, spectral)

In [None]:
import IPython.display
import tqdm

num_epochs = 1000
n_cpu = 1
latent_dim = 100
img_size = 28
channels = 1
n_critic = 5
clip_value = 0.01
vanilla = True
wgan = False
# mode = 'wgan' if wgan else 'vanilla'
mode = 'wgan_gd'
if vanilla:
    gen_loss = gen_loss_fn
    dsc_loss = dsc_loss_fn
    
elif wgan:
    gen_loss = gen_wgan_loss
    dsc_loss = dsc_wgan_loss

if os.path.isfile(f'{checkpoint_file_final}.pt'):
    print(f'*** Loading final checkpoint file {checkpoint_file_final} instead of training')
    num_epochs = 0
    gen = torch.load(f'{checkpoint_file_final}.pt', map_location=device)
    checkpoint_file = checkpoint_file_final

    
try:
    dsc_avg_losses, gen_avg_losses = [], []
    for epoch_idx in range(num_epochs):
        # We'll accumulate batch losses and show an average once per epoch.
        dsc_losses, gen_losses = [], []
        print(f'--- EPOCH {epoch_idx+1}/{num_epochs} ---')

        with tqdm.tqdm(total=len(dl_train.batch_sampler), file=sys.stdout) as pbar:
            for batch_idx, (x_data, _) in enumerate(dl_train):
                critic = batch_idx % n_critic == 0
                x_data = x_data.to(device)
                dsc_loss, gen_loss = train_batch(
                    dsc, gen,
                    dsc_loss_fn, gen_loss_fn,
                    dsc_optimizer, gen_optimizer,
                    x_data,critic=critic, mode=mode)
                if not gen_loss:
                    gen_loss = gen_losses[-1]
                dsc_losses.append(dsc_loss)
                gen_losses.append(gen_loss)
                pbar.update()
        
        dsc_avg_losses.append(np.mean(dsc_losses))
        gen_avg_losses.append(np.mean(gen_losses))
        print(f'Discriminator loss: {dsc_avg_losses[-1]}')
        print(f'Generator loss:     {gen_avg_losses[-1]}')
        
        if epoch_idx % 15 == 0:
            save_checkpoint(gen, dsc_avg_losses, gen_avg_losses, checkpoint_file+str(epoch_idx))
            print(f'Saved checkpoint.')
            
#         if epoch_idx % 100 == 0:
        samples = gen.sample(5, with_grad=False)
        fig, _ = plot.tensors_as_images(samples.cpu(), figsize=(6,2))
        IPython.display.display(fig)
        plt.close(fig)
except KeyboardInterrupt as e:
    print('\n *** Training interrupted by user')

In [None]:
# Plot images from best or last model
if os.path.isfile(f'{checkpoint_file}.pt'):
    gen = torch.load(f'{checkpoint_file}.pt', map_location=device)
print('*** Images Generated from best model:')
samples = gen.sample(n=15, with_grad=False).cpu()
fig, _ = plot.tensors_as_images(samples, nrows=3, figsize=(6,6))