<a href="https://colab.research.google.com/github/ghaiszaher/Foggy-CycleGAN/blob/master/Foggy_CycleGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CycleFoggyGAN

## Set up the input pipeline

In [0]:
!pip install git+https://github.com/tensorflow/examples.git

In [0]:
try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass
import tensorflow as tf

In [0]:
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix

import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output

tfds.disable_progress_bar()
AUTOTUNE = tf.data.experimental.AUTOTUNE

## Input Pipeline

This tutorial trains a model to translate from images of A, to images of B. You can find this dataset and similar ones [here](https://www.tensorflow.org/datasets/datasets#cycle_gan). 

As mentioned in the [paper](https://arxiv.org/abs/1703.10593), apply random jittering and mirroring to the training dataset. These are some of the image augmentation techniques that avoids overfitting.

This is similar to what was done in [pix2pix](https://www.tensorflow.org/tutorials/generative/pix2pix#load_the_dataset)

* In random jittering, the image is resized to `286 x 286` and then randomly cropped to `256 x 256`.
* In random mirroring, the image is randomly flipped horizontally i.e left to right.

In [0]:
BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256

In [0]:
def random_crop(image):
  cropped_image = tf.image.random_crop(
      image, size=[IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image

In [0]:
# normalizing the images to [-1, 1]
def normalize(image):
  image = (image * 2) - 1
  return image

In [0]:
#back to range [0, 1]
def denormalize(image):
  image = (image + 1)/2.
  return image

In [0]:
def random_jitter(image):
  # resizing to 286 x 286 x 3
  image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  # randomly cropping to 256 x 256 x 3
  image = random_crop(image)

  # random mirroring
  image = tf.image.random_flip_left_right(image)
  
  return image

In [0]:
def rgb_to_hsv(image):  
  image = tf.image.rgb_to_hsv(image)
  return image

In [0]:
def hsv_to_hsl(image):
  h = image[...,0:1]
  s = image[...,1:2]
  v = image[...,2:3]

  l = v*(1-s/2)
  s = tf.where((l==0) | (l==1),0.,(v-l)/tf.math.minimum(l,1-l))

  hsl = tf.concat([h,s,l],axis=-1)
  return hsl

In [0]:
def hsl_to_hsv(image):
  h = image[...,0:1]
  s = image[...,1:2]
  l = image[...,2:3]

  v = l + s*tf.math.minimum(l,1-l)
  s = tf.where(v==0, 0., 2*(1-l/v))

  hsv = tf.concat([h,s,v], axis=-1)
  return hsv

In [0]:
def output_to_rgb(image):
  return image
  # return denormalize(image)

In [0]:
def preprocess_image_train(image, label):
  image = tf.cast(image, tf.float32)/255.
  image = random_jitter(image)
  # image = normalize(image)
  return image

In [0]:
def preprocess_image_test(image, label):
  image = tf.cast(image, tf.float32)/255.
  # image = normalize(image)
  return image

In [0]:
dataset, metadata = tfds.load('cycle_gan/horse2zebra',
                              with_info=True, as_supervised=True)

train_A, train_B = dataset['trainA'], dataset['trainB']
test_A, test_B = dataset['testA'], dataset['testB']

In [0]:
train_A = train_A.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

train_B = train_B.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

test_A = test_A.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

test_B = test_B.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

In [0]:
sample_A = next(iter(train_A))
sample_B = next(iter(train_B))

In [0]:
plt.subplot(121)
plt.title('A')
plt.imshow(output_to_rgb(sample_A[0]))

plt.subplot(122)
plt.title('A with random jitter')
plt.imshow(output_to_rgb(random_jitter(sample_A[0])))

In [0]:
plt.subplot(121)
plt.title('B')
plt.imshow(output_to_rgb(sample_B[0]))

plt.subplot(122)
plt.title('B with random jitter')
plt.imshow(output_to_rgb(random_jitter(sample_B[0])))

## Build Generator

In [0]:
OUTPUT_CHANNELS = 3

# generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
# generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

# discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
# discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

In [0]:
def downsample(filters, size, apply_batchnorm=True):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

  if apply_batchnorm:
    result.add(tf.keras.layers.BatchNormalization())

  result.add(tf.keras.layers.LeakyReLU())

  return result

In [0]:
down_model = downsample(3, 4)
down_result = down_model(tf.expand_dims(sample_A[0], 0))
print (down_result.shape)

In [0]:
def upsample(filters, size, apply_dropout=False):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
    tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                    padding='same',
                                    kernel_initializer=initializer,
                                    use_bias=False))

  result.add(tf.keras.layers.BatchNormalization())

  if apply_dropout:
      result.add(tf.keras.layers.Dropout(0.5))

  result.add(tf.keras.layers.ReLU())

  return result

In [0]:
up_model = upsample(3, 4)
up_result = up_model(down_result)
print (up_result.shape)

In [0]:
def gauss_blur_model(input_shape, kernel_size=19, sigma=5, **kwargs):
  import numpy as np
  def matlab_style_gauss2D(shape=(3,3),sigma=0.5):
      """
      2D gaussian mask - should give the same result as MATLAB's
      fspecial('gaussian',[shape],[sigma])
      #https://stackoverflow.com/questions/55643675/how-do-i-implement-gaussian-blurring-layer-in-keras
      #https://stackoverflow.com/questions/17190649/how-to-obtain-a-gaussian-filter-in-python/17201686#17201686
      """
      m,n = [(ss-1.)/2. for ss in shape]
      y,x = np.ogrid[-m:m+1,-n:n+1]
      h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )
      h[ h < np.finfo(h.dtype).eps*h.max() ] = 0
      sumh = h.sum()
      if sumh != 0:
          h /= sumh
      return h  
  class SymmetricPadding2D(tf.keras.layers.Layer):
      #Source: https://stackoverflow.com/a/55210905/11394663
      def __init__(self, output_dim, padding=[1,1], 
                  data_format="channels_last", **kwargs):
          self.output_dim = output_dim
          self.data_format = data_format
          self.padding = padding
          super(SymmetricPadding2D, self).__init__(**kwargs)

      def build(self, input_shape):
          super(SymmetricPadding2D, self).build(input_shape)

      def call(self, inputs):
          if self.data_format is "channels_last":
              #(batch, depth, rows, cols, channels)
              pad = [[0,0]] + [[i,i] for i in self.padding] + [[0,0]]
          elif self.data_format is "channels_first":
              #(batch, channels, depth, rows, cols)
              pad = [[0, 0], [0, 0]] + [[i,i] for i in self.padding]
          paddings = tf.constant(pad)
          out = tf.pad(inputs, paddings, "REFLECT")
          return out 

      def compute_output_shape(self, input_shape):
          return (input_shape[0], self.output_dim)  
  if kernel_size % 2 == 0:
    raise Exception("kernel size should be an odd number")
  gauss_inputs = tf.keras.layers.Input(shape=input_shape)
  #### Gaussian Blur #####
  kernel_weights = matlab_style_gauss2D(shape=(kernel_size,kernel_size), sigma=sigma)
  in_channels = input_shape[-1]
  kernel_weights = np.expand_dims(kernel_weights, axis=-1)
  kernel_weights = np.repeat(kernel_weights, in_channels, axis=-1) # apply the same filter on all the input channels
  kernel_weights = np.expand_dims(kernel_weights, axis=-1)  # for shape compatibility reasons
  gauss_layer = tf.keras.layers.DepthwiseConv2D(kernel_size, use_bias=False, padding='valid')
  p = (kernel_size-1)//2
  x = SymmetricPadding2D(0, padding=[p,p])(gauss_inputs)
  x = gauss_layer(x)
  ########################
  gauss_layer.set_weights([kernel_weights])
  gauss_layer.trainable = False
  return tf.keras.Model(inputs=gauss_inputs, outputs=x, **kwargs)

In [0]:
def Generator():
 
  inputs = tf.keras.layers.Input(shape=[256,256,3])

  down_stack = [
    downsample(64, 4, apply_batchnorm=False), # (bs, 128, 128, 64)
    downsample(128, 4), # (bs, 64, 64, 128)
    downsample(256, 4), # (bs, 32, 32, 256)
    downsample(512, 4), # (bs, 16, 16, 512)
    downsample(512, 4), # (bs, 8, 8, 512)
    downsample(512, 4), # (bs, 4, 4, 512)
    downsample(512, 4), # (bs, 2, 2, 512)
    downsample(512, 4), # (bs, 1, 1, 512)
  ]

  up_stack = [
    upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
    upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
    upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
    upsample(512, 4), # (bs, 16, 16, 1024)
    upsample(256, 4), # (bs, 32, 32, 512)
    upsample(128, 4), # (bs, 64, 64, 256)
    upsample(64, 4), # (bs, 128, 128, 128)
  ]

  initializer = tf.random_normal_initializer(0., 0.02)
  last = tf.keras.layers.Conv2DTranspose(1, 4,
                                         strides=2,
                                         padding='same',
                                         name = 'transmission_layer',
                                         kernel_initializer=initializer,
                                         activation='sigmoid') # (bs, 256, 256, 1)

  x = inputs
  # channel1 = tf.keras.layers.Lambda(lambda x:x[:,:,:,0:1])(inputs)
  # channel2 = tf.keras.layers.Lambda(lambda x:x[:,:,:,1:2])(inputs)

  # Downsampling through the model
  skips = []
  for down in down_stack:
    x = down(x)
    skips.append(x)

  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    x = tf.keras.layers.Concatenate()([x, skip])

  transmission = last(x)
  # transmission = tf.keras.layers.Lambda(lambda image: tf.image.resize(image, [128, 128],
                          # method=tf.image.ResizeMethod.GAUSSIAN))(transmission)
  # transmission = tf.keras.layers.Lambda(lambda image: tf.image.resize(image, [256, 256],
                          # method=tf.image.ResizeMethod.GAUSSIAN), name='final_transmission_layer')(transmission)

  transmission = gauss_blur_model([256,256,1],name="gauss_blur")(transmission)

  x = tf.keras.layers.multiply([inputs, transmission])
  one_minus_t = tf.keras.layers.Lambda(lambda x:1-x, name='transmission_invert')(transmission)
  x = tf.keras.layers.add([x,one_minus_t])
  # x = tf.keras.layers.Concatenate()([channel1, channel2, channel3])

  return tf.keras.Model(inputs=inputs, outputs=x)

In [0]:
generator_g = Generator()
generator_f = Generator()
tf.keras.utils.plot_model(generator_g, show_shapes=True, dpi=64)
# discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
# discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

In [0]:
to_B = generator_g(sample_A)
to_A = generator_f(sample_B)
plt.figure(figsize=(8, 8))
contrast = 8

imgs = [sample_A, to_B, sample_B, to_A]
title = ['A', 'To B', 'B', 'To A']

for i in range(len(imgs)):
  plt.subplot(2, 2, i+1)
  plt.title(title[i])
  if i % 2 == 0:
    plt.imshow(output_to_rgb(imgs[i][0]))
  else:
    plt.imshow(output_to_rgb(imgs[i][0]))
plt.show()

## Build Discriminator

In [0]:
def Discriminator():
  initializer = tf.random_normal_initializer(0., 0.02)

  inp = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')
  # tar = tf.keras.layers.Input(shape=[256, 256, 3], name='target_image')

  # x = tf.keras.layers.concatenate([inp, tar]) # (bs, 256, 256, channels*2)

  down1 = downsample(64, 4, False)(inp) # (bs, 128, 128, 64)
  down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
  down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)

  zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
  conv = tf.keras.layers.Conv2D(512, 4, strides=1,
                                kernel_initializer=initializer,
                                use_bias=False)(zero_pad1) # (bs, 31, 31, 512)

  batchnorm1 = tf.keras.layers.BatchNormalization()(conv)

  leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)

  zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

  last = tf.keras.layers.Conv2D(1, 4, strides=1,
                                kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)

  return tf.keras.Model(inputs=inp, outputs=last)

In [0]:
discriminator_x = Discriminator()
discriminator_y = Discriminator()

In [0]:
tf.keras.utils.plot_model(discriminator_x, show_shapes=True, dpi=64)

In [0]:
plt.figure(figsize=(8, 8))

plt.subplot(121)
plt.title('Is a real B?')
plt.imshow(discriminator_y(sample_B)[0, ..., -1], cmap='RdBu_r')

plt.subplot(122)
plt.title('Is a real A?')
plt.imshow(discriminator_x(sample_A)[0, ..., -1], cmap='RdBu_r')

plt.show()

## Loss functions

In [0]:
LAMBDA = 10

In [0]:
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [0]:
def discriminator_loss(real, generated):
  real_loss = loss_obj(tf.ones_like(real), real)

  generated_loss = loss_obj(tf.zeros_like(generated), generated)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss * 0.5

In [0]:
def generator_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)

In [0]:
def calc_cycle_loss(real_image, cycled_image):
  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
  
  return LAMBDA * loss1

In [0]:
def identity_loss(real_image, same_image):
  loss = tf.reduce_mean(tf.abs(real_image - same_image))
  return LAMBDA * 0.5 * loss

In [0]:
generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

## Checkpoints

In [0]:
# from google.colab import drive
# drive.mount('/content/drive')

In [0]:
# checkpoint_path = "./drive/My Drive/Colab Notebooks/CycleGAN/tf_learn/06-train/"
checkpoint_path = "./06-train/"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=1)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')
else:
  print('No checkpoint found.')

## Training 

In [0]:
EPOCHS = 40

In [0]:
def generate_images(modelA, test_inputA, modelB, test_inputB):
  predictionA = modelA(test_inputA)
  predictionB = modelB(test_inputB)
    
  plt.figure(figsize=(12, 12))

  display_list = [test_inputA[0], predictionA[0], test_inputB[0], predictionB[0]]
  title = ['A', 'To B', 'B', 'To A']

  for i in range(4):
    plt.subplot(2, 2, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(output_to_rgb(display_list[i]))
    plt.axis('off')
  plt.show()

In [0]:
generate_images(generator_g, sample_A, generator_f, sample_B)

In [0]:
@tf.function
def train_step(real_x, real_y):
  # persistent is set to True because the tape is used more than
  # once to calculate the gradients.
  with tf.GradientTape(persistent=True) as tape:
    # Generator G translates X -> Y
    # Generator F translates Y -> X.
    
    fake_y = generator_g(real_x, training=True)
    cycled_x = generator_f(fake_y, training=True)

    fake_x = generator_f(real_y, training=True)
    cycled_y = generator_g(fake_x, training=True)

    # same_x and same_y are used for identity loss.
    same_x = generator_f(real_x, training=True)
    same_y = generator_g(real_y, training=True)

    disc_real_x = discriminator_x(real_x, training=True)
    disc_real_y = discriminator_y(real_y, training=True)

    disc_fake_x = discriminator_x(fake_x, training=True)
    disc_fake_y = discriminator_y(fake_y, training=True)

    # calculate the loss
    gen_g_loss = generator_loss(disc_fake_y)
    gen_f_loss = generator_loss(disc_fake_x)
    
    total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)
    
    # Total generator loss = adversarial loss + cycle loss
    total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
    total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

    disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
    disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)
  
  # Calculate the gradients for generator and discriminator
  generator_g_gradients = tape.gradient(total_gen_g_loss, 
                                        generator_g.trainable_variables)
  generator_f_gradients = tape.gradient(total_gen_f_loss, 
                                        generator_f.trainable_variables)
  
  discriminator_x_gradients = tape.gradient(disc_x_loss, 
                                            discriminator_x.trainable_variables)
  discriminator_y_gradients = tape.gradient(disc_y_loss, 
                                            discriminator_y.trainable_variables)
  
  # Apply the gradients to the optimizer
  generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                            generator_g.trainable_variables))

  generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                            generator_f.trainable_variables))
  
  discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                discriminator_x.trainable_variables))
  
  discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                discriminator_y.trainable_variables))

In [0]:
length = "Unknown"
for epoch in range(EPOCHS):
  start = time.time()

  n = 0
  # Using a consistent image (sample_A) so that the progress of the model
  # is clearly visible.
  for A, B in zip(test_A.take(1), test_B.take(1)):
    generate_images(generator_g, A, generator_f, B)
  dataset = tf.data.Dataset.zip((train_A, train_B))
  for image_x, image_y in dataset:
    # print(image_x.shape, image_y.shape)
    train_step(image_x, image_y)
    if(n%10==0):
      print ('{}/{}'.format(n,length))
    n+=1
  length = n

  clear_output(wait=True)

  ckpt_save_path = ckpt_manager.save()
  print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                        ckpt_save_path))

  print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                      time.time()-start))

In [0]:
sample_A = next(iter(test_A))
sample_B = next(iter(test_B))

In [0]:
gen = generator_g
image = sample_A
# Check Transmission map values
model1 = tf.keras.Model(inputs = gen.inputs, outputs = gen.get_layer('transmission_layer').output)
t = model1(image)
model2 = tf.keras.Model(inputs = gen.inputs, outputs = gen.get_layer('transmission_invert').output)
tgauss = model2(image)
# t = tf.image.resize(t, [32, 32],
#                           method=tf.image.ResizeMethod.GAUSSIAN)
# t = tf.image.resize(t, [256, 256],
#                           method=tf.image.ResizeMethod.GAUSSIAN)
plt.figure(figsize=(12,12)) 
plt.subplot(2,2,1)                       
plt.imshow((1-t[0]).numpy().squeeze(), cmap='gray')
plt.subplot(2,2,2)                       
plt.imshow((tgauss[0]).numpy().squeeze(), cmap='gray')
plt.subplot(2,2,3)                       
plt.imshow((image).numpy().squeeze(), cmap='gray')
plt.subplot(2,2,4)                       
plt.imshow((gen(image)).numpy().squeeze(), cmap='gray')
# generate_images(generator_g, sample_A, generator_f, sample_B)

In [0]:
gen.summary()