# Package Setup and Initialization
Import all required libraries

In [9]:
import torch as th
import torchvision
from torch.utils.data import DataLoader

import os
import matplotlib.pyplot as plt
import torchvision.utils as vutils
import numpy as np

# Setup configuration
Setup hyperparameters for the network to use

In [10]:
# Network
NOISE_SIZE = 100
NOISE_TYPE = 'normal' 
CRITIC_FEATURE_MAP_DEPTH = 64               # in WGAN the Discriminator is called the Critic
GENERATOR_FEATURE_MAP_DEPTH = 64

# Training 
SAVE_CHECKPOINT_EVERY = 10 
SAVE_IMAGE_EVERY = 10
BATCH_SIZE = 64
EPOCHS = 200
DISCRIMINATOR_LR = 5e-5
GENERATOR_LR = 1e-5 
TRUE_LABEL_VALUE = 1
FAKE_LABEL_VALUE = 0

# WGAN params
NUM_EPOCHS = 5
CRITIC_ITERATIONS = 5
# WEIGHT_CLIP = 0.1

# WGAN-GP params
LAMBDA_GP = 10

# Version nr
VERSION = 21

# Setup device and data

In [11]:
# Device
device = th.device('cuda' if th.cuda.is_available() else 'cpu')

print("Running on device:", device)

# Dataset
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
])

#data_directory = "/kaggle/input/"
data_directory = "./data/faces/"
dataset = torchvision.datasets.ImageFolder(data_directory, transform=transform)

dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

Running on device: cuda


# Network
Critic (Discriminator) and Generator
Note that the Critic in WGAN doest not have a sigmoid activation function in its last layer as opposed to the DCGAN variant. 

In [12]:
# DISCRIMINATOR
class CriticBlock(th.nn.Module):
    def __init__(self, in_channels: int, out_channels: int, first: bool = False, last: bool = False) -> None:
        assert(not (first and last)) # block can't be both first and last
        super().__init__()
        if first:
            self.main = th.nn.Sequential(
                th.nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False),
                th.nn.LeakyReLU(0.2, inplace=True),
            )
            
        elif last:
            self.main = th.nn.Sequential(
                th.nn.Conv2d(in_channels, out_channels, 3, 1, 0, bias=False),
                # No Sigmoid activation in WGAN in last layer
            )

        else:
            self.main = th.nn.Sequential(
                th.nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False),
                th.nn.InstanceNorm2d(out_channels, affine=True), # WGAN-GP does not use BatchNorm for the Critic (LayerNorm or InstanceNorm)
                th.nn.LeakyReLU(0.2, inplace=True),
            )

    def forward(self, x: th.Tensor) -> th.Tensor:
        return self.main(x)

class Critic(th.nn.Module):
    def __init__(self, feature_map_depth: int) -> None:
        super().__init__()
        self.main = th.nn.Sequential(
            CriticBlock(3, feature_map_depth, first=True),
            CriticBlock(feature_map_depth, feature_map_depth * 2),
            CriticBlock(feature_map_depth * 2, feature_map_depth * 4),
            CriticBlock(feature_map_depth * 4, feature_map_depth * 8),
            CriticBlock(feature_map_depth * 8, feature_map_depth * 8),
            CriticBlock(feature_map_depth * 8, 1, last=True)
        )

    def forward(self, x: th.Tensor) -> th.Tensor:
        x = self.main(x)
        return x



# GENERATOR
class GeneratorBlock(th.nn.Module):
    def __init__(self, in_channels: int, out_channels: int, first: bool = False, last: bool = False) -> None:
        assert(not (first and last)) # block can't be both first and last
        super().__init__()
        if first:
            self.main = th.nn.Sequential(
                th.nn.ConvTranspose2d(in_channels, out_channels, 3, 1, 0, bias=False),
                th.nn.BatchNorm2d(out_channels),
                th.nn.ReLU(True)
            )
        elif last:
            self.main = th.nn.Sequential(
                th.nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
                th.nn.Tanh()
            )
        else:
            self.main = th.nn.Sequential(
                th.nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
                th.nn.BatchNorm2d(out_channels),
                th.nn.ReLU(True)
            )

    def forward(self, x: th.Tensor) -> th.Tensor:
        return self.main(x)

class Generator(th.nn.Module):
    def __init__(self, noise_size: int, feature_map_depth: int) -> None:
        super().__init__()
        # first layer, no stride. Upsample from 1x1 to 4x4
        self.main = th.nn.Sequential(
            GeneratorBlock(noise_size, feature_map_depth * 8, first=True),
            GeneratorBlock(feature_map_depth * 8, feature_map_depth * 8),
            GeneratorBlock(feature_map_depth * 8, feature_map_depth * 4),
            GeneratorBlock(feature_map_depth * 4, feature_map_depth * 2),
            GeneratorBlock(feature_map_depth * 2, feature_map_depth * 1),
            GeneratorBlock(feature_map_depth * 1, 3, last=True),
        )

    def forward(self, x: th.Tensor) -> th.Tensor:
        x = self.main(x)
        return x


# Optimizer and creating network

In [13]:
# Initialize weights
def weights_init(model):
    classname = model.__class__.__name__
    if classname.find('Conv') != -1:
        th.nn.init.normal_(model.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        th.nn.init.normal_(model.weight.data, 1.0, 0.02)
        th.nn.init.constant_(model.bias.data, 0)


# Create network
generator = Generator(NOISE_SIZE, GENERATOR_FEATURE_MAP_DEPTH).to(device)
generator.apply(weights_init)

critic = Critic(CRITIC_FEATURE_MAP_DEPTH).to(device)
critic.apply(weights_init)

# Optimizer (WGAN uses RMSprop, WGAN-GP uses Adam)
critic_optimizer = th.optim.RMSprop(critic.parameters(), lr=DISCRIMINATOR_LR)
generator_optimizer = th.optim.RMSprop(generator.parameters(), lr=GENERATOR_LR)

generator.train()
critic.train()

Critic(
  (main): Sequential(
    (0): CriticBlock(
      (main): Sequential(
        (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (1): CriticBlock(
      (main): Sequential(
        (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (2): CriticBlock(
      (main): Sequential(
        (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
      )
    )
    (3): CriticBlock(
      (main): Sequential(
        (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
 

# Utility functions

In [14]:
# Constants
# results_path = "kaggle"
# experiment_name = "working"
FULL_PATH = f'output'
fixed_noise = th.randn(64, NOISE_SIZE, 1, 1, device=device)

# Create output folder if it doesn't exist yet
if not os.path.isdir('output'): 
    os.mkdir('output')

# Utility functions
def save_model_checkpoint(epoch: int) -> None:
    make_epoch_directories(epoch)
    checkpoint_path = f'{FULL_PATH}/{epoch}'
    th.save({
        'epoch': epoch,
        'generator_model_state_dict': generator.state_dict(),
        'discriminator_model_state_dict': critic.state_dict(),
        'generator_optimizer_state_dict': generator_optimizer.state_dict(),
        'discriminator_optimizer_state_dict': critic_optimizer.state_dict(),
    }, f'{checkpoint_path}/checkpoint.th')


def make_epoch_directories(epoch: int) -> None:
    epoch_path = f'{FULL_PATH}/{epoch}'
    if not os.path.isdir(epoch_path):
        os.mkdir(epoch_path)


def save_model_image(epoch: int) -> None:
    make_epoch_directories(epoch)
    image_path = f'{FULL_PATH}/{epoch}/images'
    if not os.path.isdir(image_path):
        os.mkdir(image_path)
    random_noise = th.randn(64, NOISE_SIZE, 1, 1, device=device)
    fixed_fakes = generator(fixed_noise).detach().cpu()
    random_fakes = generator(random_noise).detach().cpu()
    save_image_grid(fixed_fakes, f'{image_path}/fixed.png', 'Fixed Noise')
    save_image_grid(random_fakes, f'{image_path}/random.png', 'Random Noise')


def save_image_grid(images, path: str, title: str) -> None:
    plt.figure(figsize=(8,8))
    plt.axis('off')
    plt.title(title)
    plt.imshow(np.transpose(vutils.make_grid(images.to(device)[:64], padding=2, normalize=True).cpu(), (1, 2, 0)))
    plt.savefig(path)
    plt.close()

# Training loop
Main training loop. Note that we use the algorithm outlined in the WGAN paper. So in this case the Critic is updated more frequently (5x) and we use the Loss Function descrbied in the paper (and other resources online see: https://machinelearningmastery.com/how-to-implement-wasserstein-loss-for-generative-adversarial-networks/) 

From this blogpost:
    Critic Loss = [average critic score on real images] – [average critic score on fake images]
    Generator Loss = -[average critic score on fake images]


In [15]:

# MAIN TRAINING LOOP
print("VERSION:", VERSION)
for epoch in range(EPOCHS + 1):
    print('EPOCH: ', epoch)

    for batch_idx, (real, _) in enumerate(dataloader):
        real = real.to(device)
        batch_size = real.size(0)
        
        # TRAIN DISCRIMINATOR (CRITIC) MORE. (5x according to paper)
        for _ in range(CRITIC_ITERATIONS):
            noise = th.randn(batch_size, NOISE_SIZE, 1, 1, device=device)
            fake = generator(noise)
            
            critic_fake = critic(fake).reshape(-1)
    
            critic_real = critic(real).reshape(-1)
            
            # extra '-' because originally we want to maximize, so we minimize the negative.
            # LAMDA_GP * gp is the addition for WGAN-GP
            loss_critic = -(th.mean(critic_real) - th.mean(critic_fake))
            
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            critic_optimizer.step()
            
        
        # TRAIN GENERATOR 
        output = critic(fake).reshape(-1)
        loss_generator = -th.mean(output)
        generator.zero_grad() 
        loss_generator.backward()
        generator_optimizer.step()


    # SAVE MODEL AND IMAGES
    if epoch % SAVE_CHECKPOINT_EVERY == 0:
        print('-> Saving model checkpoint')
        save_model_checkpoint(epoch)
    
    if epoch % SAVE_IMAGE_EVERY == 0:
        print('-> Saving model images')
        save_model_image(epoch)

VERSION: 21
EPOCH:  0
