# CycleGAN
CycleGAN is a method that can capture the characteristics of one image domain and figure out how these characteristics could be translated into another image domain, all in the absence of any paired training examples. It uses a cycle consistency loss to enable training without the need for paired data. In other words, it can translate from one domain to another without a one-to-one mapping between the source and target domain

In [2]:
"""
Installing tensorflow_examples package which will enable importing of 
Generator & Discriminator
"""

!pip install -q git+https://github.com/tensorflow/examples.git
!pip install -q -U tensorboard

  Building wheel for tensorflow-examples (setup.py) ... [?25l[?25hdone


In [0]:
# Importing libraries
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

from datetime import datetime
from tensorflow_examples.models.pix2pix import pix2pix

%matplotlib inline

In [0]:
# Gloabl vars & hyper-params
AUTOTUNE = tf.data.experimental.AUTOTUNE
BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
OUTPUT_CHANNELS = 3
LAMBDA = 100
EPOCHS = 50

## Utility Functions

In [0]:
def show_results(horse, gen_zebra, zebra, gen_horse, save_fig=False):
  titles = ["Horse", "To Zebra", "Zebra", "To Horse"]
  images = [horse, gen_zebra, zebra, gen_horse]
  fig = plt.figure() 
  fig.figsize=(10,10)
  contrast = 8
  for i in range(4):
    plt.subplot(2, 2, i+1)
    plt.title(titles[i])
    # Pixel values should be in the range 0-1 for float values
    if (i+1) % 2 == 0:
      plt.imshow(images[i] * 0.5 * contrast + 0.5)
    else:
      plt.imshow(images[i] * 0.5 + 0.5)
    plt.axis("off")

  plt.show()
  
  if save_fig:
    f_name = datetime.now().strftime("%Y%m%d-%H%M%S")
    fig.savefig(f"{f_name}.png")

In [0]:
def resize_image(image, height, width):
  resized_image = tf.image.resize(image, [height, width], 
                                      method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  return resized_image

In [0]:
def get_random_crop(image):
  cropped_img = tf.image.random_crop(image, [IMG_HEIGHT, IMG_WIDTH, 3])
  
  return cropped_img

In [0]:
def normalize_img(image):
  """
  Normalizing image in the range of [-1, 1]
  """

  image = tf.cast(image, tf.float32)
  normalized_image = (image / 127.5) - 1

  return normalized_image

In [0]:
def preprocess_image(image):
  """
  This function will apply random jittering, cropping, mirroring and 
  normalization. 
  """

  # Normalizing images
  normalized_img = normalize_img(image)

  # Resizing images to 286 X 286 X 3
  resized_img = resize_image(normalized_img, 286, 286)

  # Cropping images to 256 X 256 X 3
  cropped_img = get_random_crop(resized_img)
  
  # Applying random mirroring
  processed_img = tf.image.flip_left_right(cropped_img)

  return processed_img

In [0]:
def prepare_training_data(image, label):
  return preprocess_image(image)

In [0]:
def prepare_test_data(image, label):
  resized_img = resize_image(image, IMG_HEIGHT, IMG_WIDTH)
  
  return normalize_img(resized_img)

In [0]:
def discriminator_loss(real_out, fake_out, loss_cal):
  real_loss = loss_cal(tf.ones_like(real_out), real_out)
  fake_loss = loss_cal(tf.zeros_like(fake_out), fake_out)

  return (real_loss + fake_loss) * 0.5

In [0]:
def generator_loss(fake_out, loss_cal):
  return loss_cal(tf.ones_like(fake_out), fake_out)

In [0]:
def cycle_loss(real_img, cycled_img):
  loss = tf.reduce_mean(tf.abs(real_img - cycled_img))

  return loss * LAMBDA

In [0]:
def identity_loss(real_img, same_img):
  loss = tf.reduce_mean(tf.abs(real_img - same_img))

  return loss * 0.5 * LAMBDA

### Losses:


* GAN Loss: This is a typical GAN loss. 
* Cycle Loss: This is a loss that helps both the generator to generate meaningful mapping.
* Identity Loss: This loss is helpful when we want to preserve some color composition between input and output. Sometimes we may need to keep some things as it is as which are common to both the input and the output images.



In [0]:
@tf.function
def train(real_x, real_y, gen_g, gen_f, desc_x, desc_y, gen_g_opt, gen_f_opt, 
          desc_x_opt, desc_y_opt, loss_cal, sum_writer, epoch):
  with tf.GradientTape(persistent=True) as tape:
    # Forward cycle: X -> Y -> X'
    fake_y = gen_g(real_x, training=True)
    cycled_x = gen_f(fake_y, training=True)

    # Backward cycle: Y -> X -> Y'
    fake_x = gen_f(real_y, training=True)
    cycled_y = gen_g(fake_x, training=True)

    # Identity check (no changing)
    same_y = gen_g(real_y, training=True)
    same_x = gen_f(real_x, training=True)

    # Forward passes for discriminator network X
    disc_x_real_out = desc_x(real_x, training=True)
    disc_x_fake_out = desc_x(fake_x, training=True)

    # Forward passes for discriminator network Y
    disc_y_real_out = desc_y(real_y, training=True)
    disc_y_fake_out = desc_y(fake_y, training=True)

    # GAN loss
    gen_g_loss = generator_loss(disc_y_fake_out, loss_cal)
    gen_f_loss = generator_loss(disc_x_fake_out, loss_cal)

    desc_x_loss = discriminator_loss(disc_x_real_out, disc_x_fake_out, loss_cal)
    desc_y_loss = discriminator_loss(disc_y_real_out, disc_y_fake_out, loss_cal)

    # Cycle loss
    total_cycle_loss = cycle_loss(real_x, cycled_x) + cycle_loss(real_y, cycled_y)

    # Identity loss
    gen_g_identity_loss = identity_loss(real_y, same_y)
    gen_f_identity_loss = identity_loss(real_x, same_x)

    # Total generator losses
    total_gen_g_loss = gen_g_loss + total_cycle_loss + gen_g_identity_loss
    total_gen_f_loss = gen_f_loss + total_cycle_loss + gen_f_identity_loss

  # Calculating gradients
  gen_g_gradients = tape.gradient(total_gen_g_loss, gen_g.trainable_variables)
  gen_f_gradients = tape.gradient(total_gen_f_loss, gen_f.trainable_variables)

  desc_x_gradients = tape.gradient(desc_x_loss, desc_x.trainable_variables)
  desc_y_gradients = tape.gradient(desc_y_loss, desc_y.trainable_variables)

  # Apply gradients to optimize weights
  gen_g_opt.apply_gradients(zip(gen_g_gradients, gen_g.trainable_variables))
  gen_f_opt.apply_gradients(zip(gen_f_gradients, gen_f.trainable_variables))

  desc_x_opt.apply_gradients(zip(desc_x_gradients, desc_x.trainable_variables))
  desc_y_opt.apply_gradients(zip(desc_y_gradients, desc_y.trainable_variables))

  with sum_writer.as_default():
    tf.summary.scalar('gen_g_loss', total_gen_g_loss, step=epoch)
    tf.summary.scalar('gen_f_loss', total_gen_f_loss, step=epoch)
    tf.summary.scalar('desc_x_loss', desc_x_loss, step=epoch)
    tf.summary.scalar('disc_y_loss', desc_y_loss, step=epoch)

## Preparing Input Pipeline
In this tutorial we will train a model that will learn to translate from images of horses, to images of zebras.

In [0]:
dataset, metadata = tfds.load('cycle_gan/horse2zebra',
                              with_info=True, as_supervised=True)

train_horses, train_zebras = dataset['trainA'], dataset['trainB']
test_horses, test_zebras = dataset['testA'], dataset['testB']

In [0]:
train_horses = train_horses.map(
    prepare_training_data, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

train_zebras = train_zebras.map(
    prepare_training_data, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

test_horses = test_horses.map(
    prepare_test_data, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

test_zebras = test_zebras.map(
    prepare_test_data, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

## Let Training Begin
There are 2 generators (G and F) and 2 discriminators (X and Y) being trained here. 

* Generator G learns to transform image X to image Y.  (G: X -> Y)
* Generator F learns to transform image Y to image X.  (F: Y -> X)
* Discriminator D_X learns to differentiate between image X and generated image X (F(Y)).
* Discriminator D_Y learns to differentiate between image Y and generated image Y (G(X)).

In [0]:
generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

In [0]:
gen_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
gen_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

disc_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
disc_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [0]:
# Preparing tensorboad for visualization
log_dir = "logs/"
summary_writer = tf.summary.create_file_writer(
    log_dir + "fit/" + datetime.now().strftime("%Y%m%d-%H%M%S"))

%load_ext tensorboard
%tensorboard --logdir {log_dir}

In [0]:
zipped_dataset = tf.data.Dataset.zip((train_horses, train_zebras))
loss_cal = tf.keras.losses.BinaryCrossentropy(from_logits=True)
for epoch in range(EPOCHS):
  for horse_img, zebra_img in zipped_dataset:
    train(horse_img, zebra_img, generator_g, generator_f, discriminator_x, discriminator_y, gen_g_optimizer, gen_f_optimizer, 
          disc_x_optimizer, disc_y_optimizer, loss_cal, summary_writer, epoch)
  
  if (epoch+1) % 10 == 0:
    for horse, zebra in zip(test_horses.take(1), test_zebras.take(1)):
      gen_zebra = generator_g(horse)
      gen_horse = generator_g(zebra)
      show_results(horse[0], gen_zebra[0], zebra[0], gen_horse[0], save_fig=True)

  print(f"Epoch {epoch} done.")