Import libraries.

In [4]:
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
import string
import random

Define helper functions to allow image processing.

In [5]:
def random_crop(image, width, height):
  cropped_image = tf.image.random_crop(
      image, size=[height, width, 3])

  return cropped_image

In [6]:
def random_jitter(image, width, height):
  # resizing to 286 x 286 x 3
  image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  # randomly cropping to 256 x 256 x 3
  image = random_crop(image, width, height)

  # random mirroring
  image = tf.image.random_flip_left_right(image)

  return image

In [7]:
def _get_norm_layer(norm):
    if norm == 'none':
        return lambda: lambda x: x
    elif norm == 'batch_norm':
        return keras.layers.BatchNormalization
    elif norm == 'instance_norm':
        return tfa.layers.InstanceNormalization
    elif norm == 'layer_norm':
        return keras.layers.LayerNormalization

Generator is build using the ResNet architecture. ResNet stands for Residual Network and is a deep learning model which relieves the problem of vanishing gradient in deep networks. It introduces skip connections which allow alternate shortcut path for the gradient to flow through. Moreover, if a particular layer hurts the performance of the ResNet, it will be skipped by regularization.

Building block of a ResNet is called a residual block. Activation of a residual block layer is fast-forwarded to a deeper layer in the neural network.

An image convolution is an element-wise multiplication of two matrices followed by a summation of obtained elements. Image is a multi-dimensional matrix which has a certain width and height. It also has a depth of 3 - one channel for each color - red, green or blue.

Kernel is a smaller matrix which sits on top of the image and slides from left-to-right and top-to-bottom applying a convolution at each coordinate of the original image - kernel stops at each location, examines the neighborhood of pixels located at the center, convolution is performed and output value (kernel output) stored in the output image at the same coordinates as the center of the kernel. Odd kernel size is used to ensure there is a valid integer coordinate at the center.

Pixels located on the border of the image are not in the center of any sliding window - this implies there is a decrease in spatial dimension of the image. To ensure that output image has the same dimensions as input image, padding is applied. Reflect mode of padding reuses the contents of current row or column for padding the values. It duplicates the image values along the borders but in reverse order, hence the name 'reflect mode'.

*keras.layers.Conv2D* creates a convolution kernel which is later convolved with the input layer to produce a tensor of outputs. The first parameter to this function is the number of filters that the convolutional layer will learn. The second argument is the kernel size. Padding 'valid' means no padding and as a result spatial dimansions are allowed to reduce. Padding 'same' is used to indicate that the dimensions of the input image should be preserved. The addition od a bias vector is controlled by use_bias argument.

In [19]:
def ResnetGenerator(input_shape=(256, 256, 3),
                    output_channels=3,
                    dim=64,
                    n_downsamplings=2,
                    n_blocks=9,
                    norm='instance_norm'):
    Norm = _get_norm_layer(norm)

    def _residual_block(x):
        dim = x.shape[-1]
        print(dim)
        h = x

        h = tf.pad(h, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT')
        h = keras.layers.Conv2D(dim, 3, padding='valid', use_bias=False)(h)
        h = Norm()(h)
        h = tf.nn.relu(h)

        h = tf.pad(h, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT')
        h = keras.layers.Conv2D(dim, 3, padding='valid', use_bias=False)(h)
        h = Norm()(h)

        return keras.layers.add([x, h])

    # 0
    h = inputs = keras.Input(shape=input_shape)

    # 1
    h = tf.pad(h, [[0, 0], [3, 3], [3, 3], [0, 0]], mode='REFLECT')
    h = keras.layers.Conv2D(dim, 7, padding='valid', use_bias=False)(h)
    h = Norm()(h)
    h = tf.nn.relu(h)

    # 2
    for _ in range(n_downsamplings):
        dim *= 2
        h = keras.layers.Conv2D(dim, 3, strides=2, padding='same', use_bias=False)(h)
        h = Norm()(h)
        h = tf.nn.relu(h)

    # 3
    for _ in range(n_blocks):
        h = _residual_block(h)

    # 4
    for _ in range(n_downsamplings):
        dim //= 2
        h = keras.layers.Conv2DTranspose(dim, 3, strides=2, padding='same', use_bias=False)(h)
        h = Norm()(h)
        h = tf.nn.relu(h)

    # 5
    h = tf.pad(h, [[0, 0], [3, 3], [3, 3], [0, 0]], mode='REFLECT')
    h = keras.layers.Conv2D(output_channels, 7, padding='valid')(h)
    h = tf.tanh(h)

    return keras.Model(inputs=inputs, outputs=h)

Discriminator uses *leaky_relu* function which stands for retified linear unit activation. It is commonly used activation function which speeds up training.

In [12]:
def ConvDiscriminator(input_shape=(256, 256, 3),
                      dim=64,
                      n_downsamplings=3,
                      norm='instance_norm'):
    dim_ = dim
    Norm = _get_norm_layer(norm)

    # 0
    h = inputs = keras.Input(shape=input_shape)

    # 1
    h = keras.layers.Conv2D(dim, 4, strides=2, padding='same')(h)
    h = tf.nn.leaky_relu(h, alpha=0.2)

    for _ in range(n_downsamplings - 1):
        dim = min(dim * 2, dim_ * 8)
        h = keras.layers.Conv2D(dim, 4, strides=2, padding='same', use_bias=False)(h)
        h = Norm()(h)
        h = tf.nn.leaky_relu(h, alpha=0.2)

    # 2
    dim = min(dim * 2, dim_ * 8)
    h = keras.layers.Conv2D(dim, 4, strides=1, padding='same', use_bias=False)(h)
    h = Norm()(h)
    h = tf.nn.leaky_relu(h, alpha=0.2)

    # 3
    h = keras.layers.Conv2D(1, 4, strides=1, padding='same')(h)

    return keras.Model(inputs=inputs, outputs=h)

The purpose of loss functions is to compute the quantity that a model should seek to minimize during training. *tf.keras.losses.BinaryCrossentropy* computes the cross-entropy loss between true labels and predicted labels. Cross entropy loss is a metric used to measure how well a classification model in machine learning performs. Cross-entropy loss increases as the predicted probability diverges from the actual label.

In [13]:
LAMBDA = 10
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In CycleGAN, there is no paired data to train on, hence there is no guarantee that the input x and the target y pair are meaningful during training. Thus in order to enforce that the network learns the correct mapping, the cycle consistency loss is calculated. Cycle consistency means the result should be as close to the original input as possible.

Identity loss is added to help preserve tint. It states that when given an image of the target class, a generator should return that same image.

In [14]:
def discriminator_loss(real, generated):
  real_loss = loss_obj(tf.ones_like(real), real)

  generated_loss = loss_obj(tf.zeros_like(generated), generated)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss * 0.5

In [15]:
def generator_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)

In [16]:
def calc_cycle_loss(real_image, cycled_image):
  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))

  return LAMBDA * loss1

In [17]:
def identity_loss(real_image, same_image):
  loss = tf.reduce_mean(tf.abs(real_image - same_image))
  return LAMBDA * 0.5 * loss

Generate and plot images.

In [18]:
def generate_images(model, test_input, epoch = ''):
  prediction = model(test_input)

  fig = plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']
  if(epoch):
      title[1] = title[1] + ', epoch = ' + str(epoch)

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.savefig('output/' + ''.join(random.choices(string.ascii_lowercase + string.digits, k=10)) + '.png', bbox_inches='tight')
  plt.show()