# ProGAN implementation

[Original video](https://youtu.be/nkQHASviYac)

[Source code](https://github.com/aladdinpersson/Machine-Learning-Collection/tree/master/ML/Pytorch/GANs/ProGAN)

[Paper walkthrough](https://youtu.be/lhs78if-E7E)

[Paper 1](https://arxiv.org/abs/1710.10196), [paper 2](https://research.nvidia.com/sites/default/files/pubs/2017-10_Progressive-Growing-of/karras2018iclr-paper.pdf)

[CelebA-HQ dataset](https://www.kaggle.com/lamsimon/celebahq)

## Import libraries

In [None]:
import os
import torch
import random
import numpy as np
import torchvision
import torch.nn as nn
import multiprocessing
import torch.optim as optim
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from math import log2
from tqdm.notebook import tqdm
from scipy.stats import truncnorm
from torchvision.utils import save_image
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

torch.backends.cudnn.benchmarks = True  # additional performance

## Model

![ProGAN fade](https://miro.medium.com/max/1580/1*-lY_AywUNxaWVmdo0qQ5sA.png)

![ProGAN model](https://aisc.ai.science/static/post-assets/gan-collaborative-post/image20.png)

In [None]:
factors = [1, 1, 1, 1, 1/2, 1/4, 1/8, 1/16, 1/32]

# Weighted Scaled Convolutional layer for equalized learning rate
class WSConv2d(nn.Module):
    """ Inspired and looked at:
        https://github.com/nvnbny/progressive_growing_of_gans/blob/master/modelUtils.py
    """
    def __init__(self, in_channels, out_channels,
                 kernel_size=3, stride=1, padding=1, gain=2):
        super().__init__()

        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (gain / (in_channels * kernel_size ** 2)) ** 0.5
        
        # do not scale the bias of {self.conv} layer
        self.bias = self.conv.bias
        self.conv.bias = None

        # initialize conv layer, '_' for inplace normalization
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)


# normalization replacement for BatchNorm
class PixelNorm(nn.Module):
    def __init__(self):
        super().__init__()
        self.epsilon = 1e-8

    def forward(self, x):
        # mean across channels dim=1; keepdim=True element-wise division
        return x / torch.sqrt(torch.mean(x**2, dim=1, keepdim=True) + self.epsilon)


# block of convolutional layers to make code compact
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_pixelnorm=True):
        super().__init__()
        self.use_pixelnorm = use_pixelnorm
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2)
        self.pn = PixelNorm()

    def forward(self, x):
        x = self.leaky(self.conv1(x))
        if self.use_pixelnorm: x = self.pn(x)
        x = self.leaky(self.conv2(x))
        if self.use_pixelnorm: x = self.pn(x)
        return x


class Generator(nn.Module):
    def __init__(self, z_dim, in_channels, img_channels=3):
        super().__init__()
        self.initial = nn.Sequential(
            PixelNorm(),
            nn.ConvTranspose2d(z_dim, in_channels, 4, 1, 0),  # 1x1 --> 4x4
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels),
            nn.LeakyReLU(0.2),
            PixelNorm(),
        )
        self.initial_rgb = WSConv2d(in_channels, img_channels,
                                    kernel_size=1, stride=1, padding=0)
        self.prog_blocks = nn.ModuleList([])
        self.rgb_layers = nn.ModuleList([self.initial_rgb])

        # factors[i] --> factors[i+1]
        for i in range(len(factors) - 1):
            conv_in_c = int(in_channels * factors[i])
            conv_out_c = int(in_channels * factors[i+1])
            self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c))
            self.rgb_layers.append(WSConv2d(conv_out_c, img_channels,
                                            kernel_size=1, stride=1, padding=0))

    def fade_in(self, alpha, upscaled, generated):
        # alpha should be scalar within [0, 1], and upscale.shape == generated.shape
        return torch.tanh(alpha * generated + (1 - alpha) * upscaled)  # (-1, 1)

    def forward(self, x, alpha, steps):
        out = self.initial(x)  # z_dim 1x1 --> out 4x4

        # steps=0 (4x4), steps=1 (8x8), steps=2 (16x16), etc.
        if steps == 0:
            return self.initial_rgb(out)

        for step in range(steps):
            upscaled = F.interpolate(out, scale_factor=2, mode='nearest')
            out = self.prog_blocks[step](upscaled)

        # to RGB conversion
        final_upscaled = self.rgb_layers[steps-1](upscaled)
        final_out = self.rgb_layers[steps](out)

        return self.fade_in(alpha, final_upscaled, final_out)


class Discriminator(nn.Module):
    def __init__(self, in_channels, img_channels=3):
        super().__init__()
        self.prog_blocks = nn.ModuleList([])
        self.rgb_layers = nn.ModuleList([])
        self.leaky = nn.LeakyReLU(0.2)

        for i in range(len(factors)-1, 0, -1):  # move backward
            conv_in_c = int(in_channels * factors[i])
            conv_out_c = int(in_channels * factors[i-1])
            self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c, use_pixelnorm=False))
            self.rgb_layers.append(WSConv2d(img_channels, conv_in_c,
                                            kernel_size=1, stride=1, padding=0))
        
        # mirror the initial rgb of the Generator; for final 4x4 image resolution
        initial_rgb = WSConv2d(img_channels, in_channels,
                                    kernel_size=1, stride=1, padding=0)
        self.rgb_layers.append(initial_rgb)
        self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
        
        # block for 4x4 resolution
        self.final_block = nn.Sequential(
            WSConv2d(in_channels+1, in_channels),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=4, stride=1, padding=0),
            nn.LeakyReLU(0.2),
            # we use WSConv2d instead of fully-connected layer, this is the same
            WSConv2d(in_channels, 1, kernel_size=1, stride=1, padding=0),
            nn.Flatten(),
        )

    def fade_in(self, alpha, downscaled, out):
        return alpha * out + (1 - alpha) * downscaled

    def minibatch_std(self, x):
        # (N,C,H,W) --> (N) std --> mean scalar --> (N,1,H,W) of repeated scalar value
        batch_statistics = torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        # add it like an additional channel
        return torch.cat([x, batch_statistics], dim=1)  # 512 --> 513 

    def forward(self, x, alpha, steps):
        # steps=0 (4x4), steps=1 (8x8), steps=2 (16x16), etc.
        cur_step = len(self.prog_blocks) - steps
        out = self.leaky(self.rgb_layers[cur_step](x))

        if steps == 0:
            out = self.minibatch_std(out)
            return self.final_block(out)

        # make current step with from RGB conversion
        downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
        out = self.avg_pool(self.prog_blocks[cur_step](out))
        out = self.fade_in(alpha, downscaled, out)

        # make further steps
        for step in range(cur_step+1, len(self.prog_blocks)):
            out = self.avg_pool(self.prog_blocks[step](out))

        out = self.minibatch_std(out)
        return self.final_block(out)


def test():
    Z_DIM = 100
    IN_CHANNELS = 256
    IMG_CHANNELS = 3
    BATCH_SIZE = 5
    gen = Generator(Z_DIM, IN_CHANNELS, IMG_CHANNELS)
    disc = Discriminator(IN_CHANNELS, IMG_CHANNELS)

    for img_size in [4, 8, 16, 32, 64, 128, 256, 512, 1024]:
        num_steps = int(log2(img_size / 4))
        x = torch.randn((BATCH_SIZE, Z_DIM, 1, 1))
        z = gen(x, alpha=0.5, steps=num_steps)
        assert z.shape == (BATCH_SIZE, IMG_CHANNELS, img_size, img_size)
        out = disc(z, alpha=0.5, steps=num_steps)
        assert out.shape == (BATCH_SIZE, 1)
        print(f'Success! At img size: {img_size}')

test()

Success! At img size: 4
Success! At img size: 8
Success! At img size: 16
Success! At img size: 32
Success! At img size: 64
Success! At img size: 128
Success! At img size: 256
Success! At img size: 512
Success! At img size: 1024


## Configuration parameters

In [None]:
START_TRAIN_AT_IMG_SIZE = 4
DATASET = 'celeba_hq'
CHECKPOINT_DISC = 'disc-celeb.pth.tar'
CHECKPOINT_GEN = 'gen-celeb.pth.tar'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
SAVE_MODEL = True
LOAD_MODEL = True
LEARNING_RATE = 1e-3
BATCH_SIZES = [512, 256, 128, 128, 64, 64, 32, 16, 8]
CHANNELS_IMG = 3
Z_DIM = 256  # should be 512 in original paper
IN_CHANNELS = 256  # should be 512 in original paper
LAMBDA_GP = 10
PROGRESSIVE_EPOCHS = [10] * len(BATCH_SIZES)
FIXED_NOISE = torch.randn(8, Z_DIM, 1, 1).to(DEVICE)
NUM_WORKERS = multiprocessing.cpu_count()

# !mv disc.pth.tar disc-celeb.pth.tar
# !mv gen.pth.tar gen-celeb.pth.tar

## Utils
  * for **plot_to_tensorboard** function review [Pytorch TensorBoard](https://colab.research.google.com/drive/1uftzYqL8gwmp2wvvujBHvRxfGLOxnO54)
  * for **gradient_penalty** function review [WGAN-GP implementation](https://colab.research.google.com/drive/193niPgYt2Qm8Ok5Dy2_z2xocwN-bygXd)
  * for **save_checkpoint** and **load_checkpoint** functions review [How to save and load models in Pytorch](https://colab.research.google.com/drive/1h5G53mEm_ez7zJj1qju-RF8FrjBlNCX2)
  * for **seed_everything** function (not used here) review [Quick Tips](https://colab.research.google.com/drive/1ZNBzRnUG2cJvYemO2Gz5JqOKIbuaBa4Y)

In [None]:
# Print losses occasionally and print to tensorboard
def plot_to_tensorboard(writer, loss_disc, loss_gen, real, fake, tensorboard_step):
    writer.add_scalar('Loss discriminator', loss_disc, global_step=tensorboard_step)

    with torch.no_grad():
        # take out (up to) 8 examples to plot
        img_grid_real = torchvision.utils.make_grid(real[:8], normalize=True)
        img_grid_fake = torchvision.utils.make_grid(fake[:8], normalize=True)
        writer.add_image('Real', img_grid_real, global_step=tensorboard_step)
        writer.add_image('Fake', img_grid_fake, global_step=tensorboard_step)


def gradient_penalty(disc, real, fake, alpha, train_step, device):
    BATCH_SIZE, C, H, W = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    # Calculate discriminator scores
    mixed_scores = disc(interpolated_images, alpha, train_step)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(inputs=interpolated_images,
                                   outputs=mixed_scores,
                                   grad_outputs=torch.ones_like(mixed_scores),
                                   create_graph=True,
                                   retain_graph=True,)[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty


def save_checkpoint(model, optimizer, filename):
    print('=> Saving checkpoint')
    checkpoint = {
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print('=> Loading checkpoint')
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def seed_everything(seed=42):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def generate_examples(gen, step, truncation=0.7, n=100):
    """ Tried using truncation trick here but not sure it actually helped anything,
        you can remove it if you like and just sample from torch.randn
    """
    dir = 'saved_examples'
    os.makedirs(dir, exist_ok=True)

    gen.eval()
    alpha = 1.0
    for i in range(n):
        with torch.no_grad():
            noise = torch.tensor(truncnorm.rvs(-truncation,
                                               truncation,
                                               size=(1, Z_DIM, 1, 1)),
                                 device=DEVICE,
                                 dtype=torch.float32)
            img = gen(noise, alpha, step)
            save_image(img * 0.5 + 0.5, os.path.join(dir, f'img_{i}.jpg'))
    gen.train()

## Prepare Train

In [None]:
def get_loader(step):
    image_size = 4 * 2 ** step  # 4->0, 8->1, 16->2, 32->3, 64->4
    print(f'Current image size: {image_size}')
    transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Normalize([0.5 for _ in range(CHANNELS_IMG)],
                                 [0.5 for _ in range(CHANNELS_IMG)]),
    ])
    batch_size = BATCH_SIZES[step]
    dataset = datasets.ImageFolder(root=DATASET, transform=transform)
    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=NUM_WORKERS,
                        pin_memory=True,)
    return loader, dataset


def train(disc, gen, loader, dataset, step, alpha, opt_disc, opt_gen,
          tensorboard_step, writer, scaler_gen, scaler_disc, epoch, num_epochs):

    loop = tqdm(loader, leave=False)
    loop.set_description(f'Epoch [{epoch}/{num_epochs}]')

    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(DEVICE)
        cur_batch_size = real.shape[0]

        # Train Discriminator: min -(E[D(real)] - E[D(fake)])
        noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(DEVICE)
        with torch.cuda.amp.autocast():
            fake = gen(noise, alpha, step)
            disc_real = disc(real, alpha, step)
            disc_fake = disc(fake.detach(), alpha, step)
            gp = gradient_penalty(disc, real, fake, alpha, step, DEVICE)

            loss_disc = (
                -(torch.mean(disc_real) - torch.mean(disc_fake))
                + LAMBDA_GP * gp

                # keep the discriminator output from drifting too far away from zero
                + 0.001 * torch.mean(disc_real ** 2)
            )

        opt_disc.zero_grad()
        scaler_disc.scale(loss_disc).backward()
        scaler_disc.step(opt_disc)
        scaler_disc.update()

        # Train Generator: max E[D(fake)] or min -(E[D(fake)])
        with torch.cuda.amp.autocast():
            gen_fake = disc(fake, alpha, step)
            loss_gen = -torch.mean(gen_fake)

        opt_gen.zero_grad()
        scaler_gen.scale(loss_gen).backward()
        scaler_gen.step(opt_gen)
        scaler_gen.update()

        # update alpha and ensure it less than 1
        # half epochs alpha < 1, half epochs alpha == 1
        alpha += cur_batch_size / (len(dataset) * PROGRESSIVE_EPOCHS[step] * 0.5)
        alpha = min(alpha, 1)

        if batch_idx % 500 == 0:
            with torch.no_grad():
                fixed_fakes = gen(FIXED_NOISE, alpha, step) * 0.5 + 0.5
            plot_to_tensorboard(writer,
                                loss_disc.item(),
                                loss_gen.item(),
                                real.detach(),
                                fixed_fakes.detach(),
                                tensorboard_step,)
            tensorboard_step += 1

        loop.set_postfix(gp=gp.item(), loss_disc=loss_disc.item(),)

    return tensorboard_step, alpha


def main():
    # initialize gen and disc, note: discriminator should be called critic,
    # according to WGAN paper (since it no longer outputs between [0, 1])
    # but really who cares..
    disc = Discriminator(IN_CHANNELS, CHANNELS_IMG).to(DEVICE)
    gen = Generator(Z_DIM, IN_CHANNELS, CHANNELS_IMG).to(DEVICE)

    # initialize optimizers and scalers for FP16 training
    opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99))
    opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99))
    scaler_disc = torch.cuda.amp.GradScaler()
    scaler_gen = torch.cuda.amp.GradScaler()

    # for tensorboard plotting
    writer = SummaryWriter(LOG_DIR)

    if (LOAD_MODEL
        and os.path.exists(CHECKPOINT_DISC)
        and os.path.exists(CHECKPOINT_GEN)):
        load_checkpoint(CHECKPOINT_DISC, disc, opt_disc, LEARNING_RATE)
        load_checkpoint(CHECKPOINT_GEN, gen, opt_gen, LEARNING_RATE)

    gen.train()
    disc.train()

    tensorboard_step = 0
    # start at step that corresponds to img size that we set in configuration
    step = int(log2(START_TRAIN_AT_IMG_SIZE / 4))

    for num_epochs in PROGRESSIVE_EPOCHS[step:]:
        alpha = 1e-5  # start with very low alpha
        loader, dataset = get_loader(step)

        for epoch in range(1, num_epochs+1):
            tensorboard_step, alpha = train(
                disc, gen, loader, dataset, step, alpha, opt_disc, opt_gen,
                tensorboard_step, writer, scaler_gen, scaler_disc,
                epoch, num_epochs)

            if SAVE_MODEL:
                save_checkpoint(disc, opt_disc, CHECKPOINT_DISC)
                save_checkpoint(gen, opt_gen, CHECKPOINT_GEN)

        generate_examples(gen, step)
        step += 1  # progress to the next img size

## Download datasets: "CelebA-HQ" and "Cats vs Dogs"

CelebA-HQ: 2.55 G

Cats vs Dogs: 800 M

In [None]:
# Get dataset from Kaggle

# Colab's file access feature
from google.colab import files

# Upload `kaggle.json` file
uploaded = files.upload()

# Retrieve uploaded file and print results
for fn in uploaded.keys():
  print('User uploaded file "{name}" with length {length} bytes'.format(
      name=fn, length=len(uploaded[fn])))


# Then copy kaggle.json into the folder where the API expects to find it.
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!ls ~/.kaggle

# Download the dataset
# !kaggle datasets list -s celeb
!kaggle datasets download -d lamsimon/celebahq
# !kaggle competitions download -c dogs-vs-cats

kaggle.json
Downloading celebahq.zip to /content
100% 2.54G/2.55G [01:15<00:00, 63.5MB/s]
100% 2.55G/2.55G [01:15<00:00, 36.2MB/s]


In [None]:
# Extract data
import zipfile
import tarfile

def extract(fname):
    if fname.endswith('.tar.gz') or fname.endswith('.tgz'):
        ref = tarfile.open(fname, mode='r:gz')
    elif fname.endswith('.tar'):
        ref = tarfile.open(fname, mode='r:')
    elif fname.endswith('.tar.bz2') or fname.endswith('.tbz'):
        ref = tarfile.open(fname, mode='r:bz2')
    elif fname.endswith('.zip'):
        ref = zipfile.ZipFile(fname, mode='r')

    ref.extractall()
    ref.close()

extract('celebahq.zip')
# extract('train.zip')

# remove archives
!rm celebahq.zip
# !rm test1.zip train.zip sampleSubmission.csv

# move train into subfolder for datasets.ImageFolder
# !mkdir -p cats-dogs/images
# !mv -T train cats-dogs/images

Copy saved models from Google if necessary.

In [None]:
# from google.colab import drive
# drive.mount('/content/gdrive')

# copy_from = '/content/gdrive/MyDrive/Colab Notebooks/PyTorch tutorial'

# # !ls -hal '$copy_from'

# !cp '$copy_from'/'gen-celeb.pth.tar' .
# !cp '$copy_from'/'disc-celeb.pth.tar' .

# !cp '$copy_from'/'gen-cats-dogs.pth.tar' .
# !cp '$copy_from'/'disc-cats-dogs.pth.tar' .

Mounted at /content/gdrive


# Run TensorBoard

In [None]:
# Run TensorBoard

# Delete previous logs dir
LOG_DIR = 'logs/gan1'
if os.path.exists(LOG_DIR):
    !rm -rf $LOG_DIR

# To fix the error, because PyTorch and TensorFlow are installed both:
# AttributeError: module 'tensorflow._api.v2.io.gfile' has no attribute 'get_filesystem'
import tensorflow as tf
import tensorboard as tb
tf.io.gfile = tb.compat.tensorflow_stub.io.gfile

# Load the TensorBoard notebook extension
%load_ext tensorboard

# Start TensorBoard before training to monitor it in progress
%tensorboard --logdir $LOG_DIR

# Reload TensorBoard
%reload_ext tensorboard

## Train

In [None]:
START_TRAIN_AT_IMG_SIZE = 32  # 4

main()

## Train cats and dogs

In [None]:
START_TRAIN_AT_IMG_SIZE = 64  # 4
DATASET = 'cats-dogs'
CHECKPOINT_GEN = 'gen-cats-dogs.pth.tar'
CHECKPOINT_DISC = 'disc-cats-dogs.pth.tar'

In [None]:
# Run TensorBoard

# Delete previous logs dir
LOG_DIR = 'logs/cats-dogs'
if os.path.exists(LOG_DIR):
    !rm -rf $LOG_DIR

# To fix the error, because PyTorch and TensorFlow are installed both:
# AttributeError: module 'tensorflow._api.v2.io.gfile' has no attribute 'get_filesystem'
import tensorflow as tf
import tensorboard as tb
tf.io.gfile = tb.compat.tensorflow_stub.io.gfile

# Load the TensorBoard notebook extension
%load_ext tensorboard

# Start TensorBoard before training to monitor it in progress
%tensorboard --logdir $LOG_DIR

# Reload TensorBoard
%reload_ext tensorboard

In [None]:
torch.cuda.empty_cache()

main()

Copy saved models to Google drive if necessary.

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

copy_to = '/content/gdrive/MyDrive/Colab Notebooks/PyTorch tutorial'

# !ls -hal '$copy_to'

!cp 'gen-celeb.pth.tar' '$copy_to'
!cp 'disc-celeb.pth.tar' '$copy_to'

!cp 'gen-cats-dogs.pth.tar' '$copy_to'
!cp 'disc-cats-dogs.pth.tar' '$copy_to'

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
