# Using CycleGAN for Image Translation

Using the CycleGAN implementation by https://hardikbansal.github.io/CycleGANBlog/

For more information, check out [TensorFlow](https://www.tensorflow.org/tutorials/generative/cyclegan) and [Keras](https://keras.io/examples/generative/cyclegan/) CycleGAN documentation pages.

In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa

from kaggle_datasets import KaggleDatasets
import matplotlib.pyplot as plt
import numpy as np

import os
import time

try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)

AUTOTUNE = tf.data.experimental.AUTOTUNE
    
print(tf.__version__)
tf.compat.v1.enable_eager_execution()

Number of replicas: 1
2.2.0


In [2]:
# LOADING THE FILE NAMES

MONET_FILENAMES = tf.io.gfile.glob(str('../input/monet-gan-getting-started/monet_tfrec/*.tfrec'))
print('Monet TFRecord Files:', len(MONET_FILENAMES))

PHOTO_FILENAMES = tf.io.gfile.glob(str('../input/monet-gan-getting-started/photo_tfrec/*.tfrec'))
print('Photo TFRecord Files:', len(PHOTO_FILENAMES))

Monet TFRecord Files: 5
Photo TFRecord Files: 20


In [3]:
IMAGE_SIZE = [256, 256]  # given

def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3) # 3 channels = RGB, which is given
    # normalise image in range [-1, 1]
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

def read_tfrecord(example):
    tfrecord_format = {
        "image_name": tf.io.FixedLenFeature([], tf.string),
        "image": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    return image

def load_dataset(filenames, labeled=True, ordered=False):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
    return dataset

In [4]:
# LOADING THE DATA IMAGES
monet_ds = load_dataset(MONET_FILENAMES, labeled=True).batch(1)
photo_ds = load_dataset(PHOTO_FILENAMES, labeled=True).batch(1)

# Generator
3 Components used in the Generator
1. Encoder - 3 Conv. Layers
1. Transformer
1. Decoder - 2 Deconv. layers + Conv. Layer

## Encoder
* inputimg = input image [256, 256, 3]
* num_features = number of filters in the first layer of generator
* window width, window height = the sliding window dimensions that slides across the images
* stride width, stride height = the shift by which the window slides with


    def general_conv2d(inputc, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1):

    with tf.variable_scope(name):
        conv = tf.contrib.layers.conv2d(inputc, num_features, [window_width, window_height], [stride_width, stride_height],
                                        padding, activation_fn=None, weights_initializer=tf.truncated_normal_initializer(stddev=stddev),
                                        biases_initializer=tf.constant_initializer(0.0))

As you can see, with more conv layers, the number of features also increases. This is to encode the features of an 256 x 256 image with only 3 features leaving only a 64 x 64 image with 256 features                                       

## Transformation
We need to remember not to deviate too much from the original image, since the translation should still reserve some trace of the original image. Hence we transform using a resnet layer

The resnet layer consists of 2 conv layers and a function to add:
1. The input goes through the first conv layer
1. The tensor from the first layer goes to the second
1. The second tensor is added to the original input and returned

    def build_resnet_block(input_res, num_features):

        out_res_1 = general_conv2d(input_res, num_features,
                                   window_width=3,
                                   window_heigth=3,
                                   stride_width=1,
                                   stride_heigth=1)
        out_res_2 = general_conv2d(out_res_1, num_features,
                                   window_width=3,
                                   window_heigth=3,
                                   stride_width=1,
                                   stride_heigth=1)
        return (out_res_2 + input_res)

## Decoding 
Now we have to return the tensor into a viable image. Hence, the tensor needs to unpack from a [64, 64, 256] into a [256, 256, 3] image

## Final Generator

In [5]:
# important numbers
gen_filter = 64 # Number of filters in first layer of generator
dis_filter = 64 # Number of filters in first layer of discriminator
batch_size = 1 # batch_size
pool_size = 50 # pool_size
img_width = 256 # Imput image will of width 256
img_height = 256 # Input image will be of height 256
channels = 3 # RGB format

In [6]:
def build_resnet_block(input_res, num_features):

    out_res_1 = general_conv2d(input_res, num_features,
                               window_width=3,
                               window_height=3,
                               stride_width=1,
                               stride_height=1)
    out_res_2 = general_conv2d(out_res_1, num_features,
                               window_width=3,
                               window_height=3,
                               stride_width=1,
                               stride_height=1)
    return (out_res_2 + input_res)

In [23]:
tf.config.experimental_run_functions_eagerly(True)

In [14]:
def general_conv2d(inputconv, num_features=64, window_height=7, window_width=7, stride_height=1, stride_width=1, stddev=0.02, padding=None, name="conv2d", do_norm=True, do_relu=True):
    s = inputconv.shape[3]
    with tf.compat.v1.Session() as sess:
        with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE):

            w = tf.Variable(tf.compat.v1.truncated_normal([window_height, window_width, s, num_features], stddev=0.5))      
            conv = tf.nn.conv2d(inputconv, filters=w, strides=[1,stride_width,stride_height,1] , padding='SAME')
#             conv = tf.contrib.layers.conv2d(inputconv, num_features, window_height, stride_height, padding, activation_fn=None, weights_initializer=tf.compat.v1.truncated_normal_initializer(stddev=stddev),biases_initializer=tf.constant_initializer(0.0))
            biases = tf.compat.v1.get_variable('b_'+str(num_features),[num_features],initializer=tf.constant_initializer(0.0))
            conv = tf.nn.bias_add(conv,biases)
            if do_norm:
                dims = conv.shape
                scale = tf.compat.v1.get_variable('scale_'+str(dims[3]-dims[1])+'_'+str(num_features),shape=[dims[1],dims[2],dims[3]],initializer=tf.constant_initializer(1))
                beta = tf.compat.v1.get_variable('beta_'+str(dims[3]-dims[1])+'_'+str(num_features),shape=[dims[1],dims[2],dims[3]],initializer=tf.constant_initializer(0))
                conv_mean,conv_var = tf.nn.moments(conv,[0])
                conv = tf.nn.batch_normalization(conv,conv_mean,conv_var,beta,scale,0.001)
            if do_relu:
                conv = tf.nn.relu(conv)
    return conv

In [8]:
def build_generator(input_gen):
    # Encoding 
    
    # the first Conv layer, with the input info
    # returns a tensor with shape = [256, 256, 64]
    o_c1 = general_conv2d(input_gen, num_features=gen_filter, window_width=7, window_height=7, stride_width=1, stride_height=1)
    # the second Conv layer, with the previous tensor as input
    # returns a tensor with shape = [128, 128, 128]
    o_c2 = general_conv2d(o_c1, num_features=gen_filter*2, window_width=3, window_height=3, stride_width=2, stride_height=2)
    # the third Conv layer, with the previous tensor as input
    # returns a tensor with shape = [64, 64, 256]
    o_enc_A = general_conv2d(o_c2, num_features=gen_filter*4, window_width=3, window_height=3, stride_width=2, stride_height=2)

    # Transformation
    
    # input is the last tensor from the generator
    o_r1 = build_resnet_block(o_enc_A, num_features=64*4)
    # this chain goes on for a total of 6 resnet layers
    o_r2 = build_resnet_block(o_r1, num_features=64*4)
    o_r3 = build_resnet_block(o_r2, num_features=64*4)
    o_r4 = build_resnet_block(o_r3, num_features=64*4)
    o_r5 = build_resnet_block(o_r4, num_features=64*4)
    # the shape of the final tensor is reserved to be [64, 64, 256]
    o_enc_B = build_resnet_block(o_r5, num_features=64*4)

    #Decoding
    o_d1 = general_conv2d(o_enc_B, num_features=gen_filter*2, window_width=3, window_height=3, stride_width=2, stride_height=2)
    o_d2 = general_conv2d(o_d1, num_features=gen_filter, window_width=3, window_height=3, stride_width=2, stride_height=2)
    gen_B = general_conv2d(o_d2, num_features=3, window_width=7, window_height=7, stride_width=1, stride_height=1)

    return gen_B

# Discriminator
Has multiple Convolutional Layers

In [9]:
def build_discriminator(input_disc):
    o_c1 = general_conv2d(input_disc, dis_filter, 7, 7, 2, 2)
    o_c2 = general_conv2d(o_c1, dis_filter*2, 7, 7, 2, 2)
    o_enc_A = general_conv2d(o_c2, dis_filter*4, 7, 7, 2, 2)
    o_c4 = general_conv2d(o_enc_A, dis_filter*8, 7, 7, 2, 2)

    # making decision
    decision = general_conv2d(o_c4, 1, 7, 7, 1, 1, 0.02)
    return decision

# Model

In [10]:
tf.compat.v1.disable_eager_execution()

In [34]:
lr = 0.0002
def train():
    input_A = tf.compat.v1.placeholder(tf.float32, [1, 256, 256, 3], name="input_A")
    input_B = tf.compat.v1.placeholder(tf.float32, [1, 256, 256, 3], name="input_B")

    fake_pool_A = tf.compat.v1.placeholder(tf.float32, [None, 256, 256, 3], name="fake_pool_A")
    fake_pool_B = tf.compat.v1.placeholder(tf.float32, [None, 256, 256, 3], name="fake_pool_B")

    global_step = tf.Variable(0, name="global_step", trainable=False)
    num_fake_inputs = 0
    
    with tf.compat.v1.variable_scope("Model") as scope:
        fake_B = build_generator(input_A)
        fake_A = build_generator(input_B)
        rec_A = build_discriminator(input_A)
        rec_B = build_discriminator(input_B)
        scope.reuse_variables()
        fake_rec_A = build_discriminator(fake_A)
        fake_rec_B = build_discriminator(fake_B)
        cyc_A = build_generator(fake_B)
        cyc_B = build_generator(fake_A)
        scope.reuse_variables()
        fake_pool_rec_A = build_discriminator(fake_pool_A)
        fake_pool_rec_B = build_discriminator(fake_pool_B)
        
    # Loss functions
    
    # cyclic loss function to not lose the original input too much
    cyc_loss = tf.reduce_mean(tf.abs(input_A-cyc_A)) + tf.reduce_mean(tf.abs(input_B-cyc_B))
    
    # loss for the discriminator
    d_loss_A1 = tf.reduce_mean(tf.compat.v1.squared_difference(fake_rec_A,1))
    d_loss_B1 = tf.reduce_mean(tf.compat.v1.squared_difference(fake_rec_B,1))
    
    # loss for the generator
    g_loss_A = cyc_loss*10 + d_loss_B1
    g_loss_B = cyc_loss*10 + d_loss_A1
    
    d_loss_A = (tf.reduce_mean(tf.compat.v1.square(fake_pool_rec_A)) + tf.reduce_mean(tf.compat.v1.squared_difference(rec_A,1)))/2.0
    d_loss_B = (tf.reduce_mean(tf.compat.v1.square(fake_pool_rec_B)) + tf.reduce_mean(tf.compat.v1.squared_difference(rec_B,1)))/2.0
    
    lr = tf.compat.v1.placeholder(tf.float32, shape=[], name="lr")
    
    optimizer = tf.compat.v1.train.AdamOptimizer(lr)
    
    model_vars = tf.compat.v1.trainable_variables()
    
    d_A_vars = [var for var in model_vars if 'd_A' in var.name]
    g_A_vars = [var for var in model_vars if 'g_A' in var.name]
    d_B_vars = [var for var in model_vars if 'd_B' in var.name]
    g_B_vars = [var for var in model_vars if 'g_B' in var.name]
    
    # optimising 
    d_A_trainer = optimizer.minimize(d_loss_A, var_list=d_A_vars)
    d_B_trainer = optimizer.minimize(d_loss_B, var_list=d_B_vars)
    g_A_trainer = optimizer.minimize(g_loss_A, var_list=g_A_vars)
    g_B_trainer = optimizer.minimize(g_loss_B, var_list=g_B_vars)
    
    for var in model_vars: print(var.name)
        
    # Summary Variables
    g_A_loss_summ = tf.summary.scalar("g_A_loss", g_loss_A)
    g_B_loss_summ = tf.summary.scalar("g_B_loss", g_loss_B)
    d_A_loss_summ = tf.summary.scalar("d_A_loss", d_loss_A)
    d_B_loss_summ = tf.summary.scalar("d_B_loss", d_loss_B)
    
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    

        
#         coord = tf.train.Coordinator()
#         threads = tf.train.start_queue_runners(coord=coord)
#         num_monets = sess.run(queue_length_A)
#         num_photos= sess.run(queue_length_B)
#         images_A = []
#         images_B = []
        
#         # 10 here is num of max images
#         A_input = np.zeros((10,1,256, 256, 3))
#         B_input = np.zeros((10,1,256, 256, 3))
        
#         for i in range(10): 
#             image_tensor = sess.run(monet_ds)
#             A_input[i] = image_tensor.reshape((1,256, 256, 3))
            
#         for i in range(10):
#             image_tensor = sess.run(photos_ds)
#             B_input[i] = image_tensor.reshape((1,256, 256, 3))
            
#         coord.request_stop()
#         coord.join(threads)
        
#         for ptr in range(0,100):
#             fake_A_temp, fake_B_temp, cyc_A_temp, cyc_B_temp = sess.run([fake_A, fake_B, cyc_A, cyc_B],feed_dict={input_A:A_input[0], input_B:B_input[0]})
#             imsave("./output/fakeB_"+str(ptr)+".jpg",((fake_A_temp[0]+1)*127.5).astype(np.uint8))
#             imsave("./output/fakeA_"+str(ptr)+".jpg",((fake_B_temp[0]+1)*127.5).astype(np.uint8))
#             imsave("./output/cycA_"+str(ptr)+".jpg",((cyc_A_temp[0]+1)*127.5).astype(np.uint8))
#             imsave("./output/cycB_"+str(ptr)+".jpg",((cyc_B_temp[0]+1)*127.5).astype(np.uint8))
#             imsave("./output/inputA_"+str(ptr)+".jpg",((A_input[0][0]+1)*127.5).astype(np.uint8))
#             imsave("./output/inputB_"+str(ptr)+".jpg",((B_input[0][0]+1)*127.5).astype(np.uint8))
        
        
        # Loading images into the tensors
    with tf.Session() as sess:
        sess.run(init)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        num_files_A = sess.run(queue_length_A)
        num_files_B = sess.run(queue_length_B)

        images_A = []
        images_B = []

        # 50 is pool size
        fake_images_A = np.zeros((50,1,256, 256, 3))
        fake_images_B = np.zeros((50,1,256, 256, 3))

        A_input = np.zeros((10,1,256, 256, 3))
        B_input = np.zeros((10,1,256, 256, 3))

        for i in range(10): 
            image_tensor = sess.run(monet_ds)
            A_input[i] = image_tensor.reshape((1,256, 256, 3))

        for i in range(10):
            image_tensor = sess.run(photos_ds)
            B_input[i] = image_tensor.reshape((1,256, 256, 3))

        coord.request_stop()
        coord.join(threads)

        writer = tf.summary.FileWriter("./output/2")
        check_dir = "./output/checkpoints/"

        for epoch in range(sess.run(global_step),1):
            print ("In the epoch ", epoch)
            saver.save(sess,os.path.join(check_dir,"cyclegan"),global_step=epoch)

            if(epoch < 100) :
                curr_lr = 0.0002
            else:
                curr_lr = 0.0002 - 0.0002*(epoch-100)/100

            summary_str, cyc_A_temp = sess.run([summary_op, cyc_A],feed_dict={input_A:A_input[0], input_B:B_input[0]})
            imsave("./output/output_"+str(epoch)+".jpg",((cyc_A_temp[0]+1)*127.5).astype(np.uint8))
            imsave("./output/input.jpg",((A_input[0][0]+1)*127.5).astype(np.uint8))

            writer.add_summary(summary_str, epoch)
            
        for ptr in range(0,max_images):
            print("In the iteration ",ptr)
            print("Starting",time.time()*1000.0)
            
            # Optimizing the G_A network
            _, fake_B_temp, summary_str = sess.run([g_A_trainer, fake_B, g_A_loss_summ],feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr})
            writer.add_summary(summary_str, epoch*max_images + ptr)
            print("After gA", time.time()*1000.0)

            fake_B_temp1 = fake_image_pool(num_fake_inputs, fake_B_temp, fake_images_B)

            # Optimizing the D_B network
            
            _, summary_str = sess.run([d_B_trainer, d_B_loss_summ],feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr, fake_pool_B:fake_B_temp1})
            writer.add_summary(summary_str, epoch*max_images + ptr)
            print("After dB", time.time()*1000.0)

            # Optimizing the G_B network
            
            _, fake_A_temp, summary_str = sess.run([g_B_trainer, fake_A, g_B_loss_summ],feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr})
            writer.add_summary(summary_str, epoch*max_images + ptr)
            print("After gB", time.time()*1000.0)

            fake_A_temp1 = fake_image_pool(num_fake_inputs, fake_A_temp, fake_images_A)
            
            # Optimizing the D_A network
            
            _, summary_str = sess.run([d_A_trainer, d_A_loss_summ],feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr, fake_pool_A:fake_A_temp1})
            writer.add_summary(summary_str, epoch*max_images + ptr)
            print("After dA", time.time()*1000.0)
            num_fake_inputs+=1
            
            writer.add_summary(summary_str, epoch*max_images + ptr)
            print("After dA", time.time()*1000.0)
            num_fake_inputs+=1
            
                        
        sess.run(tf.assign(global_step, epoch + 1))
    writer.add_graph(sess.graph)

In [25]:
tf.executing_eagerly()

True

In [35]:
train()

ValueError: No variables to optimize.

In [None]:
input_A = tf.compat.v1.placeholder(tf.float32, [batch_size, img_width, img_height, img_layer], name="input_A")
input_B = tf.compat.v1.placeholder(tf.float32, [batch_size, img_width, img_height, img_layer], name="input_B")

In [None]:
def build_model(input_A, input_B):
    gen_B = build_generator(input_A, name="generator_AtoB")
    gen_A = build_generator(input_B, name="generator_BtoA")
    dec_A = build_discriminator(input_A, name="discriminator_A")
    dec_B = build_discriminator(input_B, name="discriminator_B")

    dec_gen_A = build_discriminator(gen_A, "discriminator_A")
    dec_gen_B = build_discriminator(gen_B, "discriminator_B")
    cyc_A = build_generator(gen_B, "generator_BtoA")
    cyc_B = build_generator(gen_A, "generator_AtoB")
    
    
    
    d_A_trainer = optimizer.minimize(d_loss_A, var_list=d_A_vars)
    d_B_trainer = optimizer.minimize(d_loss_B, var_list=d_B_vars)
    g_A_trainer = optimizer.minimize(g_loss_A, var_list=g_A_vars)
    g_B_trainer = optimizer.minimize(g_loss_B, var_list=g_B_vars)


In [None]:
for epoch in range(0,100):
    # Define the learning rate schedule. The learning rate is kept
    # constant upto 100 epochs and then slowly decayed
    if(epoch < 100) :
        curr_lr = 0.0002
    else:
        curr_lr = 0.0002 - 0.0002*(epoch-100)/100

    # Running the training loop for all batches
    for ptr in range(0,num_images):

        # Train generator G_A->B
        _, gen_B_temp = sess.run([g_A_trainer, gen_B],
                                 feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr})

        # We need gen_B_temp because to calculate the error in training D_B
        _ = sess.run([d_B_trainer],
                     feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr})

        # Same for G_B->A and D_A as follow
        _, gen_A_temp = sess.run([g_B_trainer, gen_A],
                                 feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr})
        _ = sess.run([d_A_trainer],
                     feed_dict={input_A:A_input[ptr], input_B:B_input[ptr], lr:curr_lr})

In [None]:
def random_crop(image):
  cropped_image = tf.image.random_crop(
      image, size=[IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image

# upsizing and then cropping the image randomly
def random_jitter(image):
  # resizing to 286 x 286 x 3
  image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  # randomly cropping to 256 x 256 x 3
  image = random_crop(image)

  # random mirroring
  image = tf.image.random_flip_left_right(image)

  return image

In [None]:
def preprocess_image_train(image, label):
  image = random_jitter(image)
  return image

In [None]:
# to view an example
example_monet = next(iter(monet_ds))
example_photo = next(iter(photo_ds))

plt.subplot(121)
plt.title('Photo')
plt.imshow(example_photo[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Monet')
plt.imshow(example_monet[0] * 0.5 + 0.5)

# Build the generator

We'll be using a UNET architecture for our CycleGAN. To build our generator, let's first define our `downsample` and `upsample` methods.

The `downsample`, as the name suggests, reduces the 2D dimensions, the width and height, of the image by the stride. The stride is the length of the step the filter takes. Since the stride is 2, the filter is applied to every other pixel, hence reducing the weight and height by 2.

We'll be using an instance normalization instead of batch normalization. As the instance normalization is not standard in the TensorFlow API, we'll use the layer from TensorFlow Add-ons.

In [None]:
OUTPUT_CHANNELS = 3

def downsample(filters, size, apply_instancenorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential()
    result.add(layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

    if apply_instancenorm:
        result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    result.add(layers.LeakyReLU())

    return result

`Upsample` does the opposite of downsample and increases the dimensions of the of the image. `Conv2DTranspose` does basically the opposite of a `Conv2D` layer.

In [None]:
def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential()
    result.add(layers.Conv2DTranspose(filters, size, strides=2,
                                      padding='same',
                                      kernel_initializer=initializer,
                                      use_bias=False))

    result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    if apply_dropout:
        result.add(layers.Dropout(0.5))

    result.add(layers.ReLU())

    return result

Let's build our generator!

The generator first downsamples the input image and then upsample while establishing long skip connections. Skip connections are a way to help bypass the vanishing gradient problem by concatenating the output of a layer to multiple layers instead of only one. Here we concatenate the output of the downsample layer to the upsample layer in a symmetrical fashion.

In [None]:
def Generator():
    inputs = layers.Input(shape=[256,256,3])

    # bs = batch size
    down_stack = [
        downsample(64, 4, apply_instancenorm=False), # (bs, 128, 128, 64)
        downsample(128, 4), # (bs, 64, 64, 128)
        downsample(256, 4), # (bs, 32, 32, 256)
        downsample(512, 4), # (bs, 16, 16, 512)
        downsample(512, 4), # (bs, 8, 8, 512)
        downsample(512, 4), # (bs, 4, 4, 512)
        downsample(512, 4), # (bs, 2, 2, 512)
        downsample(512, 4), # (bs, 1, 1, 512)
    ]

    up_stack = [
        upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
        upsample(512, 4), # (bs, 16, 16, 1024)
        upsample(256, 4), # (bs, 32, 32, 512)
        upsample(128, 4), # (bs, 64, 64, 256)
        upsample(64, 4), # (bs, 128, 128, 128)
    ]

    initializer = tf.random_normal_initializer(0., 0.02)
    last = layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                  strides=2,
                                  padding='same',
                                  kernel_initializer=initializer,
                                  activation='tanh') # (bs, 256, 256, 3)

    x = inputs

    # Downsampling through the model
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = layers.Concatenate()([x, skip])

    x = last(x)

    return keras.Model(inputs=inputs, outputs=x)

# Build the discriminator

The discriminator takes in the input image and classifies it as real or fake (generated). Instead of outputing a single node, the discriminator outputs a smaller 2D image with higher pixel values indicating a real classification and lower values indicating a fake classification.

In [None]:
def Discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    inp = layers.Input(shape=[256, 256, 3], name='input_image')

    x = inp

    down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)
    down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
    down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)

    zero_pad1 = layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
    conv = layers.Conv2D(512, 4, strides=1,
                         kernel_initializer=initializer,
                         use_bias=False)(zero_pad1) # (bs, 31, 31, 512)

    norm1 = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(conv)

    leaky_relu = layers.LeakyReLU()(norm1)

    zero_pad2 = layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

    last = layers.Conv2D(1, 4, strides=1,
                         kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)

    return tf.keras.Model(inputs=inp, outputs=last)

In [None]:
with strategy.scope():
    monet_generator = Generator() # transforms photos to Monet-esque paintings
    photo_generator = Generator() # transforms Monet paintings to be more like photos

    monet_discriminator = Discriminator() # differentiates real Monet paintings and generated Monet paintings
    photo_discriminator = Discriminator() # differentiates real photos and generated photos

Since our generators are not trained yet, the generated Monet-esque photo does not show what is expected at this point.

In [None]:
to_monet = monet_generator(example_photo)

plt.subplot(1, 2, 1)
plt.title("Original Photo")
plt.imshow(example_photo[0] * 0.5 + 0.5)

plt.subplot(1, 2, 2)
plt.title("Monet-esque Photo")
plt.imshow(to_monet[0] * 0.5 + 0.5)
plt.show()

# Build the CycleGAN model

We will subclass a `tf.keras.Model` so that we can run `fit()` later to train our model. During the training step, the model transforms a photo to a Monet painting and then back to a photo. The difference between the original photo and the twice-transformed photo is the cycle-consistency loss. We want the original photo and the twice-transformed photo to be similar to one another.

The losses are defined in the next section.

In [None]:
class CycleGan(keras.Model):
    def __init__(
        self,
        monet_generator,
        photo_generator,
        monet_discriminator,
        photo_discriminator,
        lambda_cycle=10,
    ):
        super(CycleGan, self).__init__()
        self.m_gen = monet_generator
        self.p_gen = photo_generator
        self.m_disc = monet_discriminator
        self.p_disc = photo_discriminator
        self.lambda_cycle = lambda_cycle
        
    def compile(
        self,
        m_gen_optimizer,
        p_gen_optimizer,
        m_disc_optimizer,
        p_disc_optimizer,
        gen_loss_fn,
        disc_loss_fn,
        cycle_loss_fn,
        identity_loss_fn
    ):
        super(CycleGan, self).compile()
        self.m_gen_optimizer = m_gen_optimizer
        self.p_gen_optimizer = p_gen_optimizer
        self.m_disc_optimizer = m_disc_optimizer
        self.p_disc_optimizer = p_disc_optimizer
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn
        
    def train_step(self, batch_data):
        real_monet, real_photo = batch_data
        
        with tf.GradientTape(persistent=True) as tape:
            # photo to monet back to photo
            fake_monet = self.m_gen(real_photo, training=True)
            cycled_photo = self.p_gen(fake_monet, training=True)

            # monet to photo back to monet
            fake_photo = self.p_gen(real_monet, training=True)
            cycled_monet = self.m_gen(fake_photo, training=True)

            # generating itself
            same_monet = self.m_gen(real_monet, training=True)
            same_photo = self.p_gen(real_photo, training=True)

            # discriminator used to check, inputing real images
            disc_real_monet = self.m_disc(real_monet, training=True)
            disc_real_photo = self.p_disc(real_photo, training=True)

            # discriminator used to check, inputing fake images
            disc_fake_monet = self.m_disc(fake_monet, training=True)
            disc_fake_photo = self.p_disc(fake_photo, training=True)

            # evaluates generator loss
            monet_gen_loss = self.gen_loss_fn(disc_fake_monet)
            photo_gen_loss = self.gen_loss_fn(disc_fake_photo)

            # evaluates total cycle consistency loss
            total_cycle_loss = self.cycle_loss_fn(real_monet, cycled_monet, self.lambda_cycle) + self.cycle_loss_fn(real_photo, cycled_photo, self.lambda_cycle)

            # evaluates total generator loss
            total_monet_gen_loss = monet_gen_loss + total_cycle_loss + self.identity_loss_fn(real_monet, same_monet, self.lambda_cycle)
            total_photo_gen_loss = photo_gen_loss + total_cycle_loss + self.identity_loss_fn(real_photo, same_photo, self.lambda_cycle)

            # evaluates discriminator loss
            monet_disc_loss = self.disc_loss_fn(disc_real_monet, disc_fake_monet)
            photo_disc_loss = self.disc_loss_fn(disc_real_photo, disc_fake_photo)

        # Calculate the gradients for generator and discriminator
        monet_generator_gradients = tape.gradient(total_monet_gen_loss,
                                                  self.m_gen.trainable_variables)
        photo_generator_gradients = tape.gradient(total_photo_gen_loss,
                                                  self.p_gen.trainable_variables)

        monet_discriminator_gradients = tape.gradient(monet_disc_loss,
                                                      self.m_disc.trainable_variables)
        photo_discriminator_gradients = tape.gradient(photo_disc_loss,
                                                      self.p_disc.trainable_variables)

        # Apply the gradients to the optimizer
        self.m_gen_optimizer.apply_gradients(zip(monet_generator_gradients,
                                                 self.m_gen.trainable_variables))

        self.p_gen_optimizer.apply_gradients(zip(photo_generator_gradients,
                                                 self.p_gen.trainable_variables))

        self.m_disc_optimizer.apply_gradients(zip(monet_discriminator_gradients,
                                                  self.m_disc.trainable_variables))

        self.p_disc_optimizer.apply_gradients(zip(photo_discriminator_gradients,
                                                  self.p_disc.trainable_variables))
        
        return {
            "monet_gen_loss": total_monet_gen_loss,
            "photo_gen_loss": total_photo_gen_loss,
            "monet_disc_loss": monet_disc_loss,
            "photo_disc_loss": photo_disc_loss
        }

# Define loss functions

The discriminator loss function below compares real images to a matrix of 1s and fake images to a matrix of 0s. The perfect discriminator will output all 1s for real images and all 0s for fake images. The discriminator loss outputs the average of the real and generated loss.

In [None]:
with strategy.scope():
    def discriminator_loss(real, generated):
        real_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(real), real)

        generated_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.zeros_like(generated), generated)

        total_disc_loss = real_loss + generated_loss

        return total_disc_loss * 0.5

The generator wants to fool the discriminator into thinking the generated image is real. The perfect generator will have the discriminator output only 1s. Thus, it compares the generated image to a matrix of 1s to find the loss.

In [None]:
with strategy.scope():
    def generator_loss(generated):
        return tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(generated), generated)

We want our original photo and the twice transformed photo to be similar to one another. Thus, we can calculate the cycle consistency loss be finding the average of their difference.

In [None]:
with strategy.scope():
    def calc_cycle_loss(real_image, cycled_image, LAMBDA):
        loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))

        return LAMBDA * loss1

The identity loss compares the image with its generator (i.e. photo with photo generator). If given a photo as input, we want it to generate the same image as the image was originally a photo. The identity loss compares the input with the output of the generator.

In [None]:
with strategy.scope():
    def identity_loss(real_image, same_image, LAMBDA):
        loss = tf.reduce_mean(tf.abs(real_image - same_image))
        return LAMBDA * 0.5 * loss

# Train the CycleGAN

Let's compile our model. Since we used `tf.keras.Model` to build our CycleGAN, we can just ude the `fit` function to train our model.

In [None]:
with strategy.scope():
    monet_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

    monet_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
with strategy.scope():
    cycle_gan_model = CycleGan(
        monet_generator, photo_generator, monet_discriminator, photo_discriminator
    )

    cycle_gan_model.compile(
        m_gen_optimizer = monet_generator_optimizer,
        p_gen_optimizer = photo_generator_optimizer,
        m_disc_optimizer = monet_discriminator_optimizer,
        p_disc_optimizer = photo_discriminator_optimizer,
        gen_loss_fn = generator_loss,
        disc_loss_fn = discriminator_loss,
        cycle_loss_fn = calc_cycle_loss,
        identity_loss_fn = identity_loss
    )

In [None]:
cycle_gan_model.fit(
    tf.data.Dataset.zip((monet_ds, photo_ds)),
    epochs=25
)

# Visualize our Monet-esque photos

In [None]:
_, ax = plt.subplots(5, 2, figsize=(12, 12))
for i, img in enumerate(photo_ds.take(5)):
    prediction = monet_generator(img, training=False)[0].numpy()
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)

    ax[i, 0].imshow(img)
    ax[i, 1].imshow(prediction)
    ax[i, 0].set_title("Input Photo")
    ax[i, 1].set_title("Monet-esque")
    ax[i, 0].axis("off")
    ax[i, 1].axis("off")
plt.show()

# Create submission file

In [None]:
import PIL
! mkdir ../images

In [None]:
i = 1
for img in photo_ds:
    prediction = monet_generator(img, training=False)[0].numpy()
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    im = PIL.Image.fromarray(prediction)
    im.save("../images/" + str(i) + ".jpg")
    i += 1

In [None]:
import shutil
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/images")