# pix2pix in Tensorflow
This notebook implements the network nrchitecture and training process of the Pix2Pix approach of generating new images from gven images.
The architecture is based on the conditional GAN (Generative Adversarial Network) which adds the input image as an input of the discriminator which is only an input of the generator in conventional GAN architectures.

A bug point of the Pix2Pix approach is its good performance in versatile problems.

In [1]:
# common packages

from tensorflow.keras.layers import Layer
from tensorflow.keras import Model
import tensorflow as tf
import numpy as np
import math

## Structure of Pix2Pix
According to the GAN architecture, Pix2Pix is split up into two parts which are the generator and the discriminator. 

As the name suggests, the generator generates images from a given input image. The output of the generator will be the output later on.

The discriminator however classifies images as real or fake. It is used aside the l1 loss as of the loss function of the generator.

While training the discriminator tries to maximize the loss of the generator while the generator tries to minimize it.
To do so, the discriminator minimizes the loss for correctly classifying the real image as real (0) and the generated image as fake (1). The generator simply minimizes the output of the discriminator plus the l1 loss. The l1 loss is included to improve the output image quality. These improvements are shown in the paper.

## Common layers
The paper defines the following types of Layers wich we implement before taking a look at the model:
- C(k)
    - Convolution (k filters; 4x4; stride 2)
    - BatchNorm (in testing and training)
    - ReLU
- CD(k)
    - Convolution (k filters; 4x4; stride 2)
    - BatchNorm (in testing and training)
    - Dropout (50% rate)
    - ReLU

In [2]:
class C(Layer):
    """This layer represents the C(k) layer described in the pix2pix paper. The activation function 
        is a parameter to allow the use of different activation functions like ReLU and leaky ReLU for 
        encoder and decoder. The sampling_factor gives a factor by which the convolution output will be 
        sampled up or down. A value of 2 will sample the tensor up by 2. A value of 0.5 will sample the 
        tensor down by 2."""
        
    def __init__(self, k, activation=None, sampling='down', batchnorm=True):
        super(C, self).__init__()
        if sampling == 'up':
            self.conv = tf.keras.layers.Conv2DTranspose(k, kernel_size=4, strides=2, activation=None, padding='same')
        elif sampling == 'down':
            self.conv = tf.keras.layers.Conv2D(k, kernel_size=4, strides=2, activation=None, padding='same')
        else:
            raise AttributeError('illegal sampling mode: "' + str(sampling) + '"')
            
        self.batchnorm = None
        if batchnorm:
            self.batchnorm = tf.keras.layers.BatchNormalization()
            
        self.activation = activation
        
    def call(self, x):
        x = self.conv(x)
        
        if self.batchnorm != None:
            x = self.batchnorm(x)
        
        if self.activation != None:
            x = self.activation(x)
            
        return x

class CD(C):
    """This layer represents the CD(k) layer described in the pix2pix paper. The activation function 
        is a parameter to allow the use of different activation functions like ReLU and leaky ReLU for 
        encoder and decoder. The sampling_factor gives a factor by which the convolution output will be 
        sampled up or down. A value of 2 will sample the tensor up by 2. A value of 0.5 will sample the 
        tensor down by 2."""
    
    def __init__(self, k, activation=None, sampling=None, batchnorm=True):
        super(CD, self).__init__(k, activation, sampling, batchnorm)
        self.dropout = tf.keras.layers.Dropout(rate=0.5)

    def call(self, x):
        x = self.conv(x)
        
        if self.batchnorm != None:
            x = self.batchnorm(x)
        
        x = self.dropout(x)
        
        if self.activation != None:
            x = self.activation(x)
            
        return x

## Define the discriminator
The paper presents different discriminator sizes for different result performances. The paper shows that the 70x70 discriminator shows a great balance between quality and training time.

The following lists show the architectures of the different discriminators using the Layers defines earlier. All convolutions used here downsample the image and all but the first convolution layer use batchnorm.

**16 x 16:**
- C64
- C128 
- conv to 1d
- sigmoid

**70 x 70:**
- C64 
- C128 
- C256 
- C512 
- conv to 1d 
- sigmoid
    
**286 x 286:**
- C64 
- C128 
- C256 
- C512 
- C512
- C512 
- conv to 1d 
- sigmoid

In [3]:
# 16 x 16 discriminator:
class Discriminator16(Model):
    def __init__(self):
        super(Discriminator16, self).__init__()
        self.conv1 = C(k=64, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down", batchnorm=False)
        self.conv2 = C(k=128, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        
        # flatten and dense with one neuron and sigmoid is the same as conv to 1D and sigmoid
        self.flatten = tf.keras.layers.Flatten()
        self.out = tf.keras.layers.Dense(1, activation=tf.keras.activations.sigmoid)
        
    def call(self, x, y):
        """Calls the discriminator with input x and generator output y"""
        x = tf.keras.layers.concatenate([x, y])
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.flatten(x)
        x = self.out(x)
        return x

In [4]:
# 70 x 70 discriminator:
class Discriminator70(Model):
    def __init__(self):
        super(Discriminator70, self).__init__()
        self.conv1 = C(k=64, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down", batchnorm=False)
        self.conv2 = C(k=128, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        self.conv3 = C(k=256, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        self.conv4 = C(k=512, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        
        # flatten and dense with one neuron and sigmoid is the same as conv to 1D and sigmoid
        self.flatten = tf.keras.layers.Flatten()
        self.out = tf.keras.layers.Dense(1, activation=tf.keras.activations.sigmoid)
        
    def call(self, x, y):
        """Calls the discriminator with input x and generator output y"""
        x = tf.keras.layers.concatenate([x, y])
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.flatten(x)
        x = self.out(x)
        return x

In [5]:
# 286 x 286 discriminator:
class Discriminator286(Model):
    def __init__(self):
        super(Discriminator286, self).__init__()
        self.conv1 = C(k=64, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down", batchnorm=False)
        self.conv2 = C(k=128, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        self.conv3 = C(k=256, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        self.conv4 = C(k=512, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        self.conv5 = C(k=512, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        self.conv6 = C(k=512, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        
        # flatten and dense with one neuron and sigmoid is the same as conv to 1D and sigmoid
        self.flatten = tf.keras.layers.Flatten()
        self.out = tf.keras.layers.Dense(1, activation=tf.keras.activations.sigmoid)
        
    def call(self, x, y):
        """Calls the discriminator with input x and generator output y"""
        x = tf.keras.layers.concatenate([x, y])
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)
        x = self.flatten(x)
        x = self.out(x)
        return x

## Define the Generator model
The generator is based on the widely used autoencoder structure. This means that the generator consists of an encoder and an decoder.
Like for the discriminator, the following lists describe the encoder and decoder generator.

**encoder: (Leaky ReLU (slope = 0.2))**
- C64
- C128
- C256
- C512
- C512
- C512
- C512
- C512

All convultions of the encoder downsample the image while the decoder convolutions upsample the images. As with the encoder, all convolutions apply batchnorm while the first convolution does not.

**UNet decoder: (ReLU)**
- CD512
- CD1024
- CD1024
- C1024
- C1024
- C512
- C256
- C128
- reduction to output channels
- tanh

All decoder convolutions upsample the image and apply batchnorm.
As said before, we use the UNet structure for the generator. This means that we have connections from the i-th layer to the (n - i)-th layer of the autoencoder.

In [6]:
class Generator(Model):
    def __init__(self, output_dim=3):
        super(Generator, self).__init__()
        
        # encoder:
        self.enc_conv1 = C(k=64, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down", batchnorm=False)
        self.enc_conv2 = C(k=128, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        self.enc_conv3 = C(k=256, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        self.enc_conv4 = C(k=512, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        self.enc_conv5 = C(k=512, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        self.enc_conv6 = C(k=512, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        self.enc_conv7 = C(k=512, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        self.enc_conv8 = C(k=512, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        
        # decoder
        self.dec_conv1 = CD(k=512, activation=tf.keras.activations.relu, sampling="up")
        self.dec_conv2 = CD(k=1024, activation=tf.keras.activations.relu, sampling="up")
        self.dec_conv3 = CD(k=1024, activation=tf.keras.activations.relu, sampling="up")
        self.dec_conv4 = C(k=1024, activation=tf.keras.activations.relu, sampling="up")
        self.dec_conv5 = C(k=1024, activation=tf.keras.activations.relu, sampling="up")
        self.dec_conv6 = C(k=512, activation=tf.keras.activations.relu, sampling="up")
        self.dec_conv7 = C(k=256, activation=tf.keras.activations.relu, sampling="up")
        self.dec_conv8 = C(k=128, activation=tf.keras.activations.relu, sampling="up")
        
        self.out = tf.keras.layers.Conv2D(output_dim, kernel_size=3, strides=1, activation=tf.keras.activations.tanh, padding='same')
        
    def call(self, x):
        # encoder
        x1 = self.enc_conv1(x)
        x2 = self.enc_conv2(x1)
        x3 = self.enc_conv3(x2)
        x4 = self.enc_conv4(x3)
        x5 = self.enc_conv5(x4)
        x6 = self.enc_conv6(x5)
        x7 = self.enc_conv7(x6)
        x8 = self.enc_conv8(x7)
        
        #decoder
        x = self.dec_conv1(x8)
        x = self.dec_conv2(tf.keras.layers.concatenate([x, x7]))
        x = self.dec_conv3(tf.keras.layers.concatenate([x, x6]))
        x = self.dec_conv4(tf.keras.layers.concatenate([x, x5]))
        x = self.dec_conv5(tf.keras.layers.concatenate([x, x4]))
        x = self.dec_conv6(tf.keras.layers.concatenate([x, x3]))
        x = self.dec_conv7(tf.keras.layers.concatenate([x, x2]))
        x = self.dec_conv8(tf.keras.layers.concatenate([x, x1]))
        
        # get three channels
        x = self.out(x)
        return x

## Define the combined pix2pix model
Adter we have described the overall structure and defined the the models for the discriminator and the generator, we define one model that combines both components and implements a training procedure. to eas up the use later on. To do so we derive from the keras base Model which provides basic functions to the model.

In [7]:
class Pix2pix(Model):
    """This model implements the Pix2Pix neural network to convert one image into another. 
    It uses a Generator with UNet encoder and a 70x70 discriminator by default.
    """
    def __init__(self, discriminator=Discriminator70(), output_dim=3):
        """
        Args:
            discriminator: instance of the discriminator to use. (defaults to 70x70 discriminator)
            output_dim: dimension of output of generator (defaults to 3 for rgb)
        """
        super(Pix2pix, self).__init__()
        self.output_dim = output_dim
        
        self.g = Generator(output_dim=output_dim)
        self.d = discriminator
        
        self.cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.999)
    
    def _disc_loss(self, disc_real, disc_fake):
        """Calculates the loss for the discriminator based on the discriminator output of the real and fake image.
        Args:
            disc_real: discriminator output for real image
            disc_fake: discriminator outpur for generated (fake) image
        """
        return tf.reduce_mean(-tf.math.log(disc_real + 1e-16) - tf.math.log(1 - disc_fake +  + 1e-16))
    
    def _gen_loss(self, y, generated, disc_fake):
        """Calculates the loss for the generator based on the output, the generated image 
        and the discriminator outpur for the generated image.
        Args:
            y: dataset output
            generated: generated output
            disc_fake: discriminator output for generated image
        """
        gan_loss = tf.reduce_mean(-tf.math.log(disc_fake + 1e-16))
        l1_loss = tf.reduce_mean(tf.abs(y - generated))
        return gan_loss + (600 / self.output_dim) * l1_loss
    
    def split_dataset(self, x, y, validation_split=0.0):
        """Splits the dataset into test and training dataset. The returned datasets are numpy arrays.
        Args:
            x: original input images of dataset
            y: original input images of dataset
            validation_split: split point for splitting the given x`s and y`s into test dataset and training dataset. if set to 0.1, there will be 10% of the given dataset in the testing dataset an 90% in the training dataset.
        """
        num_test = math.floor(x.shape[0] * validation_split)

        test_x = np.array(x[:num_test], dtype=np.float32)
        test_y = np.array(y[:num_test], dtype=np.float32)

        train_x = np.array(x[num_test:], dtype=np.float32)
        train_y = np.array(y[num_test:], dtype=np.float32)
        
        return ((train_x, train_y), (test_x, test_y))
    
    def predict_on_batch(self, x):
        """Returns the output of the generator for the given batch x.
        Args:
            x: input images of batch
        """
        return self.g(x)
    
    def predict(self, x, batch_size=None, verbose=0, steps=None, callbacks=None, max_queue_size=10, workers=1, use_multiprocessing=False):
        """Generates the output for a input dataset x. The data is expexted to be a numpy array and will be split 
        up into batches if a batch_size is given.
        Args:
            x: numpy array of input images
            batch_size: size of batches the inputs should be split up into
            verbose: [unused]
            steps: maximum amount of prediction runs. if the length of x is larger than steps, the output array will have steps entries
            calbacks: list of all calbacks that should be called at given events
            max_queue_size: [unused]
            workers: [unused]
            use_multiprocessing: [unused]
        """
        
        # generate dataset
        dataset = tf.data.Dataset.from_tensor_slices(x)
        if batch_size == None:
            dataset = dataset.batch(batch_size=1)
        else:
            dataset = dataset.batch(batch_size=batch_size)
        
        # call callback for predition start
        if callbacks != None:
            for callback in callbacks:
                callback.on_predict_begin()
        
        result = np.array([])
        
        # run for all elements in dataset
        for n, x1 in dataset.enumerate():
            
            # call callback for batch begin
            if callbacks != None:
                for callback in callbacks:
                    callback.on_predict_batch_begin(n)
                    
            # predict for current batch and add resulting images to results array
            output = self.predict_on_batch(x1)
            if result.size == 0:
                result = output.numpy()
            else:
                result = np.concatenate((result, output.numpy()), axis=0)
            
            # call callback for batch end
            if callbacks != None:
                for callback in callbacks:
                    callback.on_predict_batch_end(n)
            
            if steps != None and n >= steps:
                break
        
        # call callback for prediction end 
        if callbacks != None:
            for callback in callbacks:
                callback.on_predict_end()
                
        return result
    
    def fit(self, x, y, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0.0, validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None, validation_freq=1, max_queue_size=10, workers=1, use_multiprocessing=False):
        """Trains the pix2pix model including the discriminator and the generator according to the training strategy described above.
        Args:
            x: numpy array containing all input images
            y: numpy array containing all output images of the dataset
            batch_size: size of the batches while training
            epochs: number of epochs to train
            verbose: [unused]
            calbacks: list of all calbacks that should be called at given events
            validation_split: split point for splitting the given x`s and y`s into test dataset and training dataset. if set to 0.1, there will be 10% of the given dataset in the testing dataset an 90% in the training dataset.
            validation_data: set of (x`s, y`s) used as testing dataset. This can be used instead ov validation_split
            shuffle: [unused]
            class_weight: [unused]
            sample_weight: [unused]
            initial_epoch: number of epoch to start
            steps_per_epoch: [unused]
            validation_steps: [unused]
            validation_freq: [unused]
            max_queue_size: [unused]
            workers: [unused]
            use_multiprocessing: [unused]
        """
        
        # get or generate validation data
        if validation_data != None:
            train_x = x
            train_y = y
            
            test_x = validation_data[0]
            test_y = validation_data[1]
        else:
            ((train_x, train_y), (test_x, test_y)) = self.split_dataset(x, y, validation_split=validation_split)
        
        # generate tf datasets from data
        train_dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y))
        test_dataset = tf.data.Dataset.from_tensor_slices((test_x, test_y))

        if shuffle:
            train_dataset = train_dataset.shuffle(buffer_size=50000)

        if batch_size == None:
            train_dataset = train_dataset.batch(batch_size=1)
        else:
            train_dataset = train_dataset.batch(batch_size=batch_size)
        test_dataset = test_dataset.batch(10000)
        
        # prepare callbacks
        if callbacks != None:
            params = {
                'batch_size': batch_size,
                'epochs': epochs,
                'steps': steps_per_epoch,
                'samples': train_x.shape[0],
                'verbose': verbose,
                'do_validation': validation_steps != None,
                'metrics': []
            }
            for callback in callbacks:
                callback.set_params(params)
                
                callback.set_model(self)
        
        # call callbacks for train begin
        if callbacks != None:
            for callback in callbacks:
                callback.on_train_begin()
        
        # train
        for epoch in range(initial_epoch, epochs):
            # call callbacks for epoch begin
            if callbacks != None:
                for callback in callbacks:
                    callback.on_epoch_begin(epoch)
                    
            sum_disc_loss = 0
            sum_gen_loss = 0
            
            for n, (x, y) in train_dataset.enumerate():
                # call callbacks for bach begin
                if callbacks != None:
                    for callback in callbacks:
                        callback.on_train_batch_begin(n, logs={
                            'size': x.shape[0], 
                            'batch': n
                        })
                        
                losses = model.train_on_batch(x, y)
                sum_disc_loss += losses[0]
                sum_gen_loss += losses[1]
                
                # call callbacks for bach end
                if callbacks != None:
                    for callback in callbacks:
                        callback.on_train_batch_end(n, logs={
                            'size': x.shape[0], 
                            'batch': n,
                            'discriminator_loss': losses[0],
                            'generator_loss': losses[1],
                        })
                
                if steps_per_epoch != None and n >= steps_per_epoch:
                    break
                    
            # call callbacks for epoch end
            if callbacks != None:
                epoch_logs = {
                    'discriminator_loss': sum_disc_loss / n,
                    'generator_loss': sum_gen_loss / n,
                }
                if test_x.size > 0:
                    losses = self.test_on_batch(test_x, test_y)
                    epoch_logs['val_discriminator_loss'] = losses[0]
                    epoch_logs['val_generator_loss'] = losses[1]
                
                for callback in callbacks:
                    callback.on_epoch_end(epoch, logs=epoch_logs)
        
        # call callbacks for train end
        if callbacks != None:
            for callback in callbacks:
                callback.on_train_end()
    
    def train_on_batch(self, x, y, sample_weight=None, class_weight=None, reset_metrics=True):
        """trains the model on one batch.
        Args:
            x: x`s of current batch
            y: y`s of current batch
            sample_weight: [unused]
            class_weight: [unused]
            reset_metrics: [unused]
        """
        
        # train discriminator
        with tf.GradientTape() as disc_tape, tf.GradientTape() as gen_tape:
            generated = self.g(x)
            disc_real = self.d(x, y)
            disc_fake = self.d(x, generated)
            
            disc_loss = self._disc_loss(disc_real, disc_fake)
            gen_loss = self._gen_loss(y, generated, disc_fake)
            
        gradients = disc_tape.gradient(disc_loss, self.d.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.d.trainable_variables))
        
        gradients = gen_tape.gradient(gen_loss, self.g.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.g.trainable_variables))
        
        return np.array([disc_loss, gen_loss])
        
    def test_on_batch(self, x, y, sample_weight=None, reset_metrics=True):
        """calculates discriminator- and generator loss for given batch.
        Args:
            x: x´s of batch
            y: y´s of batch
            sample_weight: [unused]
            reset_metrics: [unused]
        """
        generated = self.g(x)
        disc_real = self.d(x, y)
        disc_fake = self.d(x, generated)

        disc_loss = self._disc_loss(disc_real, disc_fake)
        gen_loss = self._gen_loss(y, generated, disc_fake)
        
        return np.array([disc_loss, gen_loss])
    
    def call(self, x):
        """Takes input image x and calls generator and discriminator for generator output.
        """
        generated = self.g(x)
        discriminator = self.d(x, generated)
        
        return (generated, discriminator)