# Import libraries

In [6]:
import os
import numpy as np
import matplotlib.pyplot as plt

### Tensorflow dependencies ###
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l1
from tensorflow.keras.losses import BinaryCrossentropy

### Some constants ###
lr = 1e-4
batch_size = 64
epochs = 100
img_shape = (112, 112, 3)
latent_dim = 100

### Discriminators and Generators weights path ###
g_weights_path = "weights/g.weights.hdf5" # G : X -> Y
f_weights_path = "weights/f.weights.hdf5" # F : Y -> X
x_weights_path = "weights/x.weights.hdf5" # discriminate X and F(Y)
y_weights_path = "weights/y.weights.hdf5" # discriminate Y and G(X)

# Loading and processing data

In [5]:
dataset, metadata = tfds.load('cycle_gan/horse2zebra',
                              with_info=True, as_supervised=True)

train_horses, train_zebras = dataset['trainA'], dataset['trainB']
test_horses, test_zebras = dataset['testA'], dataset['testB'] 

def random_crop(image):
    cropped_image = tf.image.random_crop(image, size=[img_shape[0], img_shape[1], 3])

    return cropped_image

# normalizing the images to [-1, 1]
def normalize(image):
    image = tf.cast(image, tf.float32)
    image = (image / 127.5) - 1
    return image

def random_jitter(image):
    # resizing to 286 x 286 x 3
    image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    # randomly cropping to 112 x 112 x 3
    image = random_crop(image)

    # random mirroring
    image = tf.image.random_flip_left_right(image)

    return image

def preprocess_image_train(image, label):
    image = random_jitter(image)
    image = normalize(image)
    return image

def preprocess_image_test(image, label):
    image = normalize(image)
    return image

train_horses = train_horses.map(
    preprocess_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(batch_size)

train_zebras = train_zebras.map(
    preprocess_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(batch_size)

test_horses = test_horses.map(
    preprocess_image_test, num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(batch_size)

test_zebras = test_zebras.map(
    preprocess_image_test, num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(batch_size)


# Model architectures

## 1. Generator G and F

In [3]:
def make_generator(name):
    inputs = Input(shape=(latent_dim,))
    
    x = Dense(7*7*256, use_bias=False)(inputs)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Reshape(target_shape=(7,7,256))(x)
    
    # Size = 128 x 14 x 14 
    x = Conv2DTranspose(128, kernel_size=(4,4), strides=(2,2), padding='same',
                        use_bias=False)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    # Size = 64 x 28 x 28
    x = Conv2DTranspose(64, kernel_size=(4,4), strides=(2,2), padding='same',
                       use_bias=False)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    # Size = 32 x 56 x 56
    x = Conv2DTranspose(32, kernel_size=(4,4), strides=(2,2), padding='same',
                       use_bias=False)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    # Size = 16 x 112 x 112
    x = Conv2DTranspose(16, kernel_size=(4,4), strides=(2,2), padding='same',
                       use_bias=False)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    # Use tanh so that it is [-1, 1]
    x = Conv2D(img_shape[-1], kernel_size=(4,4), padding='same', use_bias=False, activation='tanh')(x)
    
    model = Model(inputs=inputs, outputs=x, name=f'Generator_{name}')
    return model

G = make_generator('G')
F = make_generator('F')

print(G.summary())
print(F.summary())

Model: "Generator_G"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 100)]             0         
_________________________________________________________________
dense (Dense)                (None, 12544)             1254400   
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 12544)             0         
_________________________________________________________________
reshape (Reshape)            (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 14, 14, 128)       524288    
_________________________________________________________________
batch_normalization (BatchNo (None, 14, 14, 128)       512       
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 14, 14, 128)       

## Discriminator X and Y

In [4]:
def make_discriminator(name):
    inputs = Input(shape=img_shape)
    
    x = Conv2D(16, kernel_size=(4,4), strides=(2,2), padding='same')(inputs)
    x = LeakyReLU(alpha=0.2)(x)
    x = Dropout(0.3)(x)
    
    x = Conv2D(32, kernel_size=(4,4), strides=(2,2), padding='same')(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Dropout(0.3)(x)
    
    x = Conv2D(64, kernel_size=(4,4), strides=(2,2), padding='same')(x)
    x = LeakyReLU()(x)
    x = Dropout(0.3)(x)
    
    x = Conv2D(128, kernel_size=(4,4), strides=(2,2), padding='same')(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Dropout(0.3)(x)
    
    x = Flatten()(x)
    x = Dense(1)(x)
    
    model = Model(inputs=inputs, outputs=x, name=f'Discriminator_{name}')
    return model

X = make_discriminator("X")
Y = make_discriminator("Y")

print(X.summary())
print(Y.summary())

Model: "Discriminator_X"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         [(None, 112, 112, 3)]     0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 56, 56, 16)        784       
_________________________________________________________________
leaky_re_lu_10 (LeakyReLU)   (None, 56, 56, 16)        0         
_________________________________________________________________
dropout (Dropout)            (None, 56, 56, 16)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 28, 28, 32)        8224      
_________________________________________________________________
leaky_re_lu_11 (LeakyReLU)   (None, 28, 28, 32)        0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 28, 28, 32)    

# Define loss functions 

In [5]:
bce = BinaryCrossentropy(from_logits=True)
LAMBDA = 10.0 # scale factor for cycle consistency and identity loss functions

def generator_loss(D_fake):
    ones = tf.ones_like(D_fake, dtype=tf.float32)
    loss = bce(ones, D_fake)
    
    return loss

def discriminator_loss(D_real, D_fake):
    ones = tf.ones_like(D_real, dtype=tf.float32)
    zeros = tf.zeros_like(D_fake, dtype=tf.float32)
    
    real_loss = bce(ones, D_real)
    fake_loss = bce(zeros, D_fake)
    loss = real_loss + fake_loss
    
    return loss

def cycle_consistency_loss(cycled_images, real_images):
    return LAMBDA * tf.reduce_mean(tf.abs(cycled_images - real_images))

def identity_loss(same_images, real_images):
    return LAMBDA * 0.5 * tf.reduce_mean(tf.abs(same_images - real_images))

# Define training loop