In [3]:
import os
import io
import cv2
from PIL import Image
import numpy as np
import tensorflow as tf
from tensorflow.keras.initializers import RandomNormal
import matplotlib.pyplot as plt
from tqdm import tqdm
import tensorflow.keras.backend as K
from sklearn.preprocessing import MinMaxScaler
import tensorflow_addons as tfa
from IPython import display
autotune = tf.data.AUTOTUNE

## Read the data

In [20]:
X1 = os.listdir("/kaggle/input/cyclegan/horse2zebra/horse2zebra/trainA")
X1 = [os.path.join("/kaggle/input/cyclegan/horse2zebra/horse2zebra/trainA", x) for x in X1][:1000]

X2 = os.listdir("/kaggle/input/cyclegan/horse2zebra/horse2zebra/trainB")
X2 = [os.path.join("/kaggle/input/cyclegan/horse2zebra/horse2zebra/trainB", x) for x in X2][:1000]

In [21]:
#read images
def read_images(file_paths):
    imgs = []
    for file_path in tqdm(file_paths):
        img = cv2.imread(file_path, cv2.IMREAD_COLOR)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (128,128))
        img = img.astype(np.float32)
        img = (img - 127.5) / 127.5
        imgs.append(img)
    return np.array(imgs)

van = read_images(X1)
real = read_images(X2)

def denormalize_img(img):
    return ((img * 127.5) + 127.5).astype(np.uint8)


def mini_batches_(X, Y, batch_size=64):
    """
    function to produce minibatches for training
    :param X: input placeholder
    :param Y: mask placeholder
    :param batch_size: size of each batch
    :return:
    minibatches for training
    
    """
    batches= []
    train_length = len(X)
    num_batches = int(np.floor(train_length / batch_size))
    for i in tqdm(range(num_batches)):
        batch_x = X[i * batch_size: i * batch_size + batch_size]
        batch_y = Y[i * batch_size: i * batch_size + batch_size]
        batch_x = read_images(batch_x)
        batch_y = read_images(batch_y)
        batches.append([batch_x,batch_y])
    return batches
#gan_dataset = mini_batches_(X1, X2, batch_size=1)

100%|██████████| 1000/1000 [00:04<00:00, 234.81it/s]
100%|██████████| 1000/1000 [00:03<00:00, 254.51it/s]


In [6]:
van.shape

(600, 128, 128, 3)

In [None]:
#read images
def read_images(file_paths):
    imgs = []
    for file_path in file_paths:
        img = cv2.imread(file_path, cv2.IMREAD_COLOR)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (128,128))
        img = img.astype(np.float32)
        img = (img - 127.5) / 127.5
        imgs.append(img)
    return np.array(imgs)

def denormalize_img(img):
    return ((img * 127.5) + 127.5).astype(np.uint8)

def mini_batches_(X, batch_size=64):
    
    """function to produce minibatches for training
    :param X: input placeholder
    :param Y: mask placeholder
    :param batch_size: size of each batch
    :return:
    minibatches for training"""
    
    
    images_batch = []
    train_length = len(X)
    num_batches = int(np.floor(train_length / batch_size))
    for i in tqdm(range(num_batches)):
        batch_x = X[i * batch_size: i * batch_size + batch_size]
        batch_x = read_images(batch_x)
        images_batch.append(batch_x)
    return images_batch
van = mini_batches_(X1, batch_size=1)
real = mini_batches_(X2, batch_size=1)

In [None]:
zip((van,real))

In [None]:
#Visualize the examples
for i in range(10):
    plt.figure(figsize=(9, 9))
    plt.subplot(231)
    plt.title("\n\nVangogh")
    plt.imshow(denormalize_img(van[i]))
    plt.axis('off')
    plt.subplot(232)
    plt.imshow(denormalize_img(real[i]))
    plt.title("\n\nReal Photo")
    plt.axis('off')

## CycleGAN

In [22]:
from tensorflow.keras import layers
def conv2d(x, filter_size, filters, stride = 1, padding = 'same'):
    x = tf.keras.layers.Conv2D(filters,(filter_size,filter_size), strides =(stride,stride), padding = padding, kernel_initializer = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02))(x)
    return x

def batch_norm(x):
    x = tf.keras.layers.BatchNormalization()(x)
    return x

def instance_norm(x):
    return tfa.layers.InstanceNormalization(gamma_initializer = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02))(x)

def relu(x):
    return tf.keras.activations.relu(x, alpha = 0.2)

def leaky_relu(x):
    return tf.keras.layers.LeakyReLU(alpha = 0.2)(x)

class ReflectionPadding(tf.keras.layers.Layer):
    def __init__(self, dim_padding = (1,1)):
        super(ReflectionPadding, self).__init__()
        self.padding = dim_padding
    def call(self, tensor):
        x = tf.pad(tensor, [[0,0], [self.padding[1], self.padding[1]], [self.padding[0], self.padding[0]], [0,0]], 
                  mode = "REFLECT")
        return x

### CycleGAN Network

In [23]:
class CycleGAN(tf.keras.Model):
    def __init__(self):
        super(CycleGAN, self).__init__()
        
        self.gen_a = self.generator()
        self.gen_b = self.generator()
        
        self.disc_a = self.discriminator()
        self.disc_b = self.discriminator()
        
     
        
        self.kernel_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

        self.mse = tf.keras.losses.MeanSquaredError()
        self.cycle_loss = tf.keras.losses.MeanAbsoluteError()
        self.identity_loss = tf.keras.losses.MeanAbsoluteError()
    
    def resnet_block(self, prev_input):
        x = prev_input
        dims = x.shape[-1]
        
        #block1
        #using reflection padding as described in the paper
        x = ReflectionPadding()(x)
        x = conv2d(x, 3, dims, stride = 1, padding = "valid")
        x = instance_norm(x)
        x = tf.keras.activations.relu(x, alpha = 0.2)
        
        #block 2
        x = ReflectionPadding()(x)
        x = conv2d(x, 3, dims, stride = 1, padding = "valid")
        x = instance_norm(x)
        x = tf.keras.activations.relu(x, alpha = 0.2)
        
        out = tf.keras.layers.add([prev_input, x])
        return out
    
    def conv_transpose_block(self, x, filter_size, filters, stride = 2 , padding = 'same' , norm = True, activation = True):
        x = tf.keras.layers.Conv2DTranspose(filters, filter_size, strides = (stride,stride), padding =padding)(x)
        if norm:
            x = instance_norm(x)
        if activation:
            x = tf.keras.activations.relu(x,alpha = 0.2)
        return x
    
    
    #define generator
    def generator(self, resnet_blocks = 6, conv_blocks = 2, conv_transpose_block = 2):
        filters = 64
        #input layer
        input_layer = tf.keras.layers.Input(shape = (128,128,3))
        #generator architecture
        x = ReflectionPadding(dim_padding = (3,3))(input_layer)
        x = conv2d(x,7,filters, padding = 'valid')
        x = instance_norm(x)
        x = relu(x)
        
        #downsampling
        for _ in range(conv_blocks):
            filters*=2
            x = conv2d(x,3,filters,stride = 2)
            x = instance_norm(x)
            x = relu(x)
        
        #resnet block
        for _ in range(resnet_blocks):
            x = self.resnet_block(x)
        
        # upsampling block
        for _ in range(conv_transpose_block):
            filters//=2
            x = self.conv_transpose_block(x, 3, filters, stride = 2)
        
        #output block
        x = ReflectionPadding(dim_padding = (3,3))(x)
        x = conv2d(x, 7, 3, padding = 'valid')
        out = tf.keras.layers.Activation('tanh')(x)
        
        model = tf.keras.Model(inputs = input_layer, outputs = out)
        return model
    
    def discriminator(self):
        filters = 64
        input_layer = tf.keras.layers.Input(shape = (128,128,3))
        #conv1
        x = conv2d(input_layer, 4,filters,stride = 2, padding = 'same')
        x = leaky_relu(x)
        #conv2 c128s2f4
        x = conv2d(x, 4, filters*2, stride = 2, padding = 'same' )
        x = instance_norm(x)
        x = leaky_relu(x)
        #c256
        x = conv2d(x,4,filters*4,stride = 2, padding = 'same')
        x = instance_norm(x)
        x = leaky_relu(x)
        #c512
        x = conv2d(x, 4,filters*8, stride =1, padding = 'same')
        x = instance_norm(x)
        x = leaky_relu(x)
        
        out = conv2d(x, 4, 1, stride = 1, padding ='same')
        model = tf.keras.Model(input_layer, out)
        return model
    
    #define losses 
    def gen_adversarial_loss(self, disc_fake_images):
        loss = self.mse(tf.ones_like(disc_fake_images), disc_fake_images)
        return loss
    
    def disc_adversarial_loss(self, real_images, gen_fake_images):
        real_images = self.mse(tf.ones_like(real_images), real_images)
        fake_images = self.mse(tf.zeros_like(gen_fake_images), gen_fake_images)
        loss = (real_images + fake_images) * 0.5
        return loss
    
    def compile(self, gen_a_opt, gen_b_opt, disc_a_opt, disc_b_opt):
        super(CycleGAN, self).compile()
        self.genAopt = gen_a_opt
        self.genBopt = gen_b_opt
        
        self.discAopt = disc_a_opt
        self.discBopt = disc_b_opt
        
        self.gen_loss = self.gen_adversarial_loss
        self.disc_loss = self.disc_adversarial_loss
        
        self.cycle_loss = tf.keras.losses.MeanAbsoluteError()
        self.identity_loss = tf.keras.losses.MeanAbsoluteError()
    
    def train_step(self, gan_dataset):
        batch_x, batch_y = gan_dataset
        #print(batch_x.shape, batch_y.shape)
        with tf.GradientTape(persistent = True) as tape:
            #get gen_A(van to real)
            #fake_x
            fake_y = self.gen_a(batch_x, training = True)
            fake_x = self.gen_b(batch_y, training = True)
            #get cycled gen_b -> real to van
            cycled_x= self.gen_b(fake_y, training = True)
            #get gen_B(real to van)
            #get cycled for gen_b
            cycled_y = self.gen_a(fake_x, training = True)
            
            #get identity loss
            same_y = self.gen_a(batch_y, training = True)
            same_x = self.gen_b(batch_x, training = True)
            
            #get disc output
            disc_real_x  = self.disc_a(batch_x, training = True)
            disc_fake_x  = self.disc_a(fake_x, training = True)
            
            disc_real_y  = self.disc_b(batch_y, training = True)
            disc_fake_y  = self.disc_b(fake_y, training = True)
            
            #calculating generator loss
            gen_a_loss  = self.gen_adversarial_loss(disc_fake_y)
            gen_b_loss  = self.gen_adversarial_loss(disc_fake_x)
            
            #calculate identity loss
            gen_a_identity = self.identity_loss(batch_y, same_y) * 10.0 * 0.5
            gen_b_identity = self.identity_loss(batch_x, same_x) * 10.0 * 0.5
            
            #cycle loss
            gen_a_cycle = self.cycle_loss(batch_y, cycled_y) * 10.0
            gen_b_cycle = self.cycle_loss(batch_x, cycled_x) * 10.0
            
            #Total disc loss
            disc_a_loss = self.disc_adversarial_loss(disc_real_x, disc_fake_x)
            disc_b_loss = self.disc_adversarial_loss(disc_real_y, disc_fake_y)
            
            #total gen loss
            gen_a_loss = gen_a_loss + gen_a_identity + gen_a_cycle
            gen_b_loss = gen_b_loss + gen_b_identity + gen_b_cycle
        gen_a_grad = tape.gradient(gen_a_loss, self.gen_a.trainable_variables)
        gen_b_grad = tape.gradient(gen_b_loss, self.gen_b.trainable_variables)
        
        #discriminator weight gradients
        disc_a_grad = tape.gradient(disc_a_loss, self.disc_a.trainable_variables)
        disc_b_grad = tape.gradient(disc_b_loss, self.disc_b.trainable_variables)
        
        #weight updation
        self.genAopt.apply_gradients(zip(gen_a_grad, self.gen_a.trainable_variables))
        self.genBopt.apply_gradients(zip(gen_b_grad, self.gen_b.trainable_variables))
        
        self.discAopt.apply_gradients(zip(disc_a_grad, self.disc_a.trainable_variables))
        self.discBopt.apply_gradients(zip(disc_b_grad, self.disc_b.trainable_variables))
        
        return {
            "G_loss": gen_a_loss,
            "F_loss": gen_b_loss,
            "D_X_loss": disc_a_loss,
            "D_Y_loss": disc_b_loss}
        

gan = CycleGAN()
gan.compile(tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
           tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5), tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5))

In [24]:
class GANMonitor(tf.keras.callbacks.Callback):
    """A callback to generate and save images after each epoch"""

    def __init__(self, num_img=4):
        self.num_img = num_img

    def on_epoch_end(self, epoch, logs=None):
        _, ax = plt.subplots(4, 2, figsize=(12, 12))
        for i in range(4):
            r = np.random.randint(0,350)
            img = van[r]
            img = np.expand_dims(img, axis = 0)
            prediction = self.model.gen_a(img).numpy()
            #prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
            img = (img[0] * 127.5 + 127.5).astype(np.uint8)

            ax[i, 0].imshow(img)
            ax[i, 1].imshow(denormalize_img(prediction[0]))
            ax[i, 0].set_title("Input image")
            ax[i, 1].set_title("Translated image")
            ax[i, 0].axis("off")
            ax[i, 1].axis("off")

            prediction = tf.keras.preprocessing.image.array_to_img(prediction[0])
            prediction.save(
                "generated_img_{i}_{epoch}.png".format(i=i, epoch=epoch + 1)
            )
        plt.show()
        plt.close()
        
plotter = GANMonitor()

In [None]:
gan.fit(x = van, y = real,epochs=100,
        batch_size = 1,
    callbacks=[plotter]
       )