# Progressive GAN

* `Progressive Growing of GANs for Improved Quality, Stability, and Variation`, [arXiv:1710.10196](https://arxiv.org/abs/1710.10196)
  * Tero Karras, Timo Aila, Samuli Laine, and Jaakko Lehtinen

* This code is available to tensorflow version 2.0
* Implemented by [`tf.keras.layers`](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/layers) [`tf.losses`](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/losses)

## Import modules

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import os
import sys
import time
import glob

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import PIL
import imageio
from IPython import display

import tensorflow as tf
from tensorflow.keras import layers

sys.path.append(os.path.dirname(os.path.abspath('.')))
from utils.image_utils import *
from utils.ops import *

os.environ["CUDA_VISIBLE_DEVICES"]="0"

## Setting hyperparameters

In [None]:
# Training Flags (hyperparameter configuration)
model_name = 'progressive-gan'
train_dir = os.path.join('train', model_name, 'exp1')
dataset_name = 'cifar10'
assert dataset_name in ['cifar10']

training_phase_epoch = 1 # epoch
transition_phase_epoch = 1 # epoch

save_model_epochs = 10
print_steps = 20
save_images_epochs = 1
batch_size = 16
learning_rate_D = 1e-3
learning_rate_G = 1e-3
k = 1 # the number of step of learning D before learning G (Not used in this code)
num_examples_to_generate = 16
noise_dim = 512
gp_lambda = 10

CIFAR_SIZE = 32

## Load the CIFAR10 dataset

In [None]:
# Load training and eval data from tf.keras
if dataset_name == 'cifar10':
  (train_images, train_labels), _ = \
      tf.keras.datasets.cifar10.load_data()
else:
  pass

train_images = train_images.astype('float32')
train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1]

## Set up dataset with `tf.data`

### create input pipeline with `tf.data.Dataset`

In [None]:
def resize(image, size):
  image = tf.image.resize(image, [size, size],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  return image

#### create 4x4 dataset

In [None]:
#tf.random.set_seed(219)
# for train
N = len(train_images)
train_dataset4 = tf.data.Dataset.from_tensor_slices(train_images[:N])
train_dataset4 = train_dataset4.shuffle(buffer_size=N)
train_dataset4 = train_dataset4.map(lambda x: resize(x, 4))
train_dataset4 = train_dataset4.batch(batch_size=batch_size)
print(train_dataset4)

#### create 8x8 dataset

In [None]:
#tf.random.set_seed(219)
train_dataset8 = tf.data.Dataset.from_tensor_slices(train_images[:N])
train_dataset8 = train_dataset8.shuffle(buffer_size=N)
train_dataset8 = train_dataset8.map(lambda x: resize(x, 8))
train_dataset8 = train_dataset8.batch(batch_size=batch_size)
print(train_dataset8)

#### create 16x16 dataset

In [None]:
#tf.random.set_seed(219)
train_dataset16 = tf.data.Dataset.from_tensor_slices(train_images[:N])
train_dataset16 = train_dataset16.shuffle(buffer_size=N)
train_dataset16 = train_dataset16.map(lambda x: resize(x, 16))
train_dataset16 = train_dataset16.batch(batch_size=batch_size)
print(train_dataset16)

#### create 32x32 dataset

In [None]:
#tf.random.set_seed(219)
train_dataset32 = tf.data.Dataset.from_tensor_slices(train_images[:N])
train_dataset32 = train_dataset32.shuffle(buffer_size=N)
train_dataset32 = train_dataset32.batch(batch_size=batch_size)
print(train_dataset32)

## Create the generator and discriminator models

In [None]:
class PixelNormalization(tf.keras.Model):
  def __init__(self, epsilon=1e-8, name='PixelNorm'):
    super(PixelNormalization, self).__init__(name=name)
    self.epsilon = epsilon
    
  def call(self, inputs):
    # This code is borrowed from official PGGAN code (https://github.com/tkarras/progressive_growing_of_gans)
    # Shape of data in original code is [bs, c, h, w], but shape of data in my code is [bs, h, w, c]
    # So, set the axis of reduce_mean as -1 (channel axis).
    return inputs * tf.math.rsqrt(tf.reduce_mean(tf.square(inputs), axis=-1, keepdims=True) + self.epsilon)

In [None]:
class G_Block(tf.keras.Model):
  def __init__(self, filters, name):
    super(G_Block, self).__init__(name=name)
    self.upsample = layers.UpSampling2D()
    self.conv1 = layers.Conv2D(filters, 3, padding='same', activation=tf.nn.leaky_relu,
                               kernel_initializer='he_normal')
    self.conv2 = layers.Conv2D(filters, 3, padding='same', activation=tf.nn.leaky_relu,
                               kernel_initializer='he_normal')
    self.pn = PixelNormalization()
    
  def call(self, inputs):
    up = self.upsample(inputs)
    conv1 = self.conv1(up)
    conv1 = self.pn(conv1)
    conv2 = self.conv2(conv1)
    conv2 = self.pn(conv2)
    
    return conv2

In [None]:
class G_Init_Block(tf.keras.Model):
  def __init__(self, filters, name):
    super(G_Init_Block, self).__init__(name=name)
    self.filters = filters
    self.dense = layers.Dense(filters * 4 * 4, activation=tf.nn.leaky_relu,
                              kernel_initializer='he_normal')
    self.conv = layers.Conv2D(filters, 3, padding='same', activation=tf.nn.leaky_relu,
                              kernel_initializer='he_normal')
    self.pn = PixelNormalization()
    
  def call(self, inputs):
    dense = self.dense(inputs)
    dense = self.pn(dense)
    dense = tf.reshape(dense, shape=[-1, 4, 4, self.filters])
    conv = self.conv(dense)
    conv = self.pn(conv)
    
    return conv

In [None]:
class to_RGB(tf.keras.Model):
  def __init__(self, name):
    super(to_RGB, self).__init__(name=name)
    self.conv = layers.Conv2D(3, 1, padding='same', kernel_initializer='he_normal')
    
  def call(self, inputs):
    conv = self.conv(inputs)
    
    return conv

In [None]:
class Generator(tf.keras.Model):
  """Build a generator that maps latent space to real space.
    G(z): z -> x
  """
  def __init__(self):
    super(Generator, self).__init__()
    self.block1 = G_Init_Block(512, '4x4')  # [bs, 4, 4, 512]
    self.block2 = G_Block(512, '8x8')       # [bs, 8, 8, 512]
    self.block3 = G_Block(512, '16x16')     # [bs, 16, 16, 512]
    self.block4 = G_Block(512, '32x32')     # [bs, 32, 32, 512]
    self.to_RGB = to_RGB('0xto_rgb')          # [bs, height, width, 3]
    self.upsample = layers.UpSampling2D()

  def call(self, inputs, current_resolution, current_phase, alpha=0.0):
    """Run the model."""
    #assert current_resolution in [4, 8, 16, 32]
    #assert current_phase in ['training', 'transition']
    
    # inputs: [1, 1, 512]
    outputs = block1 = self.block1(inputs)
    
    if current_resolution > 4:
      outputs = block2 = self.block2(outputs)
      prev_outputs = block1
      
    if current_resolution > 8:
      outputs = block3 = self.block3(outputs)
      prev_outputs = block2
      
    if current_resolution > 16:
      outputs = block4 = self.block4(outputs)
      prev_outputs = block3
    
    generated_images = self.to_RGB(outputs)
    
    if current_phase == 'transition':
      prev_outputs = self.upsample(self.to_RGB(prev_outputs))
      generated_images = alpha * generated_images + (1. - alpha) * prev_outputs
    
    return generated_images

In [None]:
class D_Block(tf.keras.Model):
  def __init__(self, filters1, filters2, name):
    super(D_Block, self).__init__(name=name)
    self.conv1 = layers.Conv2D(filters1, 3, padding='same', activation=tf.nn.leaky_relu,
                               kernel_initializer='he_normal')
    self.conv2 = layers.Conv2D(filters2, 3, padding='same', activation=tf.nn.leaky_relu,
                               kernel_initializer='he_normal')
    self.downsample = layers.AveragePooling2D()
    
  def call(self, inputs):
    conv1 = self.conv1(inputs)
    conv2 = self.conv2(conv1)
    downsample = self.downsample(conv2)
    
    return downsample

In [None]:
class D_Last_Block(tf.keras.Model):
  def __init__(self, filters1, filters2, name):
    super(D_Last_Block, self).__init__(name=name)
    self.conv1 = layers.Conv2D(filters1, 3, padding='same', activation=tf.nn.leaky_relu,
                               kernel_initializer='he_normal')
    self.conv2 = layers.Conv2D(filters1, 4, padding='same', activation=tf.nn.leaky_relu,
                               kernel_initializer='he_normal')
    self.flatten = layers.Flatten()
    self.dense = layers.Dense(1, kernel_initializer='he_normal')
    
  def call(self, inputs):
    conv1 = self.conv1(inputs)
    conv2 = self.conv2(conv1)
    flatten = self.flatten(conv2)
    dense = self.dense(flatten)

    return dense

In [None]:
class from_RGB(tf.keras.Model):
  def __init__(self, filters, name):
    super(from_RGB, self).__init__(name=name)
    self.conv = layers.Conv2D(filters, 1, padding='same', activation=tf.nn.leaky_relu,
                              kernel_initializer='he_normal')
    
  def call(self, inputs):
    conv = self.conv(inputs)
    
    return conv

In [None]:
class Discriminator(tf.keras.Model):
  """Build a discriminator that discriminate real image x whether real or fake.
    D(x): x -> [0, 1]
  """
  def __init__(self):
    super(Discriminator, self).__init__()
    self.from_RGB = from_RGB(512, '0xfrom_rgb')   # [bs, height, width, 3]
    self.block1 = D_Block(512, 512, '32x32')    # [bs, 16, 16, 32]
    self.block2 = D_Block(512, 512, '16x16')    # [bs, 8, 8, 64]
    self.block3 = D_Block(512, 512, '8x8')      # [bs, 4, 4, 128]
    self.block4 = D_Last_Block(512, 512, '4x4') # [bs, 1]
    self.downsample = layers.AveragePooling2D()

  def call(self, inputs, current_resolution, current_phase, alpha=0.0):
    """Run the model."""
    #assert current_resolution in [4, 8, 16, 32]
    #assert current_phase in ['training', 'transition']
    
    new_inputs = self.from_RGB(inputs)
    
    if current_phase == 'transition':
      smoothing_inputs = self.from_RGB(self.downsample(inputs))
    
    if current_resolution > 16:
      new_inputs = block1 = self.block1(new_inputs)
      if current_phase == 'transition' and current_resolution == 32:
        new_inputs = alpha * block1 + (1. - alpha) * smoothing_inputs
    
    if current_resolution > 8:
      new_inputs = block2 = self.block2(new_inputs)
      if current_phase == 'transition' and current_resolution == 16:
        new_inputs = alpha * block2 + (1. - alpha) * smoothing_inputs
      
    if current_resolution > 4:
      new_inputs = block3 = self.block3(new_inputs)
      if current_phase == 'transition' and current_resolution == 8:
        new_inputs = alpha * block3 + (1. - alpha) * smoothing_inputs
      
    discriminator_logits = self.block4(new_inputs)
    
    return discriminator_logits

In [None]:
generator = Generator()
discriminator = Discriminator()

## Define the loss functions and the optimizer

In [None]:
# use logits for consistency with previous code I made
# `tf.losses` and `tf.keras.losses` are the same API (alias)
bce = tf.losses.BinaryCrossentropy(from_logits=True)
mse = tf.losses.MeanSquaredError()

In [None]:
def WGANLoss(logits, is_real=True):
  """Computes Wasserstain GAN loss

  Args:
    logits (`2-rank Tensor`): logits
    is_real (`bool`): boolean, Treu means `-` sign, False means `+` sign.

  Returns:
    loss (`0-rank Tensor`): the WGAN loss value.
  """
  if is_real:
    return -tf.reduce_mean(logits)
  else:
    return tf.reduce_mean(logits)

In [None]:
def GANLoss(logits, is_real=True, use_lsgan=True):
  """Computes standard GAN or LSGAN loss between `logits` and `labels`.

  Args:
    logits (`2-rank Tensor`): logits.
    is_real (`bool`): True means `1` labeling, False means `0` labeling.
    use_lsgan (`bool`): True means LSGAN loss, False means standard GAN loss.

  Returns:
    loss (`0-rank Tensor`): the standard GAN or LSGAN loss value. (binary_cross_entropy or mean_squared_error)
  """
  if is_real:
    labels = tf.ones_like(logits)
  else:
    labels = tf.zeros_like(logits)
    
  if use_lsgan:
    loss = mse(labels, tf.nn.sigmoid(logits))
  else:
    loss = bce(labels, logits)
    
  return loss

In [None]:
def discriminator_loss(real_logits, fake_logits):
  # losses of real with label "1"
  real_loss = WGANLoss(logits=real_logits, is_real=True)
  # losses of fake with label "0"
  fake_loss = WGANLoss(logits=fake_logits, is_real=False)
  
  return real_loss + fake_loss

In [None]:
def generator_loss(fake_logits):
  # losses of Generator with label "1" that used to fool the Discriminator
  return WGANLoss(logits=fake_logits, is_real=True)

In [None]:
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate_D, beta_1=0.0, beta_2=0.99, epsilon=1e-8)
generator_optimizer = tf.keras.optimizers.Adam(learning_rate_G, beta_1=0.0, beta_2=0.99, epsilon=1e-8)

## Checkpoints (Object-based saving)

In [None]:
checkpoint_dir = train_dir
if not tf.io.gfile.exists(checkpoint_dir):
  tf.io.gfile.makedirs(checkpoint_dir)
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

## Training

In [None]:
# keeping the random vector constant for generation (prediction) so
# it will be easier to see the improvement of the gan.
# To visualize progress in the animated GIF
const_random_vector_for_saving = tf.random.uniform([num_examples_to_generate, noise_dim],
                                                   minval=-1.0, maxval=1.0)

### Define training one step function

In [None]:
def get_discriminator_tvars(current_resolution):
  d_tvars = []
  for var in discriminator.trainable_variables:
    if current_resolution >= int(var.name.split('/')[1].split('x')[0]):
      d_tvars.append(var)
      
  return d_tvars

In [None]:
def get_generator_tvars(current_resolution):
  g_tvars = []
  for var in generator.trainable_variables:
    if current_resolution >= int(var.name.split('/')[1].split('x')[0]):
      g_tvars.append(var)
  
  return g_tvars

In [None]:
# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def discriminator_train_step(images, current_resolution, current_phase, alpha=0.0):
  # generating noise from a uniform distribution
  noise = tf.random.uniform([batch_size, noise_dim], minval=-1.0, maxval=1.0)

  with tf.GradientTape() as disc_tape:
    generated_images = generator(noise, current_resolution, current_phase, alpha)

    real_logits = discriminator(images, current_resolution, current_phase, alpha)
    fake_logits = discriminator(generated_images, current_resolution, current_phase, alpha)
   
    # interpolation of x hat for gradient penalty : epsilon * real image + (1 - epsilon) * generated image
    epsilon = tf.random.uniform([batch_size, 1, 1, 1])
    epsilon = tf.expand_dims(tf.stack([tf.stack([epsilon]*current_resolution, axis=1)]*current_resolution, axis=1), axis=3)
    interpolated_images_4gp = epsilon * images + (1. - epsilon) * generated_images
    with tf.GradientTape() as gp_tape:
      gp_tape.watch(interpolated_images_4gp)
      interpolated_images_logits = discriminator(interpolated_images_4gp, current_resolution, current_phase, alpha)
      
    gradients_of_interpolated_images = gp_tape.gradient(interpolated_images_logits, interpolated_images_4gp)
    norm_grads = tf.sqrt(tf.reduce_sum(tf.square(gradients_of_interpolated_images), axis=[1, 2, 3]))
    gradient_penalty_loss = tf.reduce_mean(tf.square(norm_grads - 1.))
    
    disc_loss = discriminator_loss(real_logits, fake_logits) + gp_lambda * gradient_penalty_loss
    gen_loss = generator_loss(fake_logits)

  d_tvars = get_discriminator_tvars(current_resolution)
  gradients_of_discriminator = disc_tape.gradient(disc_loss, d_tvars)
  discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, d_tvars))
      
  return gen_loss, disc_loss

In [None]:
# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def generator_train_step(current_resolution, current_phase, alpha=0.0):
  # generating noise from a uniform distribution
  noise = tf.random.uniform([batch_size, noise_dim], minval=-1.0, maxval=1.0)

  with tf.GradientTape() as gen_tape:
    generated_images = generator(noise, current_resolution, current_phase, alpha)

    fake_logits = discriminator(generated_images, current_resolution, current_phase, alpha)
    gen_loss = generator_loss(fake_logits)

  g_tvars = get_generator_tvars(current_resolution)
  gradients_of_generator = gen_tape.gradient(gen_loss, g_tvars)
  generator_optimizer.apply_gradients(zip(gradients_of_generator, g_tvars))

In [None]:
def print_log(global_epoch, step, global_step, start_time, disc_loss, gen_loss):
  epochs = global_epoch + (step+1) / float(num_batches_per_epoch)
  duration = time.time() - start_time
  examples_per_sec = batch_size / float(duration)
  display.clear_output(wait=True)
  print("Epochs: {:.2f} global_step: {} loss_D: {:.3g} loss_G: {:.3g} ({:.2f} examples/sec; {:.3f} sec/batch)".format(
          epochs, global_step, disc_loss, gen_loss, examples_per_sec, duration))

In [None]:
def print_samples(current_resolution, random_vector_for_sampling=None):
  if random_vector_for_sampling is None:
    random_vector_for_sampling = tf.random.uniform([num_examples_to_generate, 1, 1, noise_dim],
                                                   minval=-1.0, maxval=1.0)
  sample_images = generator(random_vector_for_sampling, current_resolution, 'training')
  print_or_save_sample_images(sample_images.numpy(), num_examples_to_generate)

In [None]:
# Initialize full size networks for making full size static graph
TARGET_SIZE = CIFAR_SIZE
_, _ = discriminator_train_step(tf.random.normal([batch_size, TARGET_SIZE, TARGET_SIZE, 3]), TARGET_SIZE, 'transition')
generator_train_step(TARGET_SIZE, 'transition')

In [None]:
print('Start Training.')
num_batches_per_epoch = int(N / batch_size)
global_step = 1 #tf.Variable(0, trainable=False)
global_epoch = 0
num_learning_critic = 0

train_datasets = [train_dataset8, train_dataset16, train_dataset32]

# 4 x 4 training phase
current_resolution = 4
for epoch in range(training_phase_epoch):
  for step, images in enumerate(train_dataset4):
    start_time = time.time()
    
    gen_loss, disc_loss = discriminator_train_step(images, current_resolution, 'training')
    generator_train_step(current_resolution, 'training')
    if global_step % (print_steps//current_resolution) == 0:
      print_log(global_epoch, step, global_step, start_time, disc_loss, gen_loss)
      print_samples(current_resolution)
    
    global_step += 1
  global_epoch += 1


for resolution, train_dataset in enumerate(train_datasets):
  current_resolution = 2**(resolution+3)
  
  # transition phase
  for epoch in range(transition_phase_epoch):
    for step, images in enumerate(train_dataset):
      start_time = time.time()
      alpha = (epoch * num_batches_per_epoch + step) / float(transition_phase_epoch * num_batches_per_epoch)
      gen_loss, disc_loss = discriminator_train_step(images, current_resolution, 'transition', alpha)
      generator_train_step(current_resolution, 'transition', alpha)
      
      if global_step % (print_steps//current_resolution) == 0:
        print_log(global_epoch, step, global_step, start_time, disc_loss, gen_loss)
        print_samples(current_resolution)
      
      global_step += 1
    global_epoch += 1
      
  # training phase
  for epoch in range(training_phase_epoch):
    for step, images in enumerate(train_dataset):
      start_time = time.time()
      gen_loss, disc_loss = discriminator_train_step(images, current_resolution, 'training')
      generator_train_step(current_resolution, 'training')
      
      if global_step % (print_steps//current_resolution) == 0:
        print_log(global_epoch, step, global_step, start_time, disc_loss, gen_loss)
        print_samples(current_resolution)
      
      global_step += 1
    global_epoch += 1


#   if (epoch + 1) % save_images_epochs == 0:
#     display.clear_output(wait=True)
#     print("This images are saved at {} epoch".format(epoch+1))
#     sample_images = generator(const_random_vector_for_saving, training=False)
#     print_or_save_sample_images(sample_images.numpy(), num_examples_to_generate,
#                                 is_square=True, is_save=True, epoch=epoch+1,
#                                 checkpoint_dir=checkpoint_dir)

#   # saving (checkpoint) the model every save_epochs
#   if (epoch + 1) % save_model_epochs == 0:
#     checkpoint.save(file_prefix=checkpoint_prefix)
    
print('Training Done.')

In [None]:
# generating after the final epoch
# display.clear_output(wait=True)
# sample_images = generator(const_random_vector_for_saving, training=False)
# print_or_save_sample_images(sample_images.numpy(), num_examples_to_generate,
#                             is_square=True, is_save=True, epoch=epoch+1,
#                             checkpoint_dir=checkpoint_dir)

## Restore the latest checkpoint

In [None]:
# restoring the latest checkpoint in checkpoint_dir
#checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

## Display an image using the epoch number

In [None]:
# display_image(max_epochs, checkpoint_dir)

## Generate a GIF of all the saved images.

In [None]:
# filename = model_name + '_' + dataset_name + '.gif'
# generate_gif(filename, checkpoint_dir)

In [None]:
# display.Image(filename=filename + '.png')