# AI-Generated Zoology: An implementation of Image-to-Image Translation

### Main file

- Project for Pattern Recognition, COMP 473
-  Use of Python 3.8 (see requirement.txt)

#### Members
+ Sandra Buchen (2631798)
+ Nigel Yong Sao Young (40089856) 
+ Dan Raileanu (40019882) 
+ Inés Gonzalez Pepe (40095696) 
+ Marc Vicuna (40079109)

This main file has the capacity to run the project from top to bottom. Read the markdown cells for more information.

## Import TensorFlow and other libraries

In [1]:
import os
import time
import tensorflow as tf
import matplotlib.pyplot as plt
from IPython.display import clear_output
from PIL import Image
import random
import shutil
tf.__version__ 

# Random seed for reproducibility
tf.random.set_seed(0)

'2.2.1'

## Instantiate the constants
These constants were chosen for our implementation, they may vary on other applications.

In [2]:
BUFFER_SIZE = 400
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
OUTPUT_CHANNELS = 3
PATH = os.path.join(os.getcwd(), 'data/')
print(PATH)

c:\Users\Dan\source\AI-Generated-Zoology\data/


## Utility functions
Used to setup the training/test data.<br>

In [3]:
def get_concat_h(im1, im2):
    """stacks two images horizontally"""
    dst = Image.new('RGB', (im1.width + im2.width, im1.height))
    dst.paste(im1, (0, 0))
    dst.paste(im2, (im1.width, 0))
    return dst


def resize_combine(raw_path, sketch_path, result_path, img_size):
    """resizes and merges matching images from raw_path and sketch_path and outputs new image in result_path"""
    #make sure that RESULT_PATH is an empty dir
    if os.path.isdir(result_path):
        assert len(os.listdir(result_path)) == 0, result_path + " is not empty, clear it first"
    else:
        os.mkdir(result_path)

    raw_dir = os.listdir(raw_path)
    sketch_dir = os.listdir(sketch_path)
    assert len(raw_dir) == len(sketch_dir), '{} and {} have different number of files'.format(raw_path, sketch_path)
    result_dir = zip(raw_dir, sketch_dir)


    for raw_file, sketch_file in result_dir:
        assert raw_file.split('.')[0] == sketch_file.split('.')[0], 'raw_file: ' + raw_file + "  and sketch_file: " + sketch_file\
            +"  don't have matching names. Maybe clear " + sketch_path + " and run PhotoSketch model again"
        raw_img = Image.open('{}{}'.format(raw_path, raw_file))
        raw_img = raw_img.resize(img_size)
        sketch_img = Image.open('{}{}'.format(sketch_path, sketch_file))
        sketch_img = sketch_img.resize(img_size)
        
        merged_img = get_concat_h(raw_img, sketch_img)
        merged_img.save('{}{}{}'.format(result_path, raw_file.split('.')[0], '.jpeg'))




def shuffle_into_train_test(input_path, train_path, test_path, test_ratio, clear_dir=False):
    """
    shuffles samples randomly into train and test set
    set clear_dir to True if you want it to clear any existing files in the train/test folders
    """
    assert os.path.isdir(input_path), input_path + ' does not exist'
    assert test_ratio >=0 and test_ratio<=1, "TEST_RATIO must be between 0 and 1"
    for dir_path in [train_path, test_path]:
        if os.path.isdir(dir_path):
            if clear_dir and len(os.listdir(dir_path)) != 0:
                for filename in os.listdir(dir_path):
                    os.remove(dir_path+filename)
            else:
                assert len(os.listdir(dir_path)) == 0, dir_path + " is not empty, clear it first or set CLEAR_DIR to True"
        else:
            os.mkdir(dir_path)


    input_dir = os.listdir(input_path)
    random.shuffle(input_dir)

    for i,filename in enumerate(input_dir):
        if i <= test_ratio * len(input_dir):
            shutil.copy(input_path+filename, test_path+filename)
        else:
            shutil.copy(input_path+filename, train_path+filename)



# Pix2Pix model
Run cell below to load all functions for pix2pix model

In [4]:
class Model:
    """
    #####################
        Pix2Pix model  
    #####################
    """
    def downsample(self, filters, size, apply_batchnorm=True):
        """Downsampling, implementation of the encoder."""
        initializer = tf.random_normal_initializer(0., 0.02)
        result = tf.keras.Sequential()
        # 1st layer, Conv
        result.add( tf.keras.layers.Conv2D(filters, size, strides=2, padding='same', kernel_initializer=initializer, use_bias=False))
        # 2nd layer, Batchnorm
        if apply_batchnorm:
            result.add(tf.keras.layers.BatchNormalization())
        # 3rd layer, Leaky ReLU
        result.add(tf.keras.layers.LeakyReLU())
        
        return result


    def upsample(self, filters, size, apply_dropout=False):
        """Upsampling, implementation of the decoder."""
        initializer = tf.random_normal_initializer(0., 0.02)
        result = tf.keras.Sequential()
        # 1st layer, Conv
        result.add(tf.keras.layers.Conv2DTranspose(filters, size, strides=2, padding='same', kernel_initializer=initializer, use_bias=False))
        # 2nd layer, Batchnorm
        result.add(tf.keras.layers.BatchNormalization())
        # 3rd layer, Dropout (Randomization)
        if apply_dropout:
            result.add(tf.keras.layers.Dropout(0.5))
        # 4th layer, regular ReLU
        result.add(tf.keras.layers.ReLU())
        return result


    def Generator(self):
        """Defining the generator"""
        # Downsampling stack
        down_stack = [
            self.downsample(64, 4, apply_batchnorm=False), # (bs, 128, 128, 64)
            self.downsample(128, 4), # (bs, 64, 64, 128)
            self.downsample(256, 4), # (bs, 32, 32, 256)
            self.downsample(512, 4), # (bs, 16, 16, 512)
            self.downsample(512, 4), # (bs, 8, 8, 512)
            self.downsample(512, 4), # (bs, 4, 4, 512)
            self.downsample(512, 4), # (bs, 2, 2, 512)
            self.downsample(512, 4), # (bs, 1, 1, 512)
        ]
        # Upsampling stack
        up_stack = [
            self.upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
            self.upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
            self.upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
            self.upsample(512, 4), # (bs, 16, 16, 1024)
            self.upsample(256, 4), # (bs, 32, 32, 512)
            self.upsample(128, 4), # (bs, 64, 64, 256)
            self.upsample(64, 4), # (bs, 128, 128, 128)
        ]
        # Initialization
        initializer = tf.random_normal_initializer(0., 0.02)
        last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4, 
                                            strides=2, 
                                            padding='same',
                                            kernel_initializer=initializer,
                                            activation='tanh') # (bs, 256, 256, 3)
        concat = tf.keras.layers.Concatenate() 
        
        inputs = tf.keras.layers.Input(shape=[None,None,3])
        x = inputs
        # Downsampling through the model
        skips = []
        for down in down_stack:
            x = down(x)
            skips.append(x)
        
        skips = reversed(skips[:-1])
        # Upsampling and establishing the skip connections
        for up, skip in zip(up_stack, skips):
            x = up(x)
            x = concat([x, skip])
        
        x = last(x)
        
        return tf.keras.Model(inputs=inputs, outputs=x)


    def Discriminator(self):
        """Defining the discriminator"""
        # Initialization
        initializer = tf.random_normal_initializer(0., 0.02)
        inp = tf.keras.layers.Input(shape=[None, None, 3], name='input_image')
        tar = tf.keras.layers.Input(shape=[None, None, 3], name='target_image')
        x = tf.keras.layers.concatenate([inp, tar]) # (bs, 256, 256, channels*2)
        
        # Downsampling blocks instantiation
        down1 = self.downsample(64, 4, False)(x) # (bs, 128, 128, 64)
        down2 = self.downsample(128, 4)(down1) # (bs, 64, 64, 128)
        down3 = self.downsample(256, 4)(down2) # (bs, 32, 32, 256)
        
        zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
        # 1st layer, Conv
        conv = tf.keras.layers.Conv2D(512, 4, strides=1, 
                                    kernel_initializer=initializer, 
                                    use_bias=False)(zero_pad1) # (bs, 31, 31, 512)
        # 2nd layer, Batchnorm
        batchnorm1 = tf.keras.layers.BatchNormalization()(conv)
        # 3rd layer, Leaky ReLU
        leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)
        zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)
        # 4th layer, Conv
        last = tf.keras.layers.Conv2D(1, 4, strides=1,
                                    kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)
        
        return tf.keras.Model(inputs=[inp, tar], outputs=last)






    """Model constructor"""
    def __init__(self):
        #setup generator and discriminator
        self.generator = self.Generator()
        self.discriminator = self.Discriminator()
        # General loss instantiation
        self.loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
        # Optimizers instantiation
        self.generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
        self.discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
        # Datasets
        self.train_dataset = None
        self.test_dataset = None
        self.predict_dataset = None



    def load_input_target(self, image_file):
        """Loads the the image_file as an input_image and a real_image representing the target"""
        #read
        image = tf.io.read_file(image_file)
        image = tf.image.decode_jpeg(image)
        
        #reformat to  return input/target separetely
        w = tf.shape(image)[1]
        w = w // 2
        real_image = image[:, :w, :]
        input_image = image[:, w:, :]
        
        input_image = tf.cast(input_image, tf.float32)
        real_image = tf.cast(real_image, tf.float32)

        return input_image, real_image


    def load_input(self, image_file):
        """ Loads the image_file as a single input_image"""
        #read
        image = tf.io.read_file(image_file)
        image = tf.image.decode_jpeg(image)
        
        input_image = tf.cast(image, tf.float32)

        return input_image


    def resize(self, input_image, real_image, height, width):
        """Resize the images to given height,width"""
        input_image = tf.image.resize(input_image, [height, width], 
                                    method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
        real_image = tf.image.resize(real_image, [height, width], 
                                    method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
        
        return input_image, real_image


    def random_crop(self, input_image, real_image):
        """Cropping the images using Tensorflow's utility functions"""        
        stacked_image = tf.stack([input_image, real_image], axis=0)
        cropped_image = tf.image.random_crop( stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])
        
        return cropped_image[0], cropped_image[1]


    def normalize(self, input_image, real_image):
        """Normalizing the images to [-1, 1]"""
        input_image = (input_image / 127.5) - 1
        real_image  = (real_image   / 127.5) - 1
        
        return input_image, real_image


    @tf.function()
    def random_jitter(self, input_image, real_image):
        """
        As mentioned in the paper, we apply random jittering and mirroring to the training dataset.
        *In random jittering, the image is resized to `286 x 286` and then randomly cropped to `256 x 256`
        *In random mirroring, the image is randomly flipped horizontally i.e left to right.
        """
        # resizing to 286 x 286 x 3
        input_image, real_image = self.resize(input_image, real_image, 286, 286)
        
        # randomly cropping to 256 x 256 x 3
        input_image, real_image = self.random_crop(input_image, real_image)
        if tf.random.uniform(()) > 0.5:
            # random mirroring
            input_image = tf.image.flip_left_right(input_image)
            real_image = tf.image.flip_left_right(real_image)
        
        return input_image, real_image

    def load_image_train(self, image_file):
        """Loading the image with heavier preprocessing, use of random jitter and normalization"""
        input_image, real_image = self.load_input_target(image_file)
        input_image, real_image = self.random_jitter(input_image, real_image)
        input_image, real_image = self.normalize(input_image, real_image)
        
        return input_image, real_image


    def load_image_test(self, image_file):
        """Loading the image with heavier preprocessing, adapted to testing"""
        
        input_image, real_image = self.load_input_target(image_file)
        input_image, real_image = self.resize(input_image, real_image,  IMG_HEIGHT, IMG_WIDTH) 
        input_image, real_image = self.normalize(input_image, real_image)
        
        return input_image, real_image


    def load_image_predict(self, image_file):
        """Loading the image with heavier preprocessing, adapted for predicting without target image"""
        input_image= self.load_input(image_file)
        input_image, real_image = self.resize(input_image, input_image,  IMG_HEIGHT, IMG_WIDTH) 
        input_image, real_image = self.normalize(input_image, input_image)
        
        return input_image

    def load_train(self, train_path):
        """Pipeline setup for training"""
        self.train_dataset = tf.data.Dataset.list_files(PATH+train_path+'*.jpeg')
        self.train_dataset = self.train_dataset.shuffle(BUFFER_SIZE)
        self.train_dataset = self.train_dataset.map(self.load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        self.train_dataset = self.train_dataset.batch(1)


    def load_test(self, test_path):
        """Pipeline setup for testing"""
        self.test_dataset = tf.data.Dataset.list_files(PATH+test_path+'*.jpeg')
        self.test_dataset = self.test_dataset.shuffle(BUFFER_SIZE)
        self.test_dataset = self.test_dataset.map(self.load_image_test)
        self.test_dataset = self.test_dataset.batch(1)


    def load_predict(self, pred_path):
        """Pipeline setup for predictions without target image"""
        self.predict_dataset = tf.data.Dataset.list_files(PATH+pred_path+'*.jpeg')
        self.predict_dataset = self.predict_dataset.map(self.load_image_test)
        self.predict_dataset = self.predict_dataset.batch(1)


    def discriminator_loss(self, disc_real_output, disc_generated_output):
        """Defining the discriminator loss"""        
        real_loss = self.loss_object(tf.ones_like(disc_real_output), disc_real_output)
        generated_loss = self.loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)
        total_disc_loss = real_loss + generated_loss
        
        return total_disc_loss


    def generator_loss(self, disc_generated_output, gen_output, target):
        """Defining the generator loss"""
        LAMBDA = 100
        gan_loss = self.loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
        
        # mean absolute error
        l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
        total_gen_loss = gan_loss + (LAMBDA * l1_loss)
        
        return total_gen_loss


    def generate_images(self, model, test_input, tar=None, savePath = None):
        """
        Generates images based on the current model.
        The training=True is intentional here since we want the batch 
        statistics while running the model on the test dataset. If we 
        use training=False, we will get the accumulated statistics 
        learned from the training dataset (which we don't want).
        """     
        # Prediction
        prediction = model(test_input, training=True)

        # Plotting
        plt.figure(figsize=(15,15))
        
        # With ground truth or not
        if tar != None:
            display_list = [test_input[0], tar[0], prediction[0]]
            title = ['Input Image', 'Ground Truth', 'Predicted Image']
        else:
            display_list = [test_input[0], prediction[0]]
            title = ['Input Image', 'Predicted Image']

        for i in range(len(display_list)):
            plt.subplot(1, len(display_list), i+1)
            plt.title(title[i])
            # getting the pixel values between [0, 1] to plot it.
            plt.imshow(display_list[i] * 0.5 + 0.5)
            plt.axis('off')
        if savePath != None:
            plt.savefig(savePath)
        plt.show()


    @tf.function
    def train_step(self, input_image, target):
        """Trains on a single image"""
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            gen_output = self.generator(input_image, training=True)
            
            disc_real_output = self.discriminator([input_image, target], training=True)
            disc_generated_output = self.discriminator([input_image, gen_output], training=True)
            gen_loss = self.generator_loss(disc_generated_output, gen_output, target)
            disc_loss = self.discriminator_loss(disc_real_output, disc_generated_output)
        
        generator_gradients = gen_tape.gradient(gen_loss, 
                                            self.generator.trainable_variables)
        discriminator_gradients = disc_tape.gradient(disc_loss, 
                                            self.discriminator.trainable_variables)
        
        self.generator_optimizer.apply_gradients(zip(generator_gradients, 
                                            self.generator.trainable_variables))
        self.discriminator_optimizer.apply_gradients(zip(discriminator_gradients, 
                                                self.discriminator.trainable_variables))



    def train(self, epochs):
        assert self.train_dataset != None, "No train_dataset present. First call load_train()"
        """Trains on all the train_dataset samples for epochs number of times"""
        for epoch in range(epochs):
            start = time.time()
            
            for input_image, target in self.train_dataset:
                self.train_step(input_image, target)
            
            clear_output(wait=True)
            for inp, tar in self.test_dataset.take(1):
                self.generate_images(self.generator, inp, tar)
            # Saving checkpoint for the model every 100 epochs
            if (epoch + 1) % 100 == 0: self.save_checkpoint()
            # Output to console prediction on some train data for each epoch
            print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1, time.time()-start))

    
    def test(self, num):
        assert self.test_dataset != None, "No test_dataset present. First call load_test()"
        for i, (inp, tar) in enumerate(self.test_dataset.take(num)):
                self.generate_images(self.generator, inp, tar, "{}{}{}.jpeg".format(PATH, 'results/', i))


    def predict(self, num):
        assert self.predict_dataset != None, "No predict_dataset present. First call load_predict()"
        for i, inp in enumerate(self.predict_dataset.take(num)):
                self.generate_images(self.generator, inp, tar=None, savePath="{}{}{}.jpeg".format(PATH, 'results/', i))


    def show_Generator_Discriminator(self, num):
        for inp, tar in self.train_dataset.take(num):
            fig,axs = plt.subplots(1,4, figsize=(12,3))

            axs[0].imshow(inp[0])
            axs[0].set_title("Input")

            axs[1].imshow(tar[0])
            axs[1].set_title("Real")

            gen_output = self.generator(inp, training=False)
            plt.figure()
            axs[2].imshow(gen_output[0,...]);
            axs[2].set_title("Generator")

            disc_out = self.discriminator([inp, gen_output], training=False)
            axs[3].imshow(disc_out[0,...,-1], cmap='RdBu_r')
            axs[3].set_title("Discriminator")


    def set_checkpoint_dir(self, path):
        """Set path for loading/saving checkpoints"""
        self.checkpoint_dir = path
        self.checkpoint_prefix = os.path.join(self.checkpoint_dir, "ckpt")
        if not os.path.isdir(self.checkpoint_dir):
            os.mkdir(self.checkpoint_dir)
        self.checkpoint = tf.train.Checkpoint(generator_optimizer=self.generator_optimizer,
                                        discriminator_optimizer=self.discriminator_optimizer,
                                        generator=self.generator,
                                        discriminator=self.discriminator)


    def save_checkpoint(self):
        """Save checkpoint in checkpoint_dir folder"""
        self.checkpoint.save(file_prefix = checkpoint_prefix)
                

    def load_checkpoint(self):
        """Restoring the latest checkpoint in checkpoint_dir"""
        self.checkpoint.restore(tf.train.latest_checkpoint(self.checkpoint_dir))

# Training model from scratch (optional)
To speed up the process, only a dozen pictures have been placed in data/raw-img/ directory with their coresponding sketch in data/sketch-img
Run the cells below for a demo on how our model trained

In [5]:
#resize and combine the raw data with the sketches
resize_combine(raw_path="data/raw-img/", sketch_path="data/sketch-img/", result_path="data/resized_combined/", img_size=(256,256))
#shuffle data into training and testing
shuffle_into_train_test(input_path="data/resized_combined/", train_path="data/train/", test_path="data/test/", test_ratio=0.2, clear_dir=True)

AssertionError: data/resized_combined/ is not empty, clear it first

In [None]:
#instanciate a model and load train/test paths
DemoModel = Model()
DemoModel.load_train("train/")
DemoModel.load_test("test/")

In [None]:
DemoModel.train(epochs=1)

In [None]:
DemoModel.show_Generator_Discriminator(3)

In [None]:
DemoModel.test(3)

# RESULTS

## DOGS

In [None]:
DogsModel = Model()
DogsModel.load_train("train(dogs_small)/")#loading smaller version of the dogs training set for demo
DogsModel.load_test("test(dogs)/")
#DogsModel.load_predict("custom/Dog/")

In [None]:
#takes about 30sec
DogsModel.train(3)

In [None]:
DogsModel.test(5)

In [None]:
DogsModel.show_Generator_Discriminator(5)

### Load checkpoint for 500 epochs of training on whole dataset

In [None]:
DogsModel.set_checkpoint_dir('training_checkpoints/dogs/')
DogsModel.load_checkpoint()

In [None]:
DogsModel.test(5)

## CATS

In [None]:
CatsModel = Model()
CatsModel.load_train("train(cats_small)/")#loading smaller version of the cats training set for demo
CatsModel.load_test("test(cats)/")
#CatsModel.load_predict("custom/Cat/")

In [None]:
#takes about 30sec
CatsModel.train(3)

In [None]:
CatsModel.show_Generator_Discriminator(5)

In [None]:
CatsModel.test(5)

### Load checkpoint for 500 epochs of training on whole dataset

In [None]:
CatsModel.set_checkpoint_dir('training_checkpoints/cats/')
CatsModel.load_checkpoint()

In [None]:
CatsModel.test(5)

# Components detailed information

In [None]:
# Loads the the image_file as an input_image and a real_image representing the target
def load_input_target(image_file):

    #read
    image = tf.io.read_file(image_file)
    image = tf.image.decode_jpeg(image)
    
    #reformat
    w = tf.shape(image)[1]
    w = w // 2
    real_image = image[:, :w, :]
    input_image = image[:, w:, :]
    
    input_image = tf.cast(input_image, tf.float32)
    real_image = tf.cast(real_image, tf.float32)
    
    return input_image, real_image

In [None]:
# Loads the image_file as a single input_image
def load_input(image_file):

   #read
    image = tf.io.read_file(image_file)
    image = tf.image.decode_jpeg(image)
    
    input_image = tf.cast(image, tf.float32)

    return input_image

### Testing loading, IO test
IO is important for this project. Make sure you have already downloaded the data. 
See README.md if there is any issue. This test should display the first image of the dataset.

In [None]:
# # Loading image
# inp, re = load(PATH+'train/1.jpeg')
# # casting to int for matplotlib to show the image
# plt.figure()
# plt.imshow(inp/255)
# plt.figure()
# plt.imshow(re/255)

In [None]:
# Resizing the image
def resize(input_image, real_image, height, width):
    input_image = tf.image.resize(input_image, [height, width], 
                                  method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    real_image = tf.image.resize(real_image, [height, width], 
                                  method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    
    return input_image, real_image

In [None]:
# Cropping the image using Tensorflow's utility functions
def random_crop(input_image, real_image):
    
    stacked_image = tf.stack([input_image, real_image], axis=0)
    cropped_image = tf.image.random_crop( stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])
    
    return cropped_image[0], cropped_image[1]

In [None]:
# Normalizing the images to [-1, 1]
def normalize(input_image, real_image):
    
    input_image = (input_image / 127.5) - 1
    real_image  = (real_image   / 127.5) - 1
    
    return input_image, real_image

In [None]:
# Implementing the random jitter (see below for details)
@tf.function()
def random_jitter(input_image, real_image):
    # resizing to 286 x 286 x 3
    input_image, real_image = resize(input_image, real_image, 286, 286)
    
    # randomly cropping to 256 x 256 x 3
    input_image, real_image = random_crop(input_image, real_image)
    if tf.random.uniform(()) > 0.5:
        # random mirroring
        input_image = tf.image.flip_left_right(input_image)
        real_image = tf.image.flip_left_right(real_image)
    
    return input_image, real_image

### Testing Random Jitter visually
Random jitter is a small, low-cost preprocessing step used 
in the context of Image-to-Image translation for natural images,
insensitive to pixel shift and mirroring. Using the prior 
knowledge of natural images, it encourages better generalizability 
of the model. <br>
Random jittering as described in the paper is to:
* Resize an image to bigger height and width
* Randomnly crop to the original size
* Randomnly flip the image horizontally 

In [None]:
"""
# Plotting 4 times the same image with random_jitter applied
plt.figure(figsize=(6, 6))
for i in range(4):
    rj_inp, rj_re = random_jitter(inp, re)  
    plt.subplot(2, 2, i+1)
    plt.imshow(rj_inp/255.0)
    plt.axis('off')
plt.show()
"""

In [None]:
# Loading the image with heavier preprocessing, use of random jitter and normalization
def load_image_train(image_file):
    
    input_image, real_image = load_input_target(image_file)
    input_image, real_image = random_jitter(input_image, real_image)
    input_image, real_image = normalize(input_image, real_image)
    
    return input_image, real_image

In [None]:
# Loading the image with heavier preprocessing, adapted to testing
def load_image_test(image_file):
    
    input_image, real_image = load_input_target(image_file)
    input_image, real_image = resize(input_image, real_image,  IMG_HEIGHT, IMG_WIDTH) 
    input_image, real_image = normalize(input_image, real_image)
    
    return input_image, real_image

In [None]:
# Loading the image with heavier preprocessing, adapted to predicting without target image
def load_image_predict(image_file):
    input_image= load_input(image_file)
    input_image, real_image = resize(input_image, input_image,  IMG_HEIGHT, IMG_WIDTH) 
    input_image, real_image = normalize(input_image, input_image)
    
    return input_image

## Input Pipeline
Setting up the input Pipeline for training. Make sure all your data is directly in the train directory, in jpeg.

In [None]:
# Pipeline setup for training
train_dataset = tf.data.Dataset.list_files(PATH+'train(cats_small)/*.jpeg')
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.batch(1)

In [None]:
# Pipeline setup for training
test_dataset = tf.data.Dataset.list_files(PATH+'test(cats)/*.jpeg')
test_dataset = test_dataset.shuffle(BUFFER_SIZE)
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(1)

## Build the Generator
  * The architecture of generator is a modified U-Net.
  * Each block in the encoder is (Conv -> Batchnorm -> Leaky ReLU)
  * Each block in the decoder is (Transposed Conv -> Batchnorm -> Dropout(applied to the first 3 blocks) -> ReLU)
  * There are skip connections between the encoder and decoder (as in U-Net).

In [None]:
# Downsampling, implementation of the encoder.
def downsample(filters, size, apply_batchnorm=True):
    
    initializer = tf.random_normal_initializer(0., 0.02)
    result = tf.keras.Sequential()
    # 1st layer, Conv
    result.add( tf.keras.layers.Conv2D(filters, size, strides=2, padding='same', kernel_initializer=initializer, use_bias=False))
    # 2nd layer, Batchnorm
    if apply_batchnorm:
        result.add(tf.keras.layers.BatchNormalization())
    # 3rd layer, Leaky ReLU
    result.add(tf.keras.layers.LeakyReLU())
    
    return result

In [None]:
# Upsampling, implementation of the decoder.
def upsample(filters, size, apply_dropout=False):
    
    initializer = tf.random_normal_initializer(0., 0.02)
    result = tf.keras.Sequential()
    # 1st layer, Conv
    result.add(tf.keras.layers.Conv2DTranspose(filters, size, strides=2, padding='same', kernel_initializer=initializer, use_bias=False))
    # 2nd layer, Batchnorm
    result.add(tf.keras.layers.BatchNormalization())
    # 3rd layer, Dropout (Randomization)
    if apply_dropout:
        result.add(tf.keras.layers.Dropout(0.5))
    # 4th layer, regular ReLU
    result.add(tf.keras.layers.ReLU())
    return result

In [None]:
# Defining the generator
def Generator():
    # Downsampling stack
    down_stack = [
        downsample(64, 4, apply_batchnorm=False), # (bs, 128, 128, 64)
        downsample(128, 4), # (bs, 64, 64, 128)
        downsample(256, 4), # (bs, 32, 32, 256)
        downsample(512, 4), # (bs, 16, 16, 512)
        downsample(512, 4), # (bs, 8, 8, 512)
        downsample(512, 4), # (bs, 4, 4, 512)
        downsample(512, 4), # (bs, 2, 2, 512)
        downsample(512, 4), # (bs, 1, 1, 512)
    ]
    # Upsampling stack
    up_stack = [
        upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
        upsample(512, 4), # (bs, 16, 16, 1024)
        upsample(256, 4), # (bs, 32, 32, 512)
        upsample(128, 4), # (bs, 64, 64, 256)
        upsample(64, 4), # (bs, 128, 128, 128)
    ]
    # Initialization
    initializer = tf.random_normal_initializer(0., 0.02)
    last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4, 
                                         strides=2, 
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='tanh') # (bs, 256, 256, 3)
    concat = tf.keras.layers.Concatenate() 
    
    inputs = tf.keras.layers.Input(shape=[None,None,3])
    x = inputs
    # Downsampling through the model
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)
    
    skips = reversed(skips[:-1])
    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = concat([x, skip])
    
    x = last(x)
    
    return tf.keras.Model(inputs=inputs, outputs=x)

### Testing the Generator on 1 image
You should be able to see an image composed of noise, with the trace of your first edge image. Ignore the warning if there is any.

In [None]:
generator = Generator()
#gen_output = generator(inp[tf.newaxis,...], training=False)
#plt.imshow(gen_output[0,...]);

## Build the Discriminator
  * The Discriminator is a PatchGAN.
  * Each block in the discriminator is (Conv -> BatchNorm -> Leaky ReLU).
  * The shape of the output after the last layer is (batch_size, 30, 30, 1).
  * Each 30x30 patch of the output classifies a 70x70 portion of the input image (such an architecture is called a PatchGAN).
  * Discriminator receives 2 inputs.
    * Input image and the target image, which it should classify as real.
    * Input image and the generated image (output of generator), which it should classify as fake. 
    * We concatenate these 2 inputs together in the code (`tf.concat([inp, tar], axis=-1)`).

In [None]:
# Defining the discriminator
def Discriminator():
    
    # Initialization
    initializer = tf.random_normal_initializer(0., 0.02)
    inp = tf.keras.layers.Input(shape=[None, None, 3], name='input_image')
    tar = tf.keras.layers.Input(shape=[None, None, 3], name='target_image')
    x = tf.keras.layers.concatenate([inp, tar]) # (bs, 256, 256, channels*2)
    
    # Downsampling blocks instantiation
    down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)
    down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
    down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)
    
    zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
    # 1st layer, Conv
    conv = tf.keras.layers.Conv2D(512, 4, strides=1, 
                                kernel_initializer=initializer, 
                                use_bias=False)(zero_pad1) # (bs, 31, 31, 512)
    # 2nd layer, Batchnorm
    batchnorm1 = tf.keras.layers.BatchNormalization()(conv)
    # 3rd layer, Leaky ReLU
    leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)
    zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)
    # 4th layer, Conv
    last = tf.keras.layers.Conv2D(1, 4, strides=1,
                                kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)
    
    return tf.keras.Model(inputs=[inp, tar], outputs=last)

### Testing the Discriminator on 1 image
You should be able to see an image composed of noise, with the trace of your first edge image. Ignore the warning if there is any.

In [None]:
discriminator = Discriminator()
#disc_out = discriminator([inp[tf.newaxis,...], gen_output], training=False)
#plt.imshow(disc_out[0,...,-1], vmin=-20, vmax=20, cmap='RdBu_r')
#plt.colorbar();

To learn more about the architecture and the hyperparameters you can refer the [paper](https://arxiv.org/abs/1611.07004).

## Define the loss functions and the optimizer

* **Discriminator loss**
  * The discriminator loss function takes 2 inputs; **real images, generated images**
  * real_loss is a sigmoid cross entropy loss of the **real images** and an **array of ones(since these are the real images)**
  * generated_loss is a sigmoid cross entropy loss of the **generated images** and an **array of zeros(since these are the fake images)**
  * Then the total_loss is the sum of real_loss and the generated_loss
  
* **Generator loss**
  * It is a sigmoid cross entropy loss of the generated images and an **array of ones**.
  * The [paper](https://arxiv.org/abs/1611.07004) also includes L1 loss which is MAE (mean absolute error) between the generated image and the target image.
  * This allows the generated image to become structurally similar to the target image.
  * The formula to calculate the total generator $loss = GAN_{loss} + \lambda * L1_{loss}$, where $\lambda = 100$. This value was decided by the authors of the [paper](https://arxiv.org/abs/1611.07004).

In [None]:
# General loss instantiation
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [None]:
# Defining the discriminator loss
def discriminator_loss(disc_real_output, disc_generated_output):
    
    real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
    generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)
    total_disc_loss = real_loss + generated_loss
    
    return total_disc_loss

In [None]:
# Defining the generator loss
def generator_loss(disc_generated_output, gen_output, target):
    LAMBDA = 100
    gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
    
    # mean absolute error
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
    total_gen_loss = gan_loss + (LAMBDA * l1_loss)
    
    return total_gen_loss

In [None]:
# Instantiating optimizers
generator_optimizer      = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

## Checkpoints (Object-based saving)
Creates a new checkpoint, for your new model. <br>
Do not modify the format of the checkpoint. If you do, modify the following cell corresponding to your new format.

In [None]:
#Directory
checkpoint_dir = './training_checkpoints'
if not os.path.isdir(checkpoint_dir):
    os.mkdir(checkpoint_dir)
# Loading
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

## Training

* We start by iterating over the dataset
* The generator gets the input image and we get a generated output.
* The discriminator receives the input_image and the generated image as the first input. The second input is the input_image and the target_image.
* Next, we calculate the generator and the discriminator loss.
* Then, we calculate the gradients of loss with respect to both the generator and the discriminator variables(inputs) and apply those to the optimizer.


## Generate Images

* After training, its time to generate some images!
* We pass images from the test dataset to the generator.
* The generator will then translate the input image into the output we expect.
* Last step is to plot the predictions and **voila!**

In [None]:
# Generates images based on the current model
def generate_images(model, test_input, tar=None, savePath = None):
    """
    The training=True is intentional here since we want the batch 
    statistics while running the model on the test dataset. If we 
    use training=False, we will get the accumulated statistics 
    learned from the training dataset (which we don't want).
    """
    
    # Prediction
    prediction = model(test_input, training=True)

    # Plotting
    plt.figure(figsize=(15,15))
    
    # With ground truth or not
    if tar != None:
        display_list = [test_input[0], tar[0], prediction[0]]
        title = ['Input Image', 'Ground Truth', 'Predicted Image']
    else:
        display_list = [test_input[0], prediction[0]]
        title = ['Input Image', 'Predicted Image']

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        # getting the pixel values between [0, 1] to plot it.
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    if savePath != None:
        plt.savefig(savePath)
        plt.close()
    else:
        plt.show()

In [None]:
# Trains on a single image. The function depends on the instantiation of 
# many functions and objects, make sure you ran through all cells.
@tf.function
def train_step(input_image, target):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator(input_image, training=True)
        
        disc_real_output = discriminator([input_image, target], training=True)
        disc_generated_output = discriminator([input_image, gen_output], training=True)
        gen_loss = generator_loss(disc_generated_output, gen_output, target)
        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
    
    generator_gradients = gen_tape.gradient(gen_loss, 
                                          generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(disc_loss, 
                                          discriminator.trainable_variables)
    
    generator_optimizer.apply_gradients(zip(generator_gradients, 
                                          generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients, 
                                              discriminator.trainable_variables))

In [None]:
# Trains on all the dataset
def train(dataset, epochs):  
    
    for epoch in range(epochs):
        start = time.time()
        
        for input_image, target in dataset:
            train_step(input_image, target)
        
        clear_output(wait=True)
        for inp, tar in test_dataset.take(1):
            generate_images(generator, inp, tar)
        # Saving (checkpoint) the model every 20 epochs
        if (epoch + 1) % 20 == 0: checkpoint.save(file_prefix = checkpoint_prefix)
        # Output to console. Trust me, it takes a while. Always good to have some sign of life.
        print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1, time.time()-start))

### Testing the Model on 1 image
You should be able to see the edge image, the real image and the predicted image, composed of noise, with the trace of your first edge image.

In [None]:
train(train_dataset, 1)

## Restore the latest checkpoint and test
Loads the last checkpoint, to load the trained model. <br>
Verify you have downloaded the lastest checkpoint. This is the trained version of the model. After the data, it should be close to the most expensive file memory-wise.

In [None]:
!ls {checkpoint_dir}

In [None]:
# Restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

## Testing on the entire test dataset

In [None]:
# Run the trained model on the entire test dataset
for i ,(inp, tar) in enumerate(test_dataset):
    generate_images(generator, inp, tar, "{}{}{}.jpeg".format(PATH, 'results/', i))

## Custom images - Dogs Testing

In [None]:
checkpoint_dir = './training_checkpoints/Dog/'
if not os.path.isdir(checkpoint_dir):
    os.mkdir(checkpoint_dir)
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

# restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

In [None]:
predict_dataset = tf.data.Dataset.list_files('custom/Dog/*.jpeg')
predict_dataset = predict_dataset.map(load_image_predict)
predict_dataset = predict_dataset.batch(1)

In [None]:
for inp in predict_dataset:
    generate_images(generator, inp)

## Custom images - Cats Testing

In [None]:
checkpoint_dir = './training_checkpoints/Cat/'
if not os.path.isdir(checkpoint_dir):
    os.mkdir(checkpoint_dir)
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

# restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

In [None]:
predict_dataset = tf.data.Dataset.list_files('custom/Cat/*.jpeg')
predict_dataset = predict_dataset.map(load_image_predict)
predict_dataset = predict_dataset.batch(1)

In [None]:
for inp in predict_dataset:
    generate_images(generator, inp)