# Implementing a Progressive GAN

Mara Rehmer and Linus kleine Kruthaup

In [4]:
!pip install MTCNN

Collecting MTCNN
[?25l  Downloading https://files.pythonhosted.org/packages/67/43/abee91792797c609c1bf30f1112117f7a87a713ebaa6ec5201d5555a73ef/mtcnn-0.1.0-py3-none-any.whl (2.3MB)
[K     |████████████████████████████████| 2.3MB 6.1MB/s 
Installing collected packages: MTCNN
Successfully installed MTCNN-0.1.0


In [17]:
import tensorflow as tf
import numpy as np
from PIL import Image 
import os
from mtcnn.mtcnn import MTCNN
import matplotlib.pyplot as plt 
from keras.layers import Add, Conv2D, Layer, LeakyReLU, Reshape, Dense, AveragePooling2D, UpSampling2D, Flatten
from keras import backend, Sequential, Model, Input
from keras.initializers import RandomNormal
from keras.optimizers import Adam


## 1 Preparing the data

The dataset we chose for our ProGAN to train on is CelebA (https://www.kaggle.com/jessicali9530/celeba-dataset). It consists of 202599 pictures of celebrity faces. The original dataset additionally holds information about different features of the face displayed in an image e.g. if the person is Bald or wears glasses. For our purposes we are only interested in the images though.

As a starting point we used the imgalignceleba version of the dataset which already cropped and alligned all of the imagest to the same dimensions. But this preprocessing is not sufficient for the task at hand. We additionally need to extract only the faces of each picture, since we are not interested in the backround. 

In order to extract the faces we used a pre-trained Neural Network called MTCNN to perform face detection on the images (https://github.com/ipazc/mtcnn). The models detect_faces function takes as input an array of pixels and returns a dictionary of coordinates for feature points like the nose, eyes and mouth, the confidence level of the network for the image at hand and a bounding box with x, y, width and height informations which can be used to cropp the face of the image.

Additionally we will only work with a subset of the original dataset for lower runtimes.  

The following code chunk takes quite some time to execute so the resulting compressed file is available on our github so that it can be reused more easily. 

In [None]:
!unzip img_align_celeba.zip

In [4]:
def load_image(filename):
    with Image.open(filename) as image:
        # converts the Image object into array
        image = image.convert('RGB')
    return image

In [None]:
def crop_faces(model, directory, n_images, crop_size = (128, 128)):
    faces = list()
    for filename in os.listdir(directory):
        image = load_image(directory + filename)
        face_data = model.detect_faces(np.asarray(image))
        if face_data != []:
            x, y, width, height = face_data[0]['box']
            image_cropped = image.crop((x, y, x+width, y+height)).resize(crop_size)
            faces.append(np.asarray(image_cropped))
            print(len(faces))
        if len(faces) >= n_images:
            break
    return faces

mtcnn = MTCNN()
np.savez_compressed('img_align_celeba_128.npz', crop_faces(mtcnn, '/content/img_align_celeba/', 10000))

Printing a few of the resulting images to see if everything functioned properly.

In [6]:
def plot_faces(faces, n):
    for i in range(n*n):
        pyplot.subplot(n, n, 1 + i)
        pyplot.axis('off')
        pyplot.imshow(faces[i])
    pyplot.show()

data = np.load('img_align_celeba_128.npz')
plot_faces(data['arr_0'],5)

FileNotFoundError: ignored

## 2 Model

The general model architecture of a ProGAN features a discriminator and a generator similar to the architecture of a basic GAN. The core idea of a ProGAN however is to train the model on pictures of incrementaly increasing sizes. This is achieved by first training a model with a small origin resolution e.g. 4x4. Then we slowly fade in a new model with a higher resolution e.g. 8x8 with the help of a Custom Layer. 

### 2.1 Custom Layers

#### 2.1.1 Weighted Sum

This custom Layer is an extension to the Add merge Layer. It is used to combine the activations from two input layers e.g. two input paths in a discriminator or two output paths in a generator model. With its alpha variable the influence of either of the two paths can be controlled. In the training call we can then update the alpha of each model call that features the WeightedSum Layer dependent on the training step we are currently in. 

In [7]:
class WeightedSum(Add):
    # init with default value
    def __init__(self, alpha=0.0, **kwargs):
        super(WeightedSum, self).__init__(**kwargs)
        self.alpha = backend.variable(alpha, name='ws_alpha')

    # output a weighted sum of inputs
    def _merge_function(self, inputs):
        # only supports a weighted sum of two inputs
        assert (len(inputs) == 2)
        # ((1-a) * input1) + (a * input2)
        output = ((1.0 - self.alpha) * inputs[0]) + (self.alpha * inputs[1])
        return output

#### 2.1.2 Pixel Normalization

In order to avoid a scenario for the training of the ProGAN where the magnitudes in the generator and discriminator get too large as a result of the competing loss functions the ProGAN uses a Normalization Layer that normalizes each pixel in the activation map to unit length. Pixel Normalization is only applied to the generator.

The formula for the normalization is defined as:

$ b_(x,y) = a_(x,y) / \sqrt{\frac{1}{N} \sum_{j=0}^{N-1}(a^j_{x,y})^2 + \epsilon}$, where ${\epsilon}$ is a small constant to deal with 0 means,
N the number of feature maps, $a_{x,y}$ the original and $b_{x,y}$ the normalized feature vector in pixel $(x,y)$, respectively.

In [8]:
class PixelNormalization(Layer):
    def __init__(self, **kwargs):
        super(PixelNormalization, self).__init__(**kwargs)

    def call(self, inputs, **kwargs):
        return inputs / backend.sqrt(backend.mean(inputs ** 2.0, axis=-1, keepdims=True) + 1.0e-8)

#### 2.1.3 Minibatch Standard deviation

GANs often struggle to capture all the variation of a training dataset. Minibatch Standard devation tries to improve on this by computing feature statistics of a minibatches activations. More precisely it is implemented by first computing the mean of the standard deviation for each feature in each spatial location over the minibatch. The resulting constant feature map is then concatenated to the activation maps.

In [9]:
class MinibatchStdev(Layer):
    def __init__(self, **kwargs):
        super(MinibatchStdev, self).__init__(**kwargs)

    def call(self, inputs, **kwargs):
        # Calculate stdev across all feature maps
        square_diffs = backend.square(inputs - backend.mean(inputs, axis=0, keepdims=True))
        var = backend.men(square_diffs, axis=0, keepdims=True) + 1e-8
        stdev = backend.sqrt(var)
        mean_stdev = backend.mean(stdev, keepdims=True)
        shape = backend.shape(inputs)
        output = backend.tile(mean_stdev, (shape[0], shape[1], shape[2], 1))
        return backend.concatenate([inputs, output], axis=1)

#### 2.1.4 Equalized Learning rate

Instead of carefully initializing the networks weights, the paper uses an equalized learning rate. This scales the weights at each layer with a constant.
This is performed during training in order to keep the networks weights at a similar scale during training.

The weights are scaled as follows:

$\hat{w_i} = w_i / c$, where $w_i$ are the weights and c is the per-layer normalization constant from He's initializer

The equalized learning rate can be implemented via a custom Conv2D layer.


In [11]:
class EqualizedConv2D(Conv2D):
    def __init__(self, *args, **kwargs):
        self.scale = 1.0
        super(EqualizedConv2D, self).__init__(*args, **kwargs)

    def build(self, input_shape):
        fan_in = np.prod(input_shape[1:-1])
        self.scale = np.sqrt(2/fan_in)
        return super(EqualizedConv2D, self).build(input_shape)

    def call(self, inputs):
        outputs = backend.conv2d(inputs, self.kernel*self.scale, strides=self.strides, padding=self.padding, data_format=self.data_format, dilation_rate=self.dilation_rate)
        outputs = backend.bias_add( outputs, self.bias, data_format=self.data_format)
        if self.activation is not None:
            return self.activation(outputs)
        return outputs

### 2.2 Loss

The paper uses Wasserstein loss but also mentions least squared error

In [12]:
def wasserstein_loss(y_true, y_pred):
    return backend.mean(y_true * y_pred)

### 2.3 Discriminator

In [13]:
def add_discriminator_block(old_model, n_input_layers=3, filters=128):
    # weight initialization
    init = RandomNormal(stddev=1.0)
    # get shape of existing model
    in_shape = list(old_model.input.shape)
    # define new input shape as double the size
    input_shape = (in_shape[-2] * 2, in_shape[-2] * 2, in_shape[-1])
    in_image = Input(shape=input_shape)
    # define new input processing layer
    d = EqualizedConv2D(filters, (1, 1), padding='same', kernel_initializer=init)(in_image)
    d = LeakyReLU(alpha=0.2)(d)
    # define new block
    d = EqualizedConv2D(filters, (3, 3), padding='same', kernel_initializer=init)(d)
    d = LeakyReLU(alpha=0.2)(d)
    d = EqualizedConv2D(filters, (3, 3), padding='same', kernel_initializer=init)(d)
    d = LeakyReLU(alpha=0.2)(d)
    d = AveragePooling2D()(d)
    block_new = d
    # skip the input, 1x1 and activation for the old model
    for i in range(n_input_layers, len(old_model.layers)):
        d = old_model.layers[i](d)
    # define straight-through model
    model1 = Model(in_image, d)
    # compile model
    model1.compile(loss=wasserstein_loss, optimizer=Adam(lr=PARAMETERS.learning_rate,
                                                         beta_1=PARAMETERS.adam_beta1,
                                                         beta_2=PARAMETERS.adam_beta2,
                                                         epsilon=PARAMETERS.adam_epsilon))
    # downsample the new larger image
    downsample = AveragePooling2D()(in_image)
    # connect old input processing to downsampled new input
    block_old = old_model.layers[1](downsample)
    block_old = old_model.layers[2](block_old)
    # fade in output of old model input layer with new input
    d = WeightedSum()([block_old, block_new])
    # skip the input, 1x1 and activation for the old model
    for i in range(n_input_layers, len(old_model.layers)):
        d = old_model.layers[i](d)
    # define straight-through model
    model2 = Model(in_image, d)
    # compile model
    model2.compile(loss=wasserstein_loss, optimizer=Adam(lr=PARAMETERS.learning_rate,
                                                         beta_1=PARAMETERS.adam_beta1,
                                                         beta_2=PARAMETERS.adam_beta2,
                                                         epsilon=PARAMETERS.adam_epsilon))
    return [model1, model2]


# define the discriminator models for each image resolution
def define_discriminator(n_blocks, input_shape=(4, 4, 3)):
    # weight initialization
    init = RandomNormal(stddev=1.0)
    model_list = list()
    # base model input
    in_image = Input(shape=input_shape)
    # conv 1x1
    d = EqualizedConv2D(128, (1, 1), padding='same', kernel_initializer=init)(in_image)
    d = LeakyReLU(alpha=0.2)(d)
    # conv 3x3 (output block)
    d = MinibatchStdev()(d)
    d = EqualizedConv2D(128, (3, 3), padding='same', kernel_initializer=init)(d)
    d = LeakyReLU(alpha=0.2)(d)
    # conv 4x4
    d = EqualizedConv2D(128, (4, 4), padding='same', kernel_initializer=init)(d)
    d = LeakyReLU(alpha=0.2)(d)
    # dense output layer
    d = Flatten()(d)
    out_class = Dense(1)(d)
    # define model
    model = Model(in_image, out_class)
    # compile model
    model.compile(loss=wasserstein_loss, optimizer=Adam(lr=PARAMETERS.learning_rate,
                                                        beta_1=PARAMETERS.adam_beta1,
                                                        beta_2=PARAMETERS.adam_beta2,
                                                        epsilon=PARAMETERS.adam_epsilon))
    # store model
    model_list.append([model, model])
    # create submodels
    for i in range(1, n_blocks):
        filters = 2**(10-i)
        if filters > 512:
            filters = 512
        if filters < 16:
            filters = 16
        # get prior model without the fade-on
        old_model = model_list[i - 1][0]
        # create new model for next resolution
        models = add_discriminator_block(old_model)
        # store model
        model_list.append(models)
    return model_list

### 2.4 Generator

In [14]:
# add a generator block
def add_generator_block(old_model, filters=256):
    # weight initialization
    init = RandomNormal(stddev=1.0)
    # get the end of the last block
    block_end = old_model.layers[-2].output
    # upsample, and define new block
    upsampling = UpSampling2D()(block_end)
    g = EqualizedConv2D(filters, (3, 3), padding='same', kernel_initializer=init)(upsampling)
    g = PixelNormalization()(g)
    g = LeakyReLU(alpha=0.2)(g)
    g = EqualizedConv2D(filters, (3, 3), padding='same', kernel_initializer=init)(g)
    g = PixelNormalization()(g)
    g = LeakyReLU(alpha=0.2)(g)
    # add new output layer
    out_image = EqualizedConv2D(3, (1, 1), padding='same', kernel_initializer=init)(g)
    model1 = Model(old_model.input, out_image)
    # get the output layer from old model
    out_old = old_model.layers[-1]
    # connect the upsampling to the old output layer
    out_image2 = out_old(upsampling)
    # define new output image as the weighted sum of the old and new models
    merged = WeightedSum()([out_image2, out_image])
    model2 = Model(old_model.input, merged)
    return [model1, model2]


# define generator models
def define_generator(latent_dim, n_blocks, in_dim=4):
    # weight initialization
    init = RandomNormal(stddev=1.0)
    model_list = list()
    # base model latent input
    in_latent = Input(shape=(latent_dim,))
    # linear scale up to activation maps
    g = Dense(128 * in_dim * in_dim, kernel_initializer=init)(in_latent)
    g = Reshape((in_dim, in_dim, 128))(g)
    # conv 4x4, input block
    # noinspection DuplicatedCode
    g = EqualizedConv2D(128, (3, 3), padding='same', kernel_initializer=init)(g)
    g = PixelNormalization()(g)
    g = LeakyReLU(alpha=0.2)(g)
    # conv 3x3
    g = EqualizedConv2D(128, (3, 3), padding='same', kernel_initializer=init)(g)
    g = PixelNormalization()(g)
    g = LeakyReLU(alpha=0.2)(g)
    # conv 1x1, output block
    out_image = EqualizedConv2D(3, (1, 1), padding='same', kernel_initializer=init)(g)
    model = Model(in_latent, out_image)
    # store model
    model_list.append([model, model])
    # create submodels
    for i in range(1, n_blocks):
        filters = 2**(10-i)
        if filters > 512:
            filters = 512
        if filters < 16:
            filters = 16
        # get prior model without the fade-on
        old_model = model_list[i - 1][0]
        # create new model for next resolution
        models = add_generator_block(old_model, gpus=gpus)
        # store model
        model_list.append(models)
    return model_list

### 2.5 Composite Model

In [15]:
# define composite models for training generators via discriminators
def define_composite(discriminators, generators):
    model_list = list()
    # create composite models
    for i in range(len(discriminators)):
        g_models, d_models = generators[i], discriminators[i]
        # straight-through model
        d_models[0].trainable = False

        model1 = Sequential()
        model1.add(g_models[0])
        model1.add(d_models[0])
        model1.compile(loss=wasserstein_loss, optimizer=Adam(lr=PARAMETERS.learning_rate,
                                                             beta_1=PARAMETERS.adam_beta1,
                                                             beta_2=PARAMETERS.adam_beta2,
                                                             epsilon=PARAMETERS.adam_epsilon))
        # fade-in model
        d_models[1].trainable = False

        model2 = Sequential()
        model2.add(g_models[1])
        model2.add(d_models[1])
        model2.compile(loss=wasserstein_loss, optimizer=Adam(lr=PARAMETERS.learning_rate,
                                                             beta_1=PARAMETERS.adam_beta1,
                                                             beta_2=PARAMETERS.adam_beta2,
                                                             epsilon=PARAMETERS.adam_epsilon))
        # store
        model_list.append([model1, model2])
    return model_list

## 3 Training

### 3.1 Helper functions

In [18]:
def deprocess(x):
    x = np.clip(x, -1, 1)
    return (x + 1) / 2.0

In [19]:
def summarize_performance(status, g_model, latent_dim, n_samples=25, save_models=True):
    # devise name
    gen_shape = g_model.output_shape
    name = '%03dx%03d-%s' % (gen_shape[1], gen_shape[2], status)
    # generate images
    X, _ = generate_fake_samples(g_model, latent_dim, n_samples)
    # normalize pixel values to the range [0,1]
    # X = (X - X.min()) / (X.max() - X.min())
    # plot real images
    square = int(np.sqrt(n_samples))
    for i in range(n_samples):
        pyplot.subplot(square, square, 1 + i)
        pyplot.axis('off')
        img = deprocess(X[i])
        img = np.clip(img, 0, 1)
        pyplot.imshow(img)
    # save plot to file
    filename1 = 'plot_%s.jpg' % (name)
    pyplot.savefig(filename1)
    pyplot.close()
    if save_models:
        # save the generator model
        filename2 = 'model_%s.h5' % (name)
        g_model.save(F"models/{filename2}")
        print('>Saved: %s and %s' % (filename1, filename2))

In [25]:
# load dataset
def load_real_samples(filename):
	# load dataset
	data = load(filename)
	# extract numpy array
	X = data['arr_0']
	# convert from ints to floats
	X = X.astype('float32')
	# scale from [0,255] to [-1,1]
	X = (X - 127.5) / 127.5
	return X

In [26]:
# select real samples
def generate_real_samples(dataset, n_samples):
	# choose random instances
	ix = randint(0, dataset.shape[0], n_samples)
	# select images
	X = dataset[ix]
	# generate class labels
	y = ones((n_samples, 1))
	return X, y

In [20]:
def generate_latent_points(latent_dim, n_samples):
    # generate points in the latent space
    x_input = np.random.randn(latent_dim * n_samples)
    # reshape into a batch of inputs for the network
    x_input = x_input.reshape(n_samples, latent_dim)
    return x_input

In [21]:
def generate_fake_samples(generator, latent_dim, n_samples):
    # generate points in latent space
    x_input = generate_latent_points(latent_dim, n_samples)
    # predict outputs
    X = generator.predict(x_input)
    # create class labels
    y = -np.ones((n_samples, 1))
    return X, y

In [22]:
def update_fadein(models, step, n_steps):
    # calculate current alpha (linear from 0 to 1)
    alpha = step / float(n_steps - 1)
    # update the alpha for each model
    for model in models:
        for layer in model.layers:
            if isinstance(layer, WeightedSum):
                backend.set_value(layer.alpha, alpha)

In [23]:
def show_images(generated_images, suffix=""):
    n_images = len(generated_images)
    rows = 4
    cols = n_images // rows

    plt.figure(figsize=(cols, rows))
    _, axs = plt.subplots(rows, cols)
    axs = axs.flatten()
    for im, ax in zip(generated_images, axs):
        img = deprocess(im)
        ax.imshow(img, cmap='gray')
    plt.savefig(F'prev_{suffix}.jpg')
    plt.show()

In [24]:
class PARAMETERS(object):
    n_blocks = 6
    latent_dim = 128
    n_batch = [16, 16, 16, 8, 4, 4]
    # n_epochs = [8, 8, 64, 64, 128, 128, 256]
    # n_epochs = [8, 8, 8, 8, 16, 32, 32]
    # n_epochs = [512, 1024, 2048, 5000, 5000, 5000, 5000]  # For AF faces, ap at 800k, like in the paper
    n_epochs = [5, 8, 8, 10, 10, 10]  # For faces, ap at 800k, like in the paper
    # n_epochs = [1, 1, 1, 1, 1, 1, 1]
    learning_rate = 0.0001
    adam_beta1 = 0.0
    adam_beta2 = 0.99
    adam_epsilon = 1e-8

In [27]:
def train_epochs(g_model, d_model, gan_model, dataset, n_epochs, n_batch, latent_dim, fadein=False):
    # calculate the number of batches per training epoch
    bat_per_epo = int(dataset.shape[0] / n_batch)
    # calculate the number of training iterations
    n_steps = bat_per_epo * n_epochs
    # calculate the size of half a batch of samples
    half_batch = int(n_batch / 2)
    # manually enumerate epochs
    t0 =time.time()
    for i in range(n_steps):
        # update alpha for all WeightedSum layers when fading in new blocks
        if fadein:
            update_fadein([g_model, d_model, gan_model], i, n_steps)
        # prepare real and fake samples
        X_real, y_real = generate_real_samples(dataset, half_batch)
        X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
        # update discriminator model
        d_loss1 = d_model.train_on_batch(X_real, y_real)
        d_loss2 = d_model.train_on_batch(X_fake, y_fake)
        # update the generator via the discriminator's error
        z_input = generate_latent_points(latent_dim, n_batch)
        y_real2 = np.ones((n_batch, 1))
        g_loss = gan_model.train_on_batch(z_input, y_real2)
        # summarize loss on this batch
        print(F'\r>{i + 1}/{n_steps}, d1=%.3f, d2=%.3f g=%.3f' % (d_loss1, d_loss2, g_loss), end="")
        if time.time() - t0 > 30:
            # save preview of images every 30 s.
            t0 = time.time()
            try:
                summarize_performance('fresh_batch_preview', g_model, latent_dim, save_models=False)
            except:
                pass
    print("")


# train the generator and discriminator
def train(g_models, d_models, gan_models, dataset, latent_dim, e_norm, e_fadein, n_batch):
    # fit the baseline model
    g_normal, d_normal, gan_normal = g_models[0][0], d_models[0][0], gan_models[0][0]
    # scale dataset to appropriate size
    gen_shape = g_normal.output_shape
    scaled_data = scale_dataset(dataset, gen_shape[1:])
    print('Scaled Data', gen_shape)
    # train normal or straight-through models
    train_epochs(g_normal, d_normal, gan_normal, scaled_dataset, e_norm[0], n_batch[0], latent_dim)
    try:
        summarize_performance('tuned', g_normal, latent_dim)
    except:
        pass

    # process each level of growth
    for i in range(1, len(g_models)):
        # retrieve models for this level of growth
        [g_normal, g_fadein] = g_models[i]
        [d_normal, d_fadein] = d_models[i]
        [gan_normal, gan_fadein] = gan_models[i]
        # scale dataset to appropriate size
        gen_shape = g_normal.output_shape
        scaled_dataset = scale_dataset(dataset, gen_shape[1:])
        print('Scaled Data', gen_shape)
        # train fade-in models for next level of growth
        train_epochs(g_fadein, d_fadein, gan_fadein, dataset, n_epochs, e_fadein[i], n_batch[i], latent_dim, True)
        summarize_performance('faded', g_fadein, latent_dim)
        # train normal or straight-through models
        train_epochs(g_normal, d_normal, gan_normal, dataset, e_norm[i], n_batch[i], latent_dim)
        summarize_performance('tuned', g_normal, latent_dim)