## Imports

In [None]:
from random import shuffle
from sys import stderr
import warnings
import os

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import numpy as np

In [None]:
tf.config.list_physical_devices()

In [None]:
warnings.filterwarnings("ignore")
device = '/GPU:0' if tf.config.list_physical_devices('GPU') else '/CPU:0'
print(device)

# from google.colab import drive
# drive.mount('/content/drive')

## Loading Dataset
* _Universal Image Embeddings (512x512)_ dataset with 130k+ images is used along with preprocessed both _Vincent Van Gogh art_ and _Cartoon - Family Guy_ images
* only 256 random images are used for faster computation - can be changed with `data_size` parameter
* `img256` -> _np.array_ with input images generated by resizing original images to **256x256** size (size of the _CycleGAN_ output)
* `img512` -> _np.array_ with original images of **512x512** size

In [None]:
img256 = []
img512 = []
data_size = 256
test_size = 16

paths = []
directories = ["../dataset/Familyguy", "../dataset/VincentVanGogh_standardized"]

for dir_name in directories:
    dir_paths = os.listdir(dir_name)
    full_paths = [os.path.join(dir_name, path) for path in dir_paths]
    paths += full_paths

shuffle(paths)
paths = paths[0:data_size + test_size]

for it, filename in enumerate(tqdm(paths)):
    img = Image.open(filename)
    img512.append(np.array(img))
    img = img.resize((256, 256))
    img256.append(np.array(img))

img256 = np.array(img256, dtype='float32')
img512 = np.array(img512, dtype='float32')

img256 = (img256 / 127.5) - 1
img512 = (img512 / 127.5) - 1

img256_train = img256[0:data_size].reshape(-1, 1, 256, 256, 3)
img512_train = img512[0:data_size].reshape(-1, 1, 512, 512, 3)

img256_test = img256[data_size:].reshape(-1, 1, 256, 256, 3)
img512_test = img512[data_size:].reshape(-1, 1, 512, 512, 3)

print(img256_train.shape)
print(img512_train.shape)

print(img256_test.shape)
print(img512_test.shape)

### Instance Normalization layer initialization
`tf.keras.layers.BatchNormalization` has been replaced with my custom `tf.keras.layers.InstanceNormalization` layer to keep the consistency with other GANs within this project.

In [None]:
class InstanceNormalization(layers.Layer):
    def __init__(self, epsilon: float = 1e-5, gamma_initializer="ones", beta_initializer="zeros", **kwargs):
        super(InstanceNormalization, self).__init__(**kwargs)
        self.beta, self.gamma = None, None
        self.epsilon = epsilon
        self.gamma_initializer = gamma_initializer
        self.beta_initializer = beta_initializer

    def build(self, input_shape):
        self.gamma = self.add_weight(
            shape=(input_shape[-1],),
            initializer=self.gamma_initializer,
            trainable=True,
            name="gamma"
        )
        self.beta = self.add_weight(
            shape=(input_shape[-1],),
            initializer=self.beta_initializer,
            trainable=True,
            name="beta"
        )
        super(InstanceNormalization, self).build(input_shape)

    def call(self, inputs):
        mean, variance = tf.nn.moments(inputs, axes=[1, 2], keepdims=True)
        normalized = (inputs - mean) / tf.sqrt(variance + self.epsilon)
        return self.gamma * normalized + self.beta

    def get_config(self) -> dict:
        config = super(InstanceNormalization, self).get_config()
        config.update({"epsilon": self.epsilon})
        return config

### Upsample and Downsample Layers
Model bases on downsampling and upsampling the images. For this exact purpose, we need to define these layers below.

In [None]:
output_channels = len(["Red", "Green", "Blue"])  # RGB

def downsample(filters: int, size: int, instance_norm: bool = True):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential()
    result.add(layers.Conv2D(filters, size, strides=2, padding="same", kernel_initializer=initializer, use_bias=False))
    if instance_norm:
        result.add(InstanceNormalization(gamma_initializer=gamma_init))
    result.add(layers.LeakyReLU())
    return result


def upsample(filters: int, size: int, dropout: bool = False):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential()
    result.add(layers.Conv2DTranspose(filters, size, strides=2, padding="same", kernel_initializer=initializer, use_bias=False))
    result.add(InstanceNormalization(gamma_initializer=gamma_init))
    if dropout:
        result.add(layers.Dropout(0.5))
    result.add(layers.ReLU())
    return result

## Generator model initialization
Generator's goal is to generate photos that can be mistaken for real ones. In this case it's is a U-Net model consisting of three main components:
* **Downsampling blocks** -> convert image input to tensors of lower dimensions until it becomes a 1D tensor
* **Upsampling blocks** -> convert output of downsampling blocks back to image output
* **Skip connections** -> provide connections between downsampling and upsampling blocks at each level
Generator and Discriminator constantly compete against each other in order to improve the quality of generated images.

In [None]:
class Generator:
    def __init__(self):
        self.initializer = tf.random_normal_initializer(0., 0.02)
        self.inputs = layers.Input([256, 256, 3])
        self.outputs = None
        self.model = None
        self.down_stack, self.up_stack = None, None
        self.prepare_stacks()

    def __call__(self, inputs, training: bool = True):
        if not self.model:
            self.build()
        return self.model(inputs, training=training)

    def prepare_stacks(self):
        # format: (bs = batch size, width, height, filters)
        # each downsampling reduces size by 2 because of stride = 2,
        self.down_stack = [
            downsample(32, 4, instance_norm=False),  # (bs, 128, 128, 32)
            downsample(64, 4),  # (bs, 64, 64, 64)
            downsample(128, 4),  # (bs, 32, 32, 128)
            downsample(256, 4),  # (bs, 16, 16, 256)
            downsample(512, 4),  # (bs, 8, 8, 512)
            downsample(512, 4),  # (bs, 4, 4, 512)
            downsample(512, 4),  # (bs, 2, 2, 512)
            downsample(512, 4),  # (bs, 1, 1, 512)
        ]
        self.up_stack = [
            upsample(512, 4, dropout=True),  # (bs, 2, 2, 1024)
            upsample(512, 4, dropout=True),  # (bs, 4, 4, 1024)
            upsample(512, 4, dropout=True),  # (bs, 8, 8, 1024)
            upsample(512, 4),  # (bs, 16, 16, 1024)
            upsample(256, 4),  # (bs, 32, 32, 512)
            upsample(128, 4),  # (bs, 64, 64, 256)
            upsample(64, 4),  # (bs, 128, 128, 128)
        ]

    def build(self) -> keras.Model:
        last_layer = layers.Conv2DTranspose(output_channels, 4, strides=4, padding='same',
                                                     kernel_initializer=self.initializer,
                                                     activation='tanh')  # (bs, 512, 512, 3)

        x = self.inputs
        skips = []
        for down in self.down_stack:
            x = down(x)
            skips.append(x)
        skips = reversed(skips[:-1])

        for up, skip in zip(self.up_stack, skips):
            x = up(x)
            x = layers.Concatenate()([x, skip])
        self.outputs = last_layer(x)
        self.model = tf.keras.Model(inputs=self.inputs, outputs=self.outputs)
        return self.model

### Resize layer initialization
`tf.image.resize` normally returns a tensor which cannot be concatenated with another layer. That's where my custom `tf.keras.layers.ResizeLayer` comes in handy.

In [None]:
class ResizeLayer(layers.Layer):
    def __init__(self, target_size, method='bicubic', **kwargs):
        super(ResizeLayer, self).__init__(**kwargs)
        self.target_size = target_size
        self.method = method

    def call(self, inputs):
        return tf.image.resize(inputs, self.target_size, method=self.method)

## Discriminator Model initialization
Discriminator model is a PatchGAN. Its output is a 3D vector referring to similarity between patches of input and target images.
It determines whether image is real (actual bicubic upscaling) or generated (by our model).

In [None]:
class Discriminator:
    def __init__(self):
        self.initializer = tf.random_normal_initializer(0., 0.02)
        self.inputs = layers.Input(shape=[256, 256, 3], name='input_image')
        self.target = layers.Input(shape=[512, 512, 3], name='target_image')
        self.outputs = None
        self.model = None

    def build(self):
        inp_resized = ResizeLayer(target_size=(512, 512), method='bicubic')(self.inputs)
        x = layers.concatenate([inp_resized, self.target])  # (bs, 512, 512, 6)
        x = downsample(32, 4, instance_norm=False)(x) # (bs, 256, 256, 32)
        x = downsample(64, 4)(x)  # (bs, 128, 128, 64)
        x = downsample(128, 4)(x)  # (bs, 64, 64, 128)
        x = downsample(256, 4)(x)  # (bs, 32, 32, 256)

        x = layers.ZeroPadding2D()(x)  # (bs, 34, 34, 256)
        x = layers.Conv2D(512, 4, strides=1, kernel_initializer=self.initializer, use_bias=False)(x)  # (bs, 31, 31, 512)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU()(x)
        x = layers.ZeroPadding2D()(x)  # (bs, 33, 33, 512)
        x = layers.Conv2D(1, 4, strides=1, kernel_initializer=self.initializer)(x)  # (bs, 30, 30, 1)

        self.outputs = x
        self.model = tf.keras.Model(inputs=[self.inputs, self.target], outputs=self.outputs)
        return self.model

## Generator and Discriminator initialization  

In [None]:
with tf.device(device):
    generator = Generator().build()
    discriminator = Discriminator().build()

## Visualizing the models
Requires `Graphviz` to be installed!
> If you want to export the models to _.png_, you can uncomment the second line in cells below.

In [None]:
keras.utils.plot_model(generator, show_shapes=True, show_layer_names=True)
# keras.utils.plot_model(generator, to_file="../docs/imgs/patchgan_generator.png", show_shapes=True, show_layer_names=True)

In [None]:
keras.utils.plot_model(discriminator, show_shapes=True, show_layer_names=True)
# keras.utils.plot_model(discriminator, to_file="../docs/imgs/patchgan_discriminator.png", show_shapes=True, show_layer_names=True)

## Generator loss function
* **L1 loss** -> mean absolute error between generated and target image to make generated images structurally similar to target images
* **GAN loss** -> binary cross entropy loss of discriminator's output on generated images and array of ones
> total_loss = GAN_loss + (lambda * L1)

In [None]:
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)


def patch_generator_loss(disc_generated_output, gen_output, target, _lambda: int = 500) -> tf.Tensor:
    gan_loss = cross_entropy(tf.ones_like(disc_generated_output), disc_generated_output)
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
    return gan_loss + (_lambda * l1_loss)

## Discriminator loss function
* **Real loss** -> sigmoid cross entropy loss of real image output and array of ones
* **Generated loss** -> sigmoid cross entropy loss of generated image output and array of zeros
> total_loss = real_loss + gen_loss

In [None]:
def patch_discriminator_loss(disc_real_output, disc_gen_output) -> tf.Tensor:
    real_loss = cross_entropy(tf.ones_like(disc_real_output), disc_real_output)
    gen_loss = cross_entropy(tf.zeros_like(disc_gen_output), disc_gen_output)
    return real_loss + gen_loss

## Average loss function
This function calculates the average loss of the given batch

In [None]:
def avg_loss(loss_tensor: tf.Tensor) -> float:
    loss = 0
    for loss_list in loss_tensor[0]:
        loss += sum(loss_list)
    dim = len(loss_tensor[0]) * len(loss_tensor[0][0])
    return loss / dim

## Defining optimizers
Adam optimizer's adaptive learning rate, efficient handling of noisy and sparse gradients and ease of use make it a great choice for complex and dynamic training process in GANs.

In [None]:
patch_generator_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5)
patch_discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5)

## Checkpoint saver
Due to long training times, it is useful to save the model checkpoints every 10 epochs in case of sudden power outage or other unexpected events.

In [None]:
checkpoint_dir = "./training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=patch_generator_optimizer,
                                 discriminator_optimizer=patch_discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

## Image generation function
Main function that will be responsible for generating images based on trained model.

In [None]:
def generate_images(model, inp, tar):
    inp_normalized = (inp / 127.5) - 1
    pred = model(inp_normalized, training=True)
    pred = (pred + 1) * 127.5
    pred = np.array(pred, dtype='uint8')
    pred = np.array(pred).reshape((1, 512, 512, 3))
    display_list = [np.array(inp[0], dtype='uint8'), np.array(pred[0], dtype='uint8'), np.array(tar[0], dtype='uint8')]
    title_list = ['Input (256x256)', 'Upscaled with PatchGAN', 'Target (bicubic - 512x512)']
    plt.figure(figsize=(20, 20))

    for i in range(3):
        plt.subplot(1, 3, i + 1)
        plt.title(title_list[i])
        plt.imshow(display_list[i])
        plt.axis('off')

    plt.show()

## Train functions

In [None]:
# noinspection PyUnusedLocal
@tf.function
def train_step(inp, tar, epoch):
    with tf.device(device):
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            gen_output = generator(inp, training=True)
    
            disc_real_output = discriminator([inp, tar], training=True)
            disc_gen_output = discriminator([inp, gen_output], training=True)
    
            gen_loss = patch_generator_loss(disc_gen_output, gen_output, tar)
            disc_loss = patch_discriminator_loss(disc_real_output, disc_gen_output)
    
        gen_grads = gen_tape.gradient(gen_loss, generator.trainable_variables)
        disc_grads = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    
        patch_generator_optimizer.apply_gradients(zip(gen_grads, generator.trainable_variables))
        patch_discriminator_optimizer.apply_gradients(zip(disc_grads, discriminator.trainable_variables))
        
        return tf.reduce_mean(gen_loss), tf.reduce_mean(disc_loss)


def fit(inp_train, tar_train, epochs: int, make_checkpoints: bool = False):
    with tf.device(device):
        for epoch in range(epochs):
            for inp, tar in tqdm(zip(inp_train, tar_train)):
                gen_loss, disc_loss = train_step(inp, tar, epoch)
            print(f"Epoch {epoch + 1}/{epochs} -> gen_loss={gen_loss:.4f}, disc_loss={disc_loss:.4f}\n",
                                            file=stderr)  # saving the average loss values every epoch
            
            if make_checkpoints and (epoch + 1) % 20 == 0:  # checkpoints every 20 epochs
                checkpoint.save(file_prefix=checkpoint_prefix)
        
        if make_checkpoints and epochs % 20 != 0:  # save final state as a checkpoint
            checkpoint.save(file_prefix=checkpoint_prefix)
        generator.save_weights("../models/patch_gan.weights.h5")  # and export weights

In [None]:
# generator.save_weights("../models/patch_gan.weights.h5")  # in case export above fails due to for example lack of memory

## Training the model
I've tried training for many different amount of epochs, but the best results were achieved after exactly 60 epochs.
Higher values used to cause overfitting that would then occasionally create really weird artifacts inside of the upscaled images.

In [None]:
# checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

In [None]:
train_epochs = 60
fit(img256_train, img512_train, epochs=train_epochs)

## Sample upscaling

In [None]:
for image in img256_test:
    generate_images(generator, (image + 1) * 127.5, (image + 1) * 127.5)