# Reimplementation: CycleGAN (2017)

This implementation focuses on unpaired image-to-image translation of photos to Monet paintings and vice versa. The code for the CycleGAN was built in reference to [this implementation](https://github.com/LynnHo/CycleGAN-Tensorflow-2), as well as the [original CycleGAN implementation](https://junyanz.github.io/CycleGAN/). This notebook has been executed using Colab Pro for the present project. It is possible that (1) this notebook does not run using the free Colab version due to usage constraints, or (2) training time might be increased.

## 1 Load the data and prepare the notebook

The dataset is taken from [here](https://www.kaggle.com/competitions/gan-getting-started/data). This dataset consists of 300 images of Monet paintings with 256 x 256 pixels. There are also 7028 content images of the same size as the Monet paintings.

The dataset is loaded into the local Colab storage from the drive of the user of this notebook. For this code to work, the user must have the dataset uploaded to their drive.

To use the trained model, one can also load the checkpoints (after 200 epochs) into the local storage. They will automatically be used once loaded.

### 1.1 Connect to Drive
This notebook uses the data that is uploaded to a drive in a zip file. This is not openly accessible. If you would like to use the dataset, download it from [here](https://www.kaggle.com/competitions/gan-getting-started/data) and store it in a zip file named data.zip in your Google drive. When extracted, the folder should contain a folder *'monet'* with a subfolder *'monet_jpg'* that contains all images of Monet paintings, as well as a folder *'photos'* with a subfolder *'photos_jpg'* that contains all photographs. All pictures should be of the size 256 x 256 pixels and in jpg format.

In [None]:
# Connect to Google drive
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
# Get the dataset by unzipping from drive to local Colab storage
!unzip '/content/gdrive/MyDrive/data.zip' -d '/content/'

### 1.2 Load the latest checkpoint
If you have a checkpoint from an earlier training saved in your drive, you can also load it here.

In [None]:
# Get the checkpoint for the trained model
# !unzip '/content/gdrive/MyDrive/tf_checkpoints-final.zip' -d '/content/'

### 1.3 Installations and imports

In [None]:
# install the missing library (needed for instance normalization layer)
!pip install tensorflow-addons

In [None]:
!pip install pytorch-fid

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import optimizers
from tensorflow.keras.optimizers import schedules
from keras import Model
import keras.layers as layers
import tensorflow_addons as tfa
import keras
import matplotlib.pyplot as plt
import numpy as np
import time as t
import os
import cv2

## 2 The Models

The CycleGAN consists of a generator and a discriminator. The generator makes use of residual blocks, hence also a Resnet is implemented. The generator also needs down- and upsampling blocks, which are implemented as separate networks to increase understandability of the network structure.

### 2.1 Helper networks: Resnet, Downsampler, and Upsampler

In [None]:
class Resnet(Model):

    def __init__(self, dim):
        """
        A Resnet block for CycleGAN consists of a convolution, followed by 
        normalization, relu activation, another convolution and normalization. 
        The input and output images are added to create the residual 
        characteristic.

        dim: number of filters
        """
        super(Resnet, self).__init__()

        self.res_conv1 = layers.Conv2D(filters=dim, 
                                       kernel_size=3, 
                                       padding='valid', 
                                       use_bias=False)
        self.res_norm1 = tfa.layers.InstanceNormalization()
        self.res_relu = layers.ReLU()

        self.res_conv2 = layers.Conv2D(filters=dim, 
                                       kernel_size=3, 
                                       padding='valid', 
                                       use_bias=False)
        self.res_norm2 = tfa.layers.InstanceNormalization()

    @tf.function
    def call(self, input, training=None):
        """
        Training function for the residual block.

        input: input images with dimensions 256 x 256
        training: training flag - whether the network is training or not
        """

        x = tf.pad(input, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT')
        x = self.res_conv1(x)
        x = self.res_norm1(x)
        x = self.res_relu(x)
        x = tf.pad(x, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT')
        x = self.res_conv2(x)
        x = self.res_norm2(x)

        return layers.add([input, x])


class Downsampler(Model):

    def __init__(self, dim, kernel_size=3, strides=2, leaky=False):
        """
        A downsampler block downsamples an image to half its size in height and 
        width but with a given number of dimensions for the third dimension. A 
        downsampler block consists of a convolutional layer, instance 
        normalization, followed by either leaky relu or relu as activation.

        dim: number of filters for convolution
        kernel_size: kernel size for convolution
        strides: strides for convolution
        leaky: whether to apply leaky relu or relu as activation
        """
        super(Downsampler, self).__init__()

        self.down_conv = layers.Conv2D(filters=dim, 
                                       kernel_size=kernel_size, 
                                       strides=strides, 
                                       padding='same', 
                                       use_bias=False)
        self.down_norm = tfa.layers.InstanceNormalization()
        if leaky:
          self.down_relu = layers.LeakyReLU()
        else:
          self.down_relu = layers.ReLU()

    @tf.function
    def call(self, input):
        """
        Call function for the downsampler blocks.

        input: input image with varying number of dimensions
        """

        x = self.down_conv(input)
        x = self.down_norm(x)
        x = self.down_relu(x)

        return x


class Upsampler(Model):

    def __init__(self, dim):
        """
        An upsampler block upsamples a given image to double its height and 
        width. An upsampler block consists of a transposed convolution, 
        instance normalization, followed by relu activation.

        dim: number of filters for convolution
        """
        super(Upsampler, self).__init__()

        self.up_convtp = layers.Conv2DTranspose(filters=dim, 
                                                kernel_size=3, 
                                                strides=2, 
                                                padding='same', 
                                                use_bias=False)
        self.up_norm = tfa.layers.InstanceNormalization()
        self.up_relu = layers.ReLU()

    @tf.function
    def call(self, input):
        """
        Call function for the upsampler blocks.

        input: input image with varying number of dimensions
        """

        x = self.up_convtp(input)
        x = self.up_norm(x)
        x = self.up_relu(x)

        return x

### 2.2 The Generator

In [None]:
class ResnetGenerator(Model):

    def __init__(self, 
                 inp_shape=(256, 256, 3), 
                 output_channels=3):
        """
        This Resnet generator consists of a convolution followed by instance 
        normalization, and relu activation, three downsampler blocks, followed 
        by seven resnet/ residual blocks, again followed by three upsampling 
        blocks, and finalized with a convolutional layer to reduce the number 
        of channels to three, and tanh activation.

        inp_shape: the input shape of input images to the network
        output_channels: the number of dimensions for the output images
        """
        super(ResnetGenerator, self).__init__()

        self.input_layer = layers.Input(shape=inp_shape)
        self.block1_conv = layers.Conv2D(filters=64, 
                                         kernel_size=7, 
                                         padding='valid', 
                                         use_bias=False)
        self.block1_norm = tfa.layers.InstanceNormalization()
        self.block1_relu = layers.ReLU()

        self.block2_down1 = Downsampler(dim=128)
        self.block2_down2 = Downsampler(dim=256)
        self.block2_down3 = Downsampler(dim=512)

        self.block3_resnet1 = Resnet(dim=512)
        self.block3_resnet2 = Resnet(dim=512)
        self.block3_resnet3 = Resnet(dim=512)
        self.block3_resnet4 = Resnet(dim=512)
        self.block3_resnet5 = Resnet(dim=512)
        self.block3_resnet6 = Resnet(dim=512)
        self.block3_resnet7 = Resnet(dim=512)
        #self.block3_resnet8 = Resnet(dim=512)
        #self.block3_resnet9 = Resnet(dim=512)

        self.block4_up1_1 = Upsampler(dim=256)
        self.block4_up1 = Upsampler(dim=128)
        self.block4_up2 = Upsampler(dim=64)

        self.out_conv = layers.Conv2D(output_channels, 
                                      kernel_size=7, 
                                      padding='valid')
        self.out_tanh = layers.Activation(tf.nn.tanh)

    @tf.function
    def call(self, input, training=False):
        """
        Call function for the ResnetGenerator.
        
        input: a 256 x 256 pixel input image
        training: training flag - whether the network is training or not
        """

        #x = self.input_layer(input)
        x = tf.pad(input, [[0, 0], [3, 3], [3, 3], [0, 0]], mode='REFLECT')
        x = self.block1_conv(x)
        x = self.block1_norm(x)
        x = self.block1_relu(x)

        x = self.block2_down1(x)
        x = self.block2_down2(x)
        x = self.block2_down3(x)

        x = self.block3_resnet1(x)
        x = self.block3_resnet2(x)
        x = self.block3_resnet3(x)
        x = self.block3_resnet4(x)
        x = self.block3_resnet5(x)
        x = self.block3_resnet6(x)
        x = self.block3_resnet7(x)
        #x = self.block3_resnet8(x)
        #x = self.block3_resnet9(x)

        x = self.block4_up1_1(x)
        x = self.block4_up1(x)
        x = self.block4_up2(x)

        x = tf.pad(x, [[0, 0], [3, 3], [3, 3], [0, 0]], mode='REFLECT')
        x = self.out_conv(x)
        x = self.out_tanh(x)

        return x

### 2.3 The Discriminator

The discriminator is a PatchDiscriminator, i.e. it determines for multiple patches in an image, whether the respective patch is real or fake.

In [None]:
class Discriminator(Model):

    def __init__(self, inp_shape=(256, 256, 3)):
        """
        The PatchDiscriminator takes an input image and processes it with a 
        convolutional layer, and leaky relu activation, followed by three 
        downsampling blocks with relu activation each, and a final output 
        convolutional layer.

        inp_shape: image input shape
        """
        super(Discriminator, self).__init__()

        self.input_layer = keras.Input(shape=inp_shape)
        self.conv_1 = layers.Conv2D(filters=64,
                                    kernel_size=4,
                                    strides=2,
                                    input_shape=inp_shape,
                                    padding='same',
                                    activation=None)
        self.leaky_relu_1 = layers.LeakyReLU()

        self.down_1 = Downsampler(dim=128, kernel_size=4, leaky=True)
        self.down_2 = Downsampler(dim=256, kernel_size=4, leaky=True)
        self.down_3 = Downsampler(dim=512, kernel_size=4, strides=1, leaky=True)

        self.out = layers.Conv2D(filters=1,
                                    kernel_size=4,
                                    strides=1,
                                    padding='same',
                                    activation='sigmoid')

    @tf.function
    def call(self, input, training=False):
        """
        Call function for the PatchDiscriminator
        
        input: input image
        training: training flag - whether the network is training or not
        """

        x = self.conv_1(input)
        x = self.leaky_relu_1(x)

        x = self.down_1(x)
        x = self.down_2(x)
        x = self.down_3(x)

        x = self.out(x)

        return x

## 3 Training the network

### 3.1 Training functions

In [None]:
BATCH_SIZE = 6

In [None]:
def norm(img):
    """
    Normalizes a given image to the range -1 to 1

    img: the image to process
    """
    return (img / 255.) * 2 - 1

def post(img):
    """
    Postprocesses an image from the range -1 to 1 to the range 0 to 1.

    img: the image to process
    """
    return (img * 0.5 + 0.5)

@tf.function
def train_generator(photo, monet, cycle_loss_weight=10, identity_loss_weight=5):
    """
    Generator training step with loss computation.

    :param photo: a photo from dataset
    :param monet: a monet image from dataset
    :param cycle_loss_weight: weight for cycle consistency loss
    :param identity_loss_weight: weight for identity loss
    """

    mse = tf.losses.MeanSquaredError()
    bce = tf.losses.BinaryCrossentropy(from_logits=True)
    mae = tf.losses.MeanAbsoluteError()

    with tf.GradientTape() as t:

        P2M = GEN_P2M(photo, training=True) # from photo to monet
        M2P = GEN_M2P(monet, training=True) # from monet to photo
        P2M2P = GEN_M2P(P2M, training=True) # from photo to monet back to photo
        M2P2M = GEN_P2M(M2P, training=True) # from monet to photo to monet
        P2P = GEN_M2P(photo, training=True) # photo to photo
        M2M = GEN_P2M(monet, training=True) # monet to monet

        x = DISC_M(P2M, training=True)
        y = DISC_P(M2P, training=True)

        disc_loss_P2M = bce(tf.ones_like(x), x)
        disc_loss_M2P = bce(tf.ones_like(y), y)

        cycle_loss_P2M2P = mae(photo, P2M2P) 
        cycle_loss_M2P2M = mae(monet, M2P2M)

        identity_loss_P2P = mae(photo, P2P)
        identity_loss_M2M = mae(monet, M2M)

        generator_loss = (cycle_loss_P2M2P + cycle_loss_M2P2M) * cycle_loss_weight + \
                         (identity_loss_P2P + identity_loss_M2M) * identity_loss_weight + \
                         (disc_loss_P2M + disc_loss_M2P)

    # Update gradients
    gradients = t.gradient(generator_loss, 
                           GEN_P2M.trainable_variables + GEN_M2P.trainable_variables)
    GEN_ADAM.apply_gradients(zip(gradients, 
                                 GEN_P2M.trainable_variables + GEN_M2P.trainable_variables))

    return P2M, M2P, generator_loss


@tf.function
def train_discriminator(photo, monet, p2m, m2p):
    """
    Discriminator training step with loss computation.

    :param photo: a photo from dataset
    :param monet: a monet image from dataset
    :param P2M: a generated monet image
    :param M2P: a generated photo
    """

    mse = tf.losses.MeanSquaredError()
    bce = tf.losses.BinaryCrossentropy(from_logits=True)

    with tf.GradientTape() as t:
        orig_photo = DISC_P(photo, training=True)
        orig_monet = DISC_M(monet, training=True)
        false_photo = DISC_P(m2p, training=True)
        false_monet = DISC_M(p2m, training=True)

        orig_photo_loss = bce(tf.ones_like(orig_photo), orig_photo)
        false_photo_loss = bce(tf.zeros_like(false_photo), false_photo) * 0.5

        orig_monet_loss = bce(tf.ones_like(orig_monet), orig_monet)
        false_monet_loss = bce(tf.zeros_like(false_monet), false_monet) * 0.5

        discriminator_loss = ((orig_photo_loss + false_photo_loss) + (orig_monet_loss + false_monet_loss)) * 0.5

    gradients = t.gradient(discriminator_loss, 
                           DISC_P.trainable_variables + DISC_M.trainable_variables)
    DISC_ADAM.apply_gradients(zip(gradients, 
                                  DISC_P.trainable_variables + DISC_M.trainable_variables))

    return discriminator_loss


def train_single_step(photo, monet):
    """
    Function to combine training of generator and discriminator

    :param photo: a photo from the dataset
    :param monet: a monet image from the dataset
    """

    false_monet, false_photo, g_loss = train_generator(photo, monet)
    d_loss = train_discriminator(photo, monet, false_monet, false_photo)

    return {'Generator loss': g_loss,
            'Discriminator loss': d_loss}


def train(start_epoch, epochs, augment=True):
  """
  Full training function

  start_epoch: epoch from which to start training (only interesting when using 
    checkpoints)
  epochs: number of epochs to train for
  augment: augmentation flag - whether to augment images using the 
    ImageDataGenerator or not
  """

  # If a checkpoint is available, load it
  try:
    checkpoint.restore(manager.latest_checkpoint)
    print("Restored from {}".format(manager.latest_checkpoint))
  except:
    print("Initializing from scratch.")

  generator_losses = []
  discriminator_losses = []

  # Create ImageDataGenerators based on the augment-flag. If augment is set to 
  # true, allow slight image augmentations like rotation, width and height 
  # shift, horizontal flip, etc.
  if augment:
    train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
        rotation_range=0.2,
        width_shift_range=0.1,
        height_shift_range=0.1,
        brightness_range=(0.8, 1.2),
        horizontal_flip=True,
        fill_mode='reflect',
    )
  else:
    train_datagen = tf.keras.preprocessing.image.ImageDataGenerator()

  # Create data flow using both datasets (photo and monet) and given batch size
  train_generator_photo = train_datagen.flow_from_directory(
      'data/photo',
      batch_size=BATCH_SIZE,
      shuffle=True
  )

  train_generator_monet = train_datagen.flow_from_directory(
      'data/monet',
      batch_size=BATCH_SIZE,
      shuffle=True
  )

  generator_losses_average = []
  discriminator_losses_average = []

  # Train for a given number of epochs
  for epoch in np.arange(start_epoch, epochs):

    print('--- EPOCH ' + str(epoch) + ' ---')
    starttime = t.time()
    generator_losses = []
    discriminator_losses = []

    # Every 10th epoch save the generated images to get the  fip-score over time
    if epoch % 10 == 0:
      save_gen = True
      photo_path = 'fip-images/generated_photos_epoch' + str(epoch)
      monet_path = 'fip-images/generated_monets_epoch' + str(epoch)
    else:
      save_gen = False

    for i, ((p, p_l), (m, m_l)) in enumerate(zip(train_generator_photo, 
                                                 train_generator_monet)):

      # Train the network and keep track of the losses
      losses = train_single_step(norm(p), norm(m))
      generator_losses.append(losses['Generator loss'])
      discriminator_losses.append(losses['Discriminator loss'])

      # ---- UNCOMMENT IF INTERESTED IN STORING IMAGES FOR POTENTIAL FID 
      # TRACKING ----
      # Save the generated photos and monets
      # if save_gen:
      #   if not os.path.exists('fip-images'):
      #       os.makedirs('fip-images')

      #   # Save the generated photos
      #   if not os.path.exists(photo_path):
      #       os.makedirs(photo_path)

      #   id2 = GEN_M2P(norm(m), training=False)
      #   for idx in range(BATCH_SIZE):
      #     img_photo = post(id2[idx])
      #     cv2.imwrite(photo_path + '/' + str((i*BATCH_SIZE)+idx) + '.png', 
      #                 cv2.cvtColor(np.array(img_photo * 255).astype(np.uint8), 
      #                              cv2.COLOR_RGB2BGR))

      #   # Save the generated Monets
      #   if not os.path.exists(monet_path):
      #       os.makedirs(monet_path)

      #   id = GEN_P2M(norm(p), training=False)
      #   for idx in range(BATCH_SIZE):
      #     img_monet = post(id[idx])
      #     cv2.imwrite(monet_path + '/' + str((i*BATCH_SIZE)+idx) + '.png', 
      #                 cv2.cvtColor(np.array(img_monet * 255).astype(np.uint8), 
      #                              cv2.COLOR_RGB2BGR))


      # Save an image every 10 steps to track the progress using visual results
      # (ignore the very first iteration per epoch as it nearly corresponds to 
      # the last step of the previous epoch)
      if i % 10 == 0 and i != 0:

          if not os.path.exists('images'):
            os.makedirs('images')

          # Save original photograph and generated monet
          fig, ax = plt.subplots(1, 2, figsize=(10, 4))
          id = GEN_P2M(norm(p), training=False)
          img_monet = post(id[0])
          ax[0].imshow(p[0] / 255.)
          ax[0].axis('off')
          ax[1].imshow(img_monet)
          ax[1].axis('off')
          plt.savefig('images/' + str(epoch) + '_' + str(i) + '_photo2monet.png')
          plt.close(fig)

          # Save original monet and generated photograph
          fig, ax = plt.subplots(1, 2, figsize=(10, 4))
          id2 = GEN_M2P(norm(m), training=False)
          img_photo = post(id2[0])
          ax[0].imshow(m[0] / 255.)
          ax[0].axis('off')
          ax[1].imshow(img_photo)
          ax[1].axis('off')
          plt.savefig('images/' + str(epoch) + '_' + str(i) + '_monet2photo.png')
          plt.close(fig)

          # Also create an average of the last 10 steps
          generator_losses_average.append(np.mean(np.array(generator_losses)))
          discriminator_losses_average.append(np.mean(np.array(discriminator_losses)))
          generator_losses = []
          discriminator_losses = []
      
      # break after a specific number of steps per epoch (tf image data 
      # generator will generate an infinite amount of images without break) and 
      # save the checkpoint - plot the current losses to track the progress
      if i >= ((300 / BATCH_SIZE) - 1):
        save_path = manager.save()
        print('... finished after ' + str(t.time() - starttime) + ' seconds')
        _, ax = plt.subplots(1, 2, figsize=(20, 4))
        ax[0].plot(generator_losses_average, 'blue')
        ax[0].set_title('Generator loss')
        ax[1].plot(discriminator_losses_average, 'orange')
        ax[1].set_title('Discriminator loss')
        plt.show()
        break

  return generator_losses_average, discriminator_losses_average

In [None]:
# Learning rate decay

class LinearDecay(schedules.LearningRateSchedule):
    # if `step` < `step_decay`: use fixed learning rate
    # else: linearly decay the learning rate to zero

    def __init__(self, initial_learning_rate, total_steps, step_decay):
        super(LinearDecay, self).__init__()
        self._initial_learning_rate = initial_learning_rate
        self._steps = total_steps
        self._step_decay = step_decay
        self.current_learning_rate = tf.Variable(initial_value=initial_learning_rate, trainable=False, dtype=tf.float32)

    def __call__(self, step):
        self.current_learning_rate.assign(tf.cond(
            step >= self._step_decay,
            true_fn=lambda: self._initial_learning_rate * (1 - 1 / (self._steps - self._step_decay) * (step - self._step_decay)),
            false_fn=lambda: self._initial_learning_rate
        ))
        print(self.current_learning_rate)
        return self.current_learning_rate

### 3.2 Initial training for 200 epochs

In [None]:
IMG_HEIGHT = 256
IMG_WIDTH = 256

# Initialize networks
GEN_P2M = ResnetGenerator()
GEN_M2P = ResnetGenerator()
DISC_P = Discriminator()
DISC_M = Discriminator()

# Set number of epochs to train for
EPOCHS = 200

# Set learning rate scheduler
LR_INITIAL = 2e-4
LR_G = LinearDecay(LR_INITIAL, 
                   EPOCHS * (300 / BATCH_SIZE), 
                   EPOCHS / 2 * (300 / BATCH_SIZE))
LR_D = LinearDecay(LR_INITIAL, 
                   EPOCHS * (300 / BATCH_SIZE), 
                   EPOCHS / 2 * (300 / BATCH_SIZE))

# Set optimizers
GEN_ADAM = tf.keras.optimizers.Adam(learning_rate=LR_G, 
                                    beta_1=0.5, 
                                    beta_2=0.999)
DISC_ADAM = tf.keras.optimizers.Adam(learning_rate=LR_D, 
                                     beta_1=0.5, 
                                     beta_2=0.999)

# Define checkpoint storage place
checkpoint = tf.train.Checkpoint(**dict(GEN_P2M=GEN_P2M, 
                                        GEN_M2P=GEN_M2P, 
                                        DISC_P=DISC_P, 
                                        DISC_M=DISC_M))
manager = tf.train.CheckpointManager(checkpoint, 
                                     './tf_checkpoints', 
                                     max_to_keep=3)

In [None]:
START_EPOCH = 0
gen_loss, disc_loss = train(START_EPOCH, EPOCHS, augment=True)

### 3.3 Retraining without data augmentation for 100 epochs

In [None]:
# Initialize networks
GEN_P2M = ResnetGenerator()
GEN_M2P = ResnetGenerator()
DISC_P = Discriminator()
DISC_M = Discriminator()

# Set number of epochs to train for
EPOCHS = 100

# Set learning rate scheduler
LR_INITIAL = 2e-5
LR_G = LinearDecay(LR_INITIAL, 
                   EPOCHS * (300 / BATCH_SIZE), 
                   EPOCHS / 2 * (300 / BATCH_SIZE))
LR_D = LinearDecay(LR_INITIAL, 
                   EPOCHS * (300 / BATCH_SIZE), 
                   EPOCHS / 2 * (300 / BATCH_SIZE))

# Set optimizers
GEN_ADAM = tf.keras.optimizers.Adam(learning_rate=LR_G, 
                                    beta_1=0.5, 
                                    beta_2=0.999)
DISC_ADAM = tf.keras.optimizers.Adam(learning_rate=LR_D, 
                                     beta_1=0.5, 
                                     beta_2=0.999)

# Define checkpoint storage place
checkpoint = tf.train.Checkpoint(**dict(GEN_P2M=GEN_P2M, 
                                        GEN_M2P=GEN_M2P, 
                                        DISC_P=DISC_P, 
                                        DISC_M=DISC_M))
manager = tf.train.CheckpointManager(checkpoint, 
                                     './tf_checkpoints', 
                                     max_to_keep=3)

In [None]:
START_EPOCH = 0
gen_loss, disc_loss = train(START_EPOCH, EPOCHS, augment=False)

## 4 Results

In [None]:
def load_data_from_dir(data_dir, subset):
    """
    Loads and preprocesses images from a given directory.

    data_dir: Name of the directory
    """

    # Load data
    dataset = tf.keras.utils.image_dataset_from_directory(data_dir,
                                                          # validation_split=0.2,
                                                          # subset=subset,
                                                          seed=123,
                                                          shuffle=True,
                                                          image_size=(IMG_HEIGHT, IMG_WIDTH),
                                                          batch_size=1)

    return dataset

In [None]:
def create_and_save_results(test_data, dir_name, p2m=True):
  """
  Generates domain transfer from the original to the target domain between monet 
  images and photographs.
  
  test_data: data as tensors
  dir_name: name of dir where to save the generated samples
  p2m: whether the transfer direction is photo to monet, or the opposite 
    direction
  """

  if not os.path.exists(dir_name):
    os.makedirs(dir_name)

  if not os.path.exists(dir_name + '/generated'):
    os.makedirs(dir_name + '/generated')

  if not os.path.exists(dir_name + '/original'):
    os.makedirs(dir_name + '/original')
  
  for i, img in enumerate(test_data):
    
    # Do not use more than a maximal number of images
    # if i > 1000:
    #   break

    # Generate either photo or Monet painting and postprocess the generated 
    # image to the correct range (0 to 1) for saving
    if p2m:
      generated = GEN_P2M(norm(img[0]), training=False)
    else:
      generated = GEN_M2P(norm(img[0]), training=False)
    generated_postprocessed = post(generated)

    # Save images:
    cv2.imwrite(dir_name + '/original/' + str(i) + '.png', 
                cv2.cvtColor(np.array(img[0][0]).astype(np.uint8), 
                             cv2.COLOR_RGB2BGR))
    cv2.imwrite(dir_name + '/generated/' + str(i) + '.png', 
                cv2.cvtColor(np.array(generated_postprocessed[0] * 255).astype(np.uint8), 
                             cv2.COLOR_RGB2BGR))

In [None]:
# Load the original data
test_monets = load_data_from_dir('data/monet', subset='validation')
test_photos = load_data_from_dir('data/photo', subset='validation')

In [None]:
# Photo to Monet
create_and_save_results(test_photos, 'Photo2Monet', p2m=True)

# Monet to photo
create_and_save_results(test_monets, 'Monet2Photo', p2m=False)

In [None]:
# ---- UNCOMMENT TO UPLOAD GENERATED IMAGES TO DRIVE AS ZIP FILE ----
#!zip -r '/content/gdrive/MyDrive/Photo2Monet-2022-08-17.zip' '/content/Photo2Monet'
#!zip -r '/content/gdrive/MyDrive/Monet2Photo-2022-08-17.zip' '/content/Monet2Photo'

## 5 Evaluation
The quantitative evaluation for the present work is performed using the Fréchet Inception Disctance (FID), which is commonly used to assess generative adversarial networks. The following code blocks calculate the FIDs between the original datasets (Monet vs. photographs), original vs. generated photographs, and original vs. generated Monets. Uses [this](https://github.com/mseitzer/pytorch-fid) FID calculation implementation.

In [None]:
# FID calculation of original datasets
!python -m pytorch_fid '/content/Photo2Monet/original' '/content/Monet2Photo/original'

In [None]:
# FID calculation of generated photographs
!python -m pytorch_fid '/content/Photo2Monet/original' '/content/Monet2Photo/generated'

In [None]:
# FID calculation of generated Monets
!python -m pytorch_fid '/content/Monet2Photo/original' '/content/Photo2Monet/generated'