# AI-Generated Zoology
## An implementation of Image-to-Image Translation
### Main file
#### For project in Pattern Recognition, COMP 472
#### Due December 10th, 2020
#### Use of Python 3.8 (see requirement.txt)
#### By Sandra Buchen (2631798)
#### Nigel Yong Sao Young (40089856) 
#### Dan Raileanu (40019882) 
#### Inés Gonzalez Pepe (40095696) 
#### Marc Vicuna (40079109)
This main file has the capacity to run the project from top to bottom. Read the markdown cells for more information.

## Import TensorFlow and other libraries

In [1]:
import os
import time
import tensorflow as tf
import matplotlib.pyplot as plt
from IPython.display import clear_output
tf.__version__ 

'2.2.1'

In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

## Load the dataset

As mentioned in the [paper](https://arxiv.org/abs/1611.07004) we apply random jittering and mirroring to the training dataset.
* In random jittering, the image is resized to `286 x 286` and then randomly cropped to `256 x 256`
* In random mirroring, the image is randomly flipped horizontally i.e left to right.

## Instantiate the constants
These constants were chosen for our implementation, they may vary on other applications.

In [2]:
BUFFER_SIZE = 400
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
OUTPUT_CHANNELS = 3
PATH = os.path.join(os.getcwd(), 'data/')
print(PATH)

c:\Users\Dan\source\AI-Generated-Zoology\data/


## Utility functions
If you are not interested in utility functions, just instantiate all functions until the next subtitle markdown text.
Most of the data manipulation is handled by these utility functions.

In [3]:
# Loads the the image_file as an input_image and a real_image representing the target
def load_input_target(image_file):

    #read
    image = tf.io.read_file(image_file)
    image = tf.image.decode_jpeg(image)
    
    #reformat
    w = tf.shape(image)[1]
    w = w // 2
    real_image = image[:, :w, :]
    input_image = image[:, w:, :]
    
    input_image = tf.cast(input_image, tf.float32)
    real_image = tf.cast(real_image, tf.float32)
    
    return input_image, real_image

In [4]:
# Loads the image_file as a single input_image
def load_input(image_file):

   #read
    image = tf.io.read_file(image_file)
    image = tf.image.decode_jpeg(image)
    
    input_image = tf.cast(image, tf.float32)

    return input_image

### Testing loading, IO test
IO is important for this project. Make sure you have already downloaded the data. 
See README.md if there is any issue. This test should display the first image of the dataset.

In [5]:
# # Loading image
# inp, re = load(PATH+'train/1.jpeg')
# # casting to int for matplotlib to show the image
# plt.figure()
# plt.imshow(inp/255)
# plt.figure()
# plt.imshow(re/255)

In [6]:
# Resizing the image
def resize(input_image, real_image, height, width):
    input_image = tf.image.resize(input_image, [height, width], 
                                  method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    real_image = tf.image.resize(real_image, [height, width], 
                                  method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    
    return input_image, real_image

In [7]:
# Cropping the image using Tensorflow's utility functions
def random_crop(input_image, real_image):
    
    stacked_image = tf.stack([input_image, real_image], axis=0)
    cropped_image = tf.image.random_crop( stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])
    
    return cropped_image[0], cropped_image[1]

In [8]:
# Normalizing the images to [-1, 1]
def normalize(input_image, real_image):
    
    input_image = (input_image / 127.5) - 1
    real_image  = (real_image   / 127.5) - 1
    
    return input_image, real_image

In [9]:
# Implementing the random jitter (see below for details)
@tf.function()
def random_jitter(input_image, real_image):
    # resizing to 286 x 286 x 3
    input_image, real_image = resize(input_image, real_image, 286, 286)
    
    # randomly cropping to 256 x 256 x 3
    input_image, real_image = random_crop(input_image, real_image)
    if tf.random.uniform(()) > 0.5:
        # random mirroring
        input_image = tf.image.flip_left_right(input_image)
        real_image = tf.image.flip_left_right(real_image)
    
    return input_image, real_image

### Testing Random Jitter visually
Random jitter is a small, low-cost preprocessing step used 
in the context of Image-to-Image translation for natural images,
insensitive to pixel shift and mirroring. Using the prior 
knowledge of natural images, it encourages better generalizability 
of the model. <br>
Random jittering as described in the paper is to:
* Resize an image to bigger height and width
* Randomnly crop to the original size
* Randomnly flip the image horizontally 

In [10]:
"""
# Plotting 4 times the same image with random_jitter applied
plt.figure(figsize=(6, 6))
for i in range(4):
    rj_inp, rj_re = random_jitter(inp, re)  
    plt.subplot(2, 2, i+1)
    plt.imshow(rj_inp/255.0)
    plt.axis('off')
plt.show()
"""

"\n# Plotting 4 times the same image with random_jitter applied\nplt.figure(figsize=(6, 6))\nfor i in range(4):\n    rj_inp, rj_re = random_jitter(inp, re)  \n    plt.subplot(2, 2, i+1)\n    plt.imshow(rj_inp/255.0)\n    plt.axis('off')\nplt.show()\n"

In [11]:
# Loading the image with heavier preprocessing, use of random jitter and normalization
def load_image_train(image_file):
    
    input_image, real_image = load_input_target(image_file)
    input_image, real_image = random_jitter(input_image, real_image)
    input_image, real_image = normalize(input_image, real_image)
    
    return input_image, real_image

In [12]:
# Loading the image with heavier preprocessing, adapted to testing
def load_image_test(image_file):
    
    input_image, real_image = load_input_target(image_file)
    input_image, real_image = resize(input_image, real_image,  IMG_HEIGHT, IMG_WIDTH) 
    input_image, real_image = normalize(input_image, real_image)
    
    return input_image, real_image

In [13]:
# Loading the image with heavier preprocessing, adapted to predicting without target image
def load_image_predict(image_file):
    input_image= load_input(image_file)
    input_image, real_image = resize(input_image, input_image,  IMG_HEIGHT, IMG_WIDTH) 
    input_image, real_image = normalize(input_image, input_image)
    
    return input_image

## Input Pipeline
Setting up the input Pipeline for training. Make sure all your data is directly in the train directory, in jpeg.

In [14]:
# Pipeline setup for training
train_dataset = tf.data.Dataset.list_files(PATH+'train(cats)/*.jpeg')
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.batch(1)

In [15]:
# Pipeline setup for training
test_dataset = tf.data.Dataset.list_files(PATH+'test(cats)/*.jpeg')
test_dataset = test_dataset.shuffle(BUFFER_SIZE)
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(1)

## Build the Generator
  * The architecture of generator is a modified U-Net.
  * Each block in the encoder is (Conv -> Batchnorm -> Leaky ReLU)
  * Each block in the decoder is (Transposed Conv -> Batchnorm -> Dropout(applied to the first 3 blocks) -> ReLU)
  * There are skip connections between the encoder and decoder (as in U-Net).

In [16]:
# Downsampling, implementation of the encoder.
def downsample(filters, size, apply_batchnorm=True):
    
    initializer = tf.random_normal_initializer(0., 0.02)
    result = tf.keras.Sequential()
    # 1st layer, Conv
    result.add( tf.keras.layers.Conv2D(filters, size, strides=2, padding='same', kernel_initializer=initializer, use_bias=False))
    # 2nd layer, Batchnorm
    if apply_batchnorm:
        result.add(tf.keras.layers.BatchNormalization())
    # 3rd layer, Leaky ReLU
    result.add(tf.keras.layers.LeakyReLU())
    
    return result

In [17]:
# Upsampling, implementation of the decoder.
def upsample(filters, size, apply_dropout=False):
    
    initializer = tf.random_normal_initializer(0., 0.02)
    result = tf.keras.Sequential()
    # 1st layer, Conv
    result.add(tf.keras.layers.Conv2DTranspose(filters, size, strides=2, padding='same', kernel_initializer=initializer, use_bias=False))
    # 2nd layer, Batchnorm
    result.add(tf.keras.layers.BatchNormalization())
    # 3rd layer, Dropout (Randomization)
    if apply_dropout:
        result.add(tf.keras.layers.Dropout(0.5))
    # 4th layer, regular ReLU
    result.add(tf.keras.layers.ReLU())
    return result

In [18]:
# Defining the generator
def Generator():
    # Downsampling stack
    down_stack = [
        downsample(64, 4, apply_batchnorm=False), # (bs, 128, 128, 64)
        downsample(128, 4), # (bs, 64, 64, 128)
        downsample(256, 4), # (bs, 32, 32, 256)
        downsample(512, 4), # (bs, 16, 16, 512)
        downsample(512, 4), # (bs, 8, 8, 512)
        downsample(512, 4), # (bs, 4, 4, 512)
        downsample(512, 4), # (bs, 2, 2, 512)
        downsample(512, 4), # (bs, 1, 1, 512)
    ]
    # Upsampling stack
    up_stack = [
        upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
        upsample(512, 4), # (bs, 16, 16, 1024)
        upsample(256, 4), # (bs, 32, 32, 512)
        upsample(128, 4), # (bs, 64, 64, 256)
        upsample(64, 4), # (bs, 128, 128, 128)
    ]
    # Initialization
    initializer = tf.random_normal_initializer(0., 0.02)
    last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4, 
                                         strides=2, 
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='tanh') # (bs, 256, 256, 3)
    concat = tf.keras.layers.Concatenate() 
    
    inputs = tf.keras.layers.Input(shape=[None,None,3])
    x = inputs
    # Downsampling through the model
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)
    
    skips = reversed(skips[:-1])
    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = concat([x, skip])
    
    x = last(x)
    
    return tf.keras.Model(inputs=inputs, outputs=x)

### Testing the Generator on 1 image
You should be able to see an image composed of noise, with the trace of your first edge image. Ignore the warning if there is any.

In [19]:
generator = Generator()
#gen_output = generator(inp[tf.newaxis,...], training=False)
#plt.imshow(gen_output[0,...]);

## Build the Discriminator
  * The Discriminator is a PatchGAN.
  * Each block in the discriminator is (Conv -> BatchNorm -> Leaky ReLU).
  * The shape of the output after the last layer is (batch_size, 30, 30, 1).
  * Each 30x30 patch of the output classifies a 70x70 portion of the input image (such an architecture is called a PatchGAN).
  * Discriminator receives 2 inputs.
    * Input image and the target image, which it should classify as real.
    * Input image and the generated image (output of generator), which it should classify as fake. 
    * We concatenate these 2 inputs together in the code (`tf.concat([inp, tar], axis=-1)`).

In [20]:
# Defining the discriminator
def Discriminator():
    
    # Initialization
    initializer = tf.random_normal_initializer(0., 0.02)
    inp = tf.keras.layers.Input(shape=[None, None, 3], name='input_image')
    tar = tf.keras.layers.Input(shape=[None, None, 3], name='target_image')
    x = tf.keras.layers.concatenate([inp, tar]) # (bs, 256, 256, channels*2)
    
    # Downsampling blocks instantiation
    down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)
    down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
    down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)
    
    zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
    # 1st layer, Conv
    conv = tf.keras.layers.Conv2D(512, 4, strides=1, 
                                kernel_initializer=initializer, 
                                use_bias=False)(zero_pad1) # (bs, 31, 31, 512)
    # 2nd layer, Batchnorm
    batchnorm1 = tf.keras.layers.BatchNormalization()(conv)
    # 3rd layer, Leaky ReLU
    leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)
    zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)
    # 4th layer, Conv
    last = tf.keras.layers.Conv2D(1, 4, strides=1,
                                kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)
    
    return tf.keras.Model(inputs=[inp, tar], outputs=last)

### Testing the Discriminator on 1 image
You should be able to see an image composed of noise, with the trace of your first edge image. Ignore the warning if there is any.

In [21]:
discriminator = Discriminator()
#disc_out = discriminator([inp[tf.newaxis,...], gen_output], training=False)
#plt.imshow(disc_out[0,...,-1], vmin=-20, vmax=20, cmap='RdBu_r')
#plt.colorbar();

To learn more about the architecture and the hyperparameters you can refer the [paper](https://arxiv.org/abs/1611.07004).

## Define the loss functions and the optimizer

* **Discriminator loss**
  * The discriminator loss function takes 2 inputs; **real images, generated images**
  * real_loss is a sigmoid cross entropy loss of the **real images** and an **array of ones(since these are the real images)**
  * generated_loss is a sigmoid cross entropy loss of the **generated images** and an **array of zeros(since these are the fake images)**
  * Then the total_loss is the sum of real_loss and the generated_loss
  
* **Generator loss**
  * It is a sigmoid cross entropy loss of the generated images and an **array of ones**.
  * The [paper](https://arxiv.org/abs/1611.07004) also includes L1 loss which is MAE (mean absolute error) between the generated image and the target image.
  * This allows the generated image to become structurally similar to the target image.
  * The formula to calculate the total generator $loss = GAN_{loss} + \lambda * L1_{loss}$, where $\lambda = 100$. This value was decided by the authors of the [paper](https://arxiv.org/abs/1611.07004).

In [22]:
# General loss instantiation
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [23]:
# Defining the discriminator loss
def discriminator_loss(disc_real_output, disc_generated_output):
    
    real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
    generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)
    total_disc_loss = real_loss + generated_loss
    
    return total_disc_loss

In [24]:
# Defining the generator loss
def generator_loss(disc_generated_output, gen_output, target):
    LAMBDA = 100
    gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
    
    # mean absolute error
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
    total_gen_loss = gan_loss + (LAMBDA * l1_loss)
    
    return total_gen_loss

In [25]:
# Instantiating optimizers
generator_optimizer      = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

## Checkpoints (Object-based saving)
Creates a new checkpoint, for your new model. <br>
Do not modify the format of the checkpoint. If you do, modify the following cell corresponding to your new format.

In [None]:
#Directory
checkpoint_dir = './training_checkpoints'
if not os.path.isdir(checkpoint_dir):
    os.mkdir(checkpoint_dir)
# Loading
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

## Training

* We start by iterating over the dataset
* The generator gets the input image and we get a generated output.
* The discriminator receives the input_image and the generated image as the first input. The second input is the input_image and the target_image.
* Next, we calculate the generator and the discriminator loss.
* Then, we calculate the gradients of loss with respect to both the generator and the discriminator variables(inputs) and apply those to the optimizer.


## Generate Images

* After training, its time to generate some images!
* We pass images from the test dataset to the generator.
* The generator will then translate the input image into the output we expect.
* Last step is to plot the predictions and **voila!**

In [None]:
# Generates images based on the current model
def generate_images(model, test_input, tar=None, savePath = None):
    """
    The training=True is intentional here since we want the batch 
    statistics while running the model on the test dataset. If we 
    use training=False, we will get the accumulated statistics 
    learned from the training dataset (which we don't want).
    """
    
    # Prediction
    prediction = model(test_input, training=True)

    # Plotting
    plt.figure(figsize=(15,15))
    
    # With ground truth or not
    if tar != None:
        display_list = [test_input[0], tar[0], prediction[0]]
        title = ['Input Image', 'Ground Truth', 'Predicted Image']
    else:
        display_list = [test_input[0], prediction[0]]
        title = ['Input Image', 'Predicted Image']

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        # getting the pixel values between [0, 1] to plot it.
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    if savePath != None:
        plt.savefig(savePath)
    else:
        #plt.show()

In [None]:
# Trains on a single image. The function depends on the instantiation of 
# many functions and objects, make sure you ran through all cells.
@tf.function
def train_step(input_image, target):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator(input_image, training=True)
        
        disc_real_output = discriminator([input_image, target], training=True)
        disc_generated_output = discriminator([input_image, gen_output], training=True)
        gen_loss = generator_loss(disc_generated_output, gen_output, target)
        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
    
    generator_gradients = gen_tape.gradient(gen_loss, 
                                          generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(disc_loss, 
                                          discriminator.trainable_variables)
    
    generator_optimizer.apply_gradients(zip(generator_gradients, 
                                          generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients, 
                                              discriminator.trainable_variables))

In [None]:
# Trains on all the dataset
def train(dataset, epochs):  
    
    for epoch in range(epochs):
        start = time.time()
        
        for input_image, target in dataset:
            train_step(input_image, target)
        
        clear_output(wait=True)
        for inp, tar in test_dataset.take(1):
            generate_images(generator, inp, tar)
        # Saving (checkpoint) the model every 20 epochs
        if (epoch + 1) % 20 == 0: checkpoint.save(file_prefix = checkpoint_prefix)
        # Output to console. Trust me, it takes a while. Always good to have some sign of life.
        print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1, time.time()-start))

### Testing the Model on 1 image
You should be able to see the edge image, the real image and the predicted image, composed of noise, with the trace of your first edge image.

In [None]:
train(train_dataset, 1)

## Restore the latest checkpoint and test
Loads the last checkpoint, to load the trained model. <br>
Verify you have downloaded the lastest checkpoint. This is the trained version of the model. After the data, it should be close to the most expensive file memory-wise.

In [None]:
!ls {checkpoint_dir}

In [None]:
# Restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

## Testing on the entire test dataset

In [None]:
# Run the trained model on the entire test dataset
for i ,(inp, tar) in enumerate(test_dataset):
    generate_images(generator, inp, tar, "{}{}{}.jpeg".format(PATH, 'results/', i))

## Custom images - Dogs Testing

In [None]:
checkpoint_dir = './training_checkpoints/Dog/'
if not os.path.isdir(checkpoint_dir):
    os.mkdir(checkpoint_dir)
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

# restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

In [None]:
predict_dataset = tf.data.Dataset.list_files('custom/Dog/*.jpeg')
predict_dataset = predict_dataset.map(load_image_predict)
predict_dataset = predict_dataset.batch(1)

In [None]:
for inp in predict_dataset:
    generate_images(generator, inp)

## Custom images - Cats Testing

In [None]:
checkpoint_dir = './training_checkpoints/Cat/'
if not os.path.isdir(checkpoint_dir):
    os.mkdir(checkpoint_dir)
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

# restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

In [None]:
predict_dataset = tf.data.Dataset.list_files('custom/Cat/*.jpeg')
predict_dataset = predict_dataset.map(load_image_predict)
predict_dataset = predict_dataset.batch(1)

In [None]:
for inp in predict_dataset:
    generate_images(generator, inp)