# Import libraries

In [7]:
import os
import numpy as np
import matplotlib.pyplot as plt

### Tensorflow dependencies ###
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix

from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l1
from tensorflow.keras.losses import BinaryCrossentropy

### Some constants ###
lr = 1e-4
batch_size = 1
epochs = 100
img_shape = (256, 256, 3)
latent_dim = 100

### Discriminators and Generators weights path ###
g_weights_path = "weights/g.weights.hdf5" # G : X -> Y
f_weights_path = "weights/f.weights.hdf5" # F : Y -> X
x_weights_path = "weights/x.weights.hdf5" # discriminate X and F(Y)
y_weights_path = "weights/y.weights.hdf5" # discriminate Y and G(X)

# Loading and processing data

In [8]:
dataset, metadata = tfds.load('cycle_gan/horse2zebra',
                              with_info=True, as_supervised=True)

train_horses, train_zebras = dataset['trainA'], dataset['trainB']
test_horses, test_zebras = dataset['testA'], dataset['testB'] 

def random_crop(image):
    cropped_image = tf.image.random_crop(image, size=[img_shape[0], img_shape[1], 3])

    return cropped_image

# normalizing the images to [-1, 1]
def normalize(image):
    image = tf.cast(image, tf.float32)
    image = (image / 127.5) - 1
    return image

def random_jitter(image):
    # resizing to 286 x 286 x 3
    image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    # randomly cropping to 112 x 112 x 3
    image = random_crop(image)

    # random mirroring
    image = tf.image.random_flip_left_right(image)

    return image

def preprocess_image_train(image, label):
    image = random_jitter(image)
    image = normalize(image)
    return image

def preprocess_image_test(image, label):
    image = normalize(image)
    return image

train_horses = train_horses.map(
    preprocess_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(batch_size)

train_zebras = train_zebras.map(
    preprocess_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(batch_size)

test_horses = test_horses.map(
    preprocess_image_test, num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(batch_size)

test_zebras = test_zebras.map(
    preprocess_image_test, num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(batch_size)


# Model architectures

## 1. Generator G and F

In [9]:
# def make_generator(name):
#     inputs = Input(shape=(latent_dim,))
    
#     x = Dense(7*7*256, use_bias=False)(inputs)
#     x = LeakyReLU(alpha=0.2)(x)
    
#     x = Reshape(target_shape=(7,7,256))(x)
    
#     # Size = 128 x 14 x 14 
#     x = Conv2DTranspose(128, kernel_size=(4,4), strides=(2,2), padding='same',
#                         use_bias=False)(x)
#     x = BatchNormalization()(x)
#     x = LeakyReLU(alpha=0.2)(x)
    
#     # Size = 64 x 28 x 28
#     x = Conv2DTranspose(64, kernel_size=(4,4), strides=(2,2), padding='same',
#                        use_bias=False)(x)
#     x = BatchNormalization()(x)
#     x = LeakyReLU(alpha=0.2)(x)
    
#     # Size = 32 x 56 x 56
#     x = Conv2DTranspose(32, kernel_size=(4,4), strides=(2,2), padding='same',
#                        use_bias=False)(x)
#     x = BatchNormalization()(x)
#     x = LeakyReLU(alpha=0.2)(x)
    
#     # Size = 16 x 112 x 112
#     x = Conv2DTranspose(16, kernel_size=(4,4), strides=(2,2), padding='same',
#                        use_bias=False)(x)
#     x = BatchNormalization()(x)
#     x = LeakyReLU(alpha=0.2)(x)
    
#     # Use tanh so that it is [-1, 1]
#     x = Conv2D(img_shape[-1], kernel_size=(4,4), padding='same', use_bias=False, activation='tanh')(x)
    
#     model = Model(inputs=inputs, outputs=x, name=f'Generator_{name}')
#     return model

G = pix2pix.unet_generator(img_shape[-1], norm_type='instancenorm') # make_generator('G')
F = pix2pix.unet_generator(img_shape[-1], norm_type='instancenorm') # make_generator('F')

print(G.summary())
print(F.summary())

Model: "functional_9"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            [(None, None, None,  0                                            
__________________________________________________________________________________________________
sequential_36 (Sequential)      (None, None, None, 6 3072        input_3[0][0]                    
__________________________________________________________________________________________________
sequential_37 (Sequential)      (None, None, None, 1 131328      sequential_36[0][0]              
__________________________________________________________________________________________________
sequential_38 (Sequential)      (None, None, None, 2 524800      sequential_37[0][0]              
_______________________________________________________________________________________

## Discriminator X and Y

In [10]:
# def make_discriminator(name):
#     inputs = Input(shape=img_shape)
    
#     x = Conv2D(16, kernel_size=(4,4), strides=(2,2), padding='same')(inputs)
#     x = LeakyReLU(alpha=0.2)(x)
#     x = Dropout(0.3)(x)
    
#     x = Conv2D(32, kernel_size=(4,4), strides=(2,2), padding='same')(x)
#     x = LeakyReLU(alpha=0.2)(x)
#     x = Dropout(0.3)(x)
    
#     x = Conv2D(64, kernel_size=(4,4), strides=(2,2), padding='same')(x)
#     x = LeakyReLU()(x)
#     x = Dropout(0.3)(x)
    
#     x = Conv2D(128, kernel_size=(4,4), strides=(2,2), padding='same')(x)
#     x = LeakyReLU(alpha=0.2)(x)
#     x = Dropout(0.3)(x)
    
#     x = Flatten()(x)
#     x = Dense(1)(x)
    
#     model = Model(inputs=inputs, outputs=x, name=f'Discriminator_{name}')
#     return model

D_x = pix2pix.discriminator(norm_type='instancenorm', target=False) # make_discriminator("X")
D_y = pix2pix.discriminator(norm_type='instancenorm', target=False) # make_discriminator("Y")

print(D_x.summary())
print(D_y.summary())

Model: "functional_13"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_image (InputLayer)     [(None, None, None, 3)]   0         
_________________________________________________________________
sequential_66 (Sequential)   (None, None, None, 64)    3072      
_________________________________________________________________
sequential_67 (Sequential)   (None, None, None, 128)   131328    
_________________________________________________________________
sequential_68 (Sequential)   (None, None, None, 256)   524800    
_________________________________________________________________
zero_padding2d_4 (ZeroPaddin (None, None, None, 256)   0         
_________________________________________________________________
conv2d_45 (Conv2D)           (None, None, None, 512)   2097152   
_________________________________________________________________
instance_normalization_64 (I (None, None, None, 512) 

# Define loss functions 

In [11]:
bce = BinaryCrossentropy(from_logits=True)
LAMBDA = 10.0 # scale factor for cycle consistency and identity loss functions

def generator_loss(D_fake):
    ones = tf.ones_like(D_fake, dtype=tf.float32)
    loss = bce(ones, D_fake)
    
    return loss

def discriminator_loss(D_real, D_fake):
    ones = tf.ones_like(D_real, dtype=tf.float32)
    zeros = tf.zeros_like(D_fake, dtype=tf.float32)
    
    real_loss = bce(ones, D_real)
    fake_loss = bce(zeros, D_fake)
    loss = real_loss + fake_loss
    
    return loss

def cycle_consistency_loss(cycled_images, real_images):
    return LAMBDA * tf.reduce_mean(tf.abs(cycled_images - real_images))

def identity_loss(same_images, real_images):
    return LAMBDA * 0.5 * tf.reduce_mean(tf.abs(same_images - real_images))

# Define training loop

In [13]:
### Preparing optimizers ###
g_opt = Adam(lr=lr, beta_1=0.5, amsgrad=True)
f_opt = Adam(lr=lr, beta_1=0.5, amsgrad=True)
x_opt = Adam(lr=lr, beta_1=0.5, amsgrad=True)
y_opt = Adam(lr=lr, beta_1=0.5, amsgrad=True)

@tf.function
def train_step(X, Y):
    ### F : Y -> X, G : X -> Y ###
    ### X : separate X and F(Y), Y : separate Y and G(X) ###
    with tf.GradientTape(persistent=True) as tape:
        # For binary crossentropy 
        fake_y = G(X, training=True)
        fake_x = F(Y, training=True)
        
        # For cycle consistency losses 
        cycled_y = F(fake_y, training=True)
        cycled_x = G(fake_x, training=True)
        
        # For identity losses 
        same_x   = F(X, training=True)
        same_y   = G(Y, training=True)
        
        D_X_real = D_x(X, training=True)
        D_Y_real = D_y(Y, training=True)
        D_X_fake = D_x(fake_x, training=True)
        D_Y_fake = D_y(fake_y, training=True)
        
        # Training generators
        bce_g = generator_loss(D_Y_fake)
        bce_f = generator_loss(D_X_fake)
        
        total_cycle_loss = cycle_consistency_loss(cycled_x, X) + cycle_consistency_loss(cycled_y, Y)
        
        identity_g = identity_loss(same_y, Y)
        identity_f = identity_loss(same_x, X)
        
        total_loss_g = bce_g + total_cycle_loss + identity_g
        total_loss_f = bce_f + total_cycle_loss + identity_f
        
        # Training discriminators 
        total_loss_Dx = discriminator_loss(D_X_real, D_X_fake)
        total_loss_Dy = discriminator_loss(D_Y_real, D_Y_fake)
        
    # Calculate gradients 
    grad_g = tape.gradient(total_loss_g, G.trainable_variables)
    grad_f = tape.gradient(total_loss_f, F.trainable_variables)
    grad_Dx = tape.gradient(total_loss_Dx, D_x.trainable_variables)
    grad_Dy = tape.gradient(total_loss_Dy, D_y.trainable_variables)
    
    # Apply gradients on trainable variables
    g_opt.apply_gradients(zip(grad_g, G.trainable_variables))
    f_opt.apply_gradients(zip(grad_f, F.trainable_variables))
    x_opt.apply_gradients(zip(grad_Dx, D_x.trainable_variables))
    y_opt.apply_gradients(zip(grad_Dy, D_y.trainable_variables))
    
    return total_loss_g, total_loss_f, total_loss_Dx, total_loss_Dy
    
def train(datasetX, datasetY):
    for i in range(epochs):
        batch_id = 0
        for X, Y in tf.data.Dataset.zip((datasetX, datasetY)):
            batch_id += 1
            g_loss, f_loss, x_loss, y_loss = train_step(X, Y)
            print('[INFO] Batch #%d | Epoch #%d, G_loss = %.2f, F_loss = %.2f, X_loss = %.2f, Y_loss = %.2f'
                  % (batch_id, i+1 , g_loss.numpy(), f_loss.numpy(), x_loss.numpy(), y_loss.numpy()))
            
train(train_horses, train_zebras)

[INFO] Batch #1 | Epoch #1, G_loss = 8.28, F_loss = 7.35, X_loss = 1.37, Y_loss = 1.49
[INFO] Batch #2 | Epoch #1, G_loss = 9.93, F_loss = 10.43, X_loss = 1.97, Y_loss = 1.80
[INFO] Batch #3 | Epoch #1, G_loss = 12.37, F_loss = 12.73, X_loss = 1.71, Y_loss = 1.51
[INFO] Batch #4 | Epoch #1, G_loss = 14.87, F_loss = 15.31, X_loss = 1.31, Y_loss = 1.57
[INFO] Batch #5 | Epoch #1, G_loss = 11.74, F_loss = 11.67, X_loss = 1.51, Y_loss = 1.40
[INFO] Batch #6 | Epoch #1, G_loss = 7.81, F_loss = 8.15, X_loss = 1.44, Y_loss = 1.43
[INFO] Batch #7 | Epoch #1, G_loss = 11.55, F_loss = 11.71, X_loss = 1.27, Y_loss = 1.31
[INFO] Batch #8 | Epoch #1, G_loss = 11.01, F_loss = 10.94, X_loss = 1.46, Y_loss = 1.45
[INFO] Batch #9 | Epoch #1, G_loss = 8.91, F_loss = 8.36, X_loss = 1.36, Y_loss = 1.53
[INFO] Batch #10 | Epoch #1, G_loss = 7.34, F_loss = 7.22, X_loss = 1.51, Y_loss = 1.49
[INFO] Batch #11 | Epoch #1, G_loss = 10.93, F_loss = 11.39, X_loss = 1.35, Y_loss = 1.27
[INFO] Batch #12 | Epoch #1,

[INFO] Batch #93 | Epoch #1, G_loss = 10.58, F_loss = 10.32, X_loss = 1.28, Y_loss = 1.48
[INFO] Batch #94 | Epoch #1, G_loss = 8.49, F_loss = 8.63, X_loss = 1.49, Y_loss = 1.37
[INFO] Batch #95 | Epoch #1, G_loss = 12.28, F_loss = 12.46, X_loss = 1.10, Y_loss = 1.46
[INFO] Batch #96 | Epoch #1, G_loss = 14.14, F_loss = 13.78, X_loss = 1.47, Y_loss = 1.36
[INFO] Batch #97 | Epoch #1, G_loss = 12.09, F_loss = 12.15, X_loss = 1.12, Y_loss = 1.32
[INFO] Batch #98 | Epoch #1, G_loss = 11.84, F_loss = 11.03, X_loss = 1.43, Y_loss = 1.41
[INFO] Batch #99 | Epoch #1, G_loss = 11.16, F_loss = 10.61, X_loss = 1.24, Y_loss = 1.29
[INFO] Batch #100 | Epoch #1, G_loss = 12.13, F_loss = 11.33, X_loss = 1.02, Y_loss = 1.27
[INFO] Batch #101 | Epoch #1, G_loss = 9.24, F_loss = 8.71, X_loss = 0.93, Y_loss = 1.28
[INFO] Batch #102 | Epoch #1, G_loss = 9.68, F_loss = 9.38, X_loss = 0.99, Y_loss = 1.23
[INFO] Batch #103 | Epoch #1, G_loss = 9.95, F_loss = 10.30, X_loss = 1.42, Y_loss = 1.32
[INFO] Batch 

[INFO] Batch #184 | Epoch #1, G_loss = 12.27, F_loss = 12.46, X_loss = 1.27, Y_loss = 1.45
[INFO] Batch #185 | Epoch #1, G_loss = 12.57, F_loss = 11.97, X_loss = 1.34, Y_loss = 1.42
[INFO] Batch #186 | Epoch #1, G_loss = 9.82, F_loss = 9.84, X_loss = 1.63, Y_loss = 1.51
[INFO] Batch #187 | Epoch #1, G_loss = 11.12, F_loss = 11.29, X_loss = 1.46, Y_loss = 1.46
[INFO] Batch #188 | Epoch #1, G_loss = 12.30, F_loss = 11.51, X_loss = 1.04, Y_loss = 1.30
[INFO] Batch #189 | Epoch #1, G_loss = 11.18, F_loss = 10.40, X_loss = 1.53, Y_loss = 1.43
[INFO] Batch #190 | Epoch #1, G_loss = 12.22, F_loss = 11.79, X_loss = 1.31, Y_loss = 1.35
[INFO] Batch #191 | Epoch #1, G_loss = 9.92, F_loss = 8.89, X_loss = 1.42, Y_loss = 1.40
[INFO] Batch #192 | Epoch #1, G_loss = 9.91, F_loss = 10.12, X_loss = 1.57, Y_loss = 1.44
[INFO] Batch #193 | Epoch #1, G_loss = 8.61, F_loss = 8.60, X_loss = 1.27, Y_loss = 1.52
[INFO] Batch #194 | Epoch #1, G_loss = 8.30, F_loss = 9.11, X_loss = 1.53, Y_loss = 1.47
[INFO] B

KeyboardInterrupt: 