# CycleGAN [with horse2zebra dataset]

* `Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks`, [arXiv:1703.10593](https://arxiv.org/abs/1703.10593)
  * Jun-Yan Zhu, Taesung Park, Phillip Isola, Alexei A. Efros

* This code is available to tensorflow version 2.0
* Implemented by [`tf.keras.layers`](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/layers) [`tf.losses`](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/losses)

## Import modules

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import os
import sys
import time
import glob

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import PIL
import imageio
from IPython import display
import urllib.request
import zipfile

import tensorflow as tf
from tensorflow.keras import layers

sys.path.append(os.path.dirname(os.path.abspath('.')))
from utils.image_utils import *
from utils.ops import *

os.environ["CUDA_VISIBLE_DEVICES"]="0"

## Setting hyperparameters

In [None]:
# Training Flags (hyperparameter configuration)
model_name = 'cyclegan'
train_dir = os.path.join('train', model_name, 'exp1')

constant_lr_epochs = 100
decay_lr_epochs = 100
max_epochs = constant_lr_epochs + decay_lr_epochs
save_model_epochs = 20
print_steps = 50
save_images_epochs = 5
batch_size = 1
learning_rate_D = 2e-4
learning_rate_G = 2e-4

BUFFER_SIZE = 10000
IMG_SIZE = 256
assert IMG_SIZE in [128, 256]
LAMBDA = 10

## Load the dataset

You can download this dataset and similar datasets from [here](https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/). 
This script source is borrowed from [original CycleGAN github repo.](https://github.com/junyanz/CycleGAN/blob/master/datasets/download_dataset.sh)


As mentioned in the [paper](https://arxiv.org/abs/1703.10593) we apply random jittering and mirroring to the training dataset.
* In random jittering, the image is resized to 286 x 286 and then randomly cropped to 256 x 256
* In random mirroring, the image is randomly flipped horizontally i.e left to right.

In [None]:
DATASETS = ["ae_photos",
            "apple2orange",
            "summer2winter_yosemite",
            "horse2zebra",
            "monet2photo",
            "cezanne2photo",
            "ukiyoe2photo",
            "vangogh2photo",
            "maps",
            "cityscapes",
            "facades",
            "iphone2dslr_flower",
            "ae_photos"]

dataset_name = "horse2zebra"
#dataset_name = "cityscapes"

url = 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/' + dataset_name + '.zip'
datasets_path = '../datasets'
if not os.path.isdir(datasets_path):
  os.makedirs(datasets_path)
zipfile_path = os.path.join(datasets_path, dataset_name + '.zip')

# Download dataset
if not os.path.isfile(zipfile_path):
  urllib.request.urlretrieve(url=url, filename=zipfile_path)
  print('download done')
else:
  print('zipfile already exists')

# Extract zipfile
PATH = os.path.join(datasets_path, dataset_name)
if not os.path.isdir(PATH):
  zip_ref = zipfile.ZipFile(zipfile_path, 'r')
  zip_ref.extractall(datasets_path)
  print('zipfile extract done')
else:
  print('zipfile already extracted')

## Set up dataset with `tf.data`

### Image augmentation

In [None]:
def load(image_file):
  image = tf.io.read_file(image_file)
  image = tf.image.decode_jpeg(image, channels=3) # fix the output channels for intentionally

  input_image = tf.cast(image, tf.float32)

  return input_image

In [None]:
imageA = load(glob.glob(os.path.join(PATH, 'trainA/*.jpg'))[1])
imageB = load(glob.glob(os.path.join(PATH, 'trainB/*.jpg'))[1])
# casting to int for matplotlib to show the image
plt.figure()
plt.imshow(imageA/255.0)
plt.figure()
plt.imshow(imageB/255.0)

In [None]:
def resize(input_image, height, width):
  input_image = tf.image.resize(input_image, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  
  return input_image

In [None]:
def random_crop(input_image):
  input_image = tf.image.random_crop(input_image, size=[IMG_SIZE, IMG_SIZE, 3])

  return input_image

In [None]:
# normalizing the images to [-1, 1]
def normalize(input_image):
  input_image = tf.clip_by_value(input_image, 0.0, 255.0)
  input_image = (input_image / 127.5) - 1

  return input_image

In [None]:
@tf.function()
def random_jitter(input_image):
  # resizing to 286 x 286 x 3
  if IMG_SIZE == 256:
    RESIZE = 286
  else:
    RESIZE = 145
  input_image = resize(input_image, RESIZE, RESIZE)

  # randomly cropping to 256 x 256 x 3
  input_image = random_crop(input_image)

  if tf.random.uniform(()) > 0.5:
    # random mirroring
    input_image = tf.image.flip_left_right(input_image)

  return input_image

In [None]:
# As you can see in the images below
# that they are going through random jittering
# Random jittering as described in the paper is to
# 1. Resize an image to bigger height and width
# 2. Randomnly crop to the original size
# 3. Randomnly flip the image horizontally

plt.figure(figsize=(6, 6))
for i in range(4):
  rj_imageA = random_jitter(imageA)
  plt.subplot(2, 2, i+1)
  plt.imshow(rj_imageA/255.0)
  plt.axis('off')
plt.show()

In [None]:
def load_image_train(image_file):
  input_image = load(image_file)
  input_image = random_jitter(input_image)
  input_image = normalize(input_image)

  return input_image

In [None]:
def load_image_test(image_file):
  input_image = load(image_file)
  input_image = resize(input_image, IMG_SIZE, IMG_SIZE)
  input_image = normalize(input_image)

  return input_image

### Input pipeline

* Use tf.data to create batches, map(do preprocessing) and shuffle the dataset

In [None]:
N_trainX = len(glob.glob(os.path.join(PATH, 'trainA/*.jpg')))
trainX_dataset = tf.data.Dataset.list_files(os.path.join(PATH, 'trainA/*.jpg'))
trainX_dataset = trainX_dataset.shuffle(N_trainX)
trainX_dataset = trainX_dataset.map(load_image_train,
                                    #num_parallel_calls=tf.data.experimental.AUTOTUNE # Error of out of memory
                                    num_parallel_calls=16)
trainX_dataset = trainX_dataset.batch(batch_size, drop_remainder=True)

In [None]:
N_trainY = len(glob.glob(os.path.join(PATH, 'trainB/*.jpg')))
trainY_dataset = tf.data.Dataset.list_files(os.path.join(PATH, 'trainB/*.jpg'))
trainY_dataset = trainY_dataset.shuffle(N_trainY)
trainY_dataset = trainY_dataset.map(load_image_train,
                                    #num_parallel_calls=tf.data.experimental.AUTOTUNE # Error of out of memory
                                    num_parallel_calls=16)
trainY_dataset = trainY_dataset.batch(batch_size, drop_remainder=True)

In [None]:
N_testX = len(glob.glob(os.path.join(PATH, 'testA/*.jpg')))
testX_dataset = tf.data.Dataset.list_files(os.path.join(PATH, 'testA/*.jpg'))
# shuffling so that for every epoch a different image is generated
# to predict and display the progress of our model.
testX_dataset = testX_dataset.shuffle(N_testX*3)
testX_dataset = testX_dataset.map(load_image_test)
testX_dataset = testX_dataset.batch(batch_size, drop_remainder=True)

In [None]:
N_testY = len(glob.glob(os.path.join(PATH, 'testB/*.jpg')))
testY_dataset = tf.data.Dataset.list_files(os.path.join(PATH, 'testB/*.jpg'))
# shuffling so that for every epoch a different image is generated
# to predict and display the progress of our model.
testY_dataset = testY_dataset.shuffle(N_testY*3)
testY_dataset = testY_dataset.map(load_image_test)
testY_dataset = testY_dataset.batch(batch_size, drop_remainder=True)

In [None]:
print("number of examples in trainA: {}".format(N_trainX))
print("number of examples in trainB: {}".format(N_trainY))
print("number of examples in testA: {}".format(N_testX))
print("number of examples in testB: {}".format(N_testY))
N = min(N_trainX, N_trainY)
print("number of examples in one epoch: {}".format(N))

## Write the generator and discriminator models

### Generator
  * The architecture of generator is based on [Johnson's architecture](https://arxiv.org/abs/1603.08155).
  * Conv block in the generator is (Conv -> Batchnorm -> ReLU)
  * Res block in the generator is (Conv -> ReLU -> Conv -> add X)
  * ConvTranspose block in the generator is (Transposed Conv -> Batchnorm -> ReLU)

In [None]:
class InstanceNormalization(layers.Layer):
  """InstanceNormalization for only 4-rank Tensor (image data)
  """
  def __init__(self, epsilon=1e-5):
    super(InstanceNormalization, self).__init__()
    self.epsilon = epsilon

  def build(self, input_shape):
    shape = tf.TensorShape(input_shape)
    param_shape = shape[-1]
    # Create a trainable weight variable for this layer.
    self.gamma = self.add_weight(name='gamma',
                                 shape=param_shape,
                                 initializer='ones',
                                 trainable=True)
    self.beta = self.add_weight(name='beta',
                                shape=param_shape,
                                initializer='zeros',
                                trainable=True)
    # Make sure to call the `build` method at the end
    super(InstanceNormalization, self).build(input_shape)

  def call(self, inputs):
    # Compute the axes along which to reduce the mean / variance
    input_shape = inputs.get_shape()
    reduction_axes = [1, 2] # only shape index
    mean, variance = tf.nn.moments(inputs, reduction_axes, keepdims=True)
    normalized = (inputs - mean) / tf.sqrt(variance + self.epsilon)
    return self.gamma * normalized + self.beta

In [None]:
class Conv(tf.keras.Model):
  def __init__(self, filters, size, strides=1, padding='same', activation='relu',
               apply_norm='instance', norm_momentum=0.9, norm_epsilon=1e-5):
    super(Conv, self).__init__()
    assert apply_norm in ['batch', 'instance', 'none']
    self.apply_norm = apply_norm
    assert activation in ['relu', 'tanh', 'none']
    self.activation = activation
    
    if self.apply_norm == 'none':
      use_bias = True
    else:
      use_bias = False
    
    self.conv = layers.Conv2D(filters=filters,
                              kernel_size=(size, size),
                              strides=strides,
                              padding=padding,
                              kernel_initializer=tf.random_normal_initializer(0., 0.02),
                              use_bias=use_bias)
    
    if self.apply_norm == 'instance':
      self.instancenorm = InstanceNormalization()
    elif self.apply_norm == 'batch':
      self.batchnorm = layers.BatchNormalization(momentum=norm_momentum,
                                                 epsilon=norm_epsilon)
    else:
      pass
  
  def call(self, x, training):
    # convolution
    x = self.conv(x)
    
    # normalization
    if self.apply_norm == 'instance':
      x = self.instancenorm(x)
    elif self.apply_norm == 'batch':
      x = self.batchnorm(x, training=training)
    else:
      pass
    
    # activation
    if self.activation == 'relu':
      x = tf.nn.relu(x)
    elif self.activation == 'tanh':
      x = tf.nn.tanh(x)
    else:
      pass
    
    return x

In [None]:
class ResBlock(tf.keras.Model):
  def __init__(self, filters, size):
    super(ResBlock, self).__init__()
    self.conv1 = Conv(filters, size, padding='valid', activation='relu')
    self.conv2 = Conv(filters, size, padding='valid', activation='none')
  
  def call(self, x, training):
    xp1 = tf.pad(x, [[0, 0], [1, 1], [1, 1], [0, 0]], 'REFLECT')
    conv = self.conv1(xp1, training)
    conv = tf.pad(conv, [[0, 0], [1, 1], [1, 1], [0, 0]], 'REFLECT')
    conv = self.conv2(conv, training)
    x = x + conv
    
    return x

In [None]:
class ConvTranspose(tf.keras.Model):
  def __init__(self, filters, size,
               apply_norm='instance', norm_momentum=0.9, norm_epsilon=1e-5):
    super(ConvTranspose, self).__init__()
    assert apply_norm in ['batch', 'instance']
    self.apply_norm = apply_norm
    self.up_conv = layers.Conv2DTranspose(filters=filters,
                                          kernel_size=(size, size),
                                          strides=2,
                                          padding='same',
                                          kernel_initializer=tf.random_normal_initializer(0., 0.02),
                                          use_bias=False)
    
    if self.apply_norm == 'instance':
      self.instancenorm = InstanceNormalization()
    elif self.apply_norm == 'batch':
      self.batchnorm = layers.BatchNormalization(momentum=norm_momentum,
                                                 epsilon=norm_epsilon)
    else:
      pass

  def call(self, x, training):
    x = self.up_conv(x)
    if self.apply_norm == 'instance':
      x = self.instancenorm(x)
    else:
      x = self.batchnorm(x, training=training)
    x = tf.nn.relu(x)
    
    return x

In [None]:
class Generator(tf.keras.Model):
  def __init__(self, inputs_shape=256):
    super(Generator, self).__init__()
    assert inputs_shape in [128, 256]
    self.inputs_shape = inputs_shape
    self.conv = Conv(32, 7, padding='valid') # c7s1-32
    self.down1 = Conv(64, 3, 2)  # d64
    self.down2 = Conv(128, 3, 2) # d128
    
    self.res1 = ResBlock(128, 3) # R128
    self.res2 = ResBlock(128, 3) # R128
    self.res3 = ResBlock(128, 3) # R128
    self.res4 = ResBlock(128, 3) # R128
    self.res5 = ResBlock(128, 3) # R128
    
    if self.inputs_shape == 256:
      self.res6 = ResBlock(128, 3) # R128
      self.res7 = ResBlock(128, 3) # R128
      self.res8 = ResBlock(128, 3) # R128
      self.res9 = ResBlock(128, 3) # R128

    self.up1 = ConvTranspose(64, 3) # u64
    self.up2 = ConvTranspose(32, 3) # u32
    self.last = Conv(3, 7, padding='valid', activation='tanh') # c7s1-3
  
  def call(self, x, training):
    # x shape == (bs, 256, 256, 3)
    xp1 = tf.pad(x, [[0, 0], [3, 3], [3, 3], [0, 0]], 'REFLECT') # xp1 shape: (bs, 262, 262, 3)
    x1 = self.conv(xp1, training=training)    # x1 shape: (bs, 256, 256, 32)
    x2 = self.down1(x1, training=training)    # x2 shape: (bs, 128, 128, 64)
    x3 = self.down2(x2, training=training)    # x3 shape: (bs, 64, 64, 128)
    
    x4 = self.res1(x3, training=training)     # x4 shape: (bs, 64, 64, 128)
    x5 = self.res2(x4, training=training)     # x5 shape: (bs, 64, 64, 128)
    x6 = self.res3(x5, training=training)     # x6 shape: (bs, 64, 64, 128)
    x7 = self.res4(x6, training=training)     # x7 shape: (bs, 64, 64, 128)
    x8 = self.res5(x7, training=training)     # x8 shape: (bs, 64, 64, 128)
    
    if self.inputs_shape == 256:
      x9 = self.res6(x8, training=training)   # x9 shape: (bs, 64, 64, 128)
      x10 = self.res7(x9, training=training)  # x10 shape: (bs, 64, 64, 128)
      x11 = self.res8(x10, training=training) # x11 shape: (bs, 64, 64, 128)
      x12 = self.res9(x11, training=training) # x12 shape: (bs, 64, 64, 128)
    else:
      x12 = x8

    x13 = self.up1(x12, training=training)    # x13 shape: (bs, 128, 128, 64)
    x14 = self.up2(x13, training=training)    # x14 shape: (bs, 256, 256, 32)
    xp2 = tf.pad(x14, [[0, 0], [3, 3], [3, 3], [0, 0]], 'REFLECT') # xp2 shape: (bs, 262, 262, 3)

    generated_images = self.last(xp2, training=training) # generated_images shape: (bs, 256, 256, 3)

    return generated_images

In [None]:
# Create two generators
generator_X2Y = Generator(inputs_shape=IMG_SIZE) # This generator_X2Y corresponds to function G: X -> Y in paper's notation
generator_Y2X = Generator(inputs_shape=IMG_SIZE) # This generator_Y2X corresponds to function F: Y -> X in paper's notation

In [None]:
# Test for Generator()
fake_imageB = generator_X2Y(imageA[tf.newaxis, ...], training=False)
fake_imageA = generator_Y2X(imageB[tf.newaxis, ...], training=False)
plt.imshow(fake_imageB[0, ...])

### Discriminator
  * The Discriminator is a PatchGAN.
  * Each block in the discriminator is (Conv -> BatchNorm -> Leaky ReLU)
  * The shape of the output after the last layer is (batch_size, 30, 30, 1)
  * Each 30x30 patch of the output classifies a 70x70 portion of the input image (such an architecture is called a PatchGAN).
  * Shape of the input travelling through the generator and the discriminator is in the comments in the code.

To learn more about the architecture and the hyperparameters you can refer the [paper](https://arxiv.org/abs/1703.10593).

In [None]:
class DiscDownsample(tf.keras.Model):
  def __init__(self, filters, size, strides=2,
               apply_norm='instance', norm_momentum=0.9, norm_epsilon=1e-5,
               apply_dropout=True):
    super(DiscDownsample, self).__init__()
    assert apply_norm in ['batch', 'instance', 'none']
    self.apply_norm = apply_norm

    if self.apply_norm == 'none':
      use_bias = True
    else:
      use_bias = False
      
    self.apply_dropout = apply_dropout

    self.conv = layers.Conv2D(filters=filters,
                              kernel_size=(size, size),
                              strides=strides,
                              padding='same',
                              kernel_initializer=tf.random_normal_initializer(0., 0.02),
                              use_bias=use_bias)
    
    if self.apply_norm == 'instance':
      self.instancenorm = InstanceNormalization()
    elif self.apply_norm == 'batch':
      self.batchnorm = layers.BatchNormalization(momentum=norm_momentum,
                                                 epsilon=norm_epsilon)
    else:
      pass
    
    if self.apply_dropout:
      self.dropout = layers.Dropout(0.5)
  
  def call(self, x, training):
    # convolution
    x = self.conv(x)
    
    # normalization
    if self.apply_norm == 'instance':
      x = self.instancenorm(x)
    elif self.apply_norm == 'batch':
      x = self.batchnorm(x, training=training)
    else:
      pass
    
    # dropout and activation
    if self.apply_dropout:
      x = self.dropout(x, training=training)
    x = tf.nn.leaky_relu(x)

    return x

In [None]:
class Discriminator(tf.keras.Model):
  def __init__(self):
    super(Discriminator, self).__init__()    
    self.down1 = DiscDownsample(64, 4, apply_dropout=False)             # C64
    self.down2 = DiscDownsample(128, 4)                                  # C128
    self.down3 = DiscDownsample(256, 4)                                  # C256
    self.down4 = DiscDownsample(512, 4, strides=1, apply_dropout=False) # C512
    self.last = Conv(1, 4, 1, activation='none', apply_norm='none')      # last
  
  def call(self, x, training):
    # x shape == (bs, 256, 256, 3)
    x = self.down1(x, training=training) # (bs, 128, 128, 64)
    x = self.down2(x, training=training) # (bs, 64, 64, 128)
    x = self.down3(x, training=training) # (bs, 32, 32, 256)
    x = self.down4(x, training=training) # (bs, 32, 32, 512)
    x = self.last(x, training=training)  # (bs, 32, 32, 1)

    return x

In [None]:
# Create two discriminators
discriminator_X = Discriminator() # This discriminator_X corresponds to function D_X in paper's notation
discriminator_Y = Discriminator() # This discriminator_Y corresponds to function D_Y in paper's notation

In [None]:
# Test for Discriminator()
disc_out = discriminator_X(imageA[tf.newaxis,...], training=False)
disc_out = discriminator_Y(imageB[tf.newaxis,...], training=False)
plt.imshow(disc_out[0,...,-1], vmin=-20, vmax=20, cmap='RdBu_r')
plt.colorbar()

## Model summary

In [None]:
generator_X2Y.summary()

In [None]:
discriminator_Y.summary()

## Define the loss functions and the optimizer

* **Discriminator loss**
  * The discriminator loss function takes 2 inputs; real images, generated images
  * real_loss is a sigmoid cross entropy loss of the real images and an array of ones(since these are the real images)
  * generated_loss is a sigmoid cross entropy loss of the generated images and an array of zeros(since these are the fake images)
  * Then the total_loss is the sum of real_loss and the generated_loss
* **Generator loss**
  * It is a sigmoid cross entropy loss of the generated images and an array of ones.
  * The paper also includes L1 loss which is MAE (mean absolute error) between the generated image and the target image.
  * This allows the generated image to become structurally similar to the target image.
  * The formula to calculate the total generator loss = gan_loss + LAMBDA * l1_loss, where LAMBDA = 100. This value was decided by the authors of the paper.

### Define loss functions

In [None]:
bce_object = tf.losses.BinaryCrossentropy(from_logits=True)
mse_object = tf.losses.MeanSquaredError()
mae_object = tf.losses.MeanAbsoluteError()

In [None]:
def GANLoss(logits, is_real=True, use_lsgan=True):
  """Computes standard GAN loss between `logits` and `labels`.

  Args:
    logits (`1-rank Tensor`): logits.
    is_real (`bool`): True means `1` labeling, False means `0` labeling.
    use_lsgan (`bool`): True means LSGAN loss, False means standard GAN loss

  Returns:
    loss (`0-randk Tensor): the standard GAN loss value. (binary_cross_entropy)
                            or LSGAN loss value.
  """
  if is_real:
    labels = tf.ones_like(logits)
  else:
    labels = tf.zeros_like(logits)
    
  if use_lsgan:
    loss = mse_object(labels, tf.nn.sigmoid(logits))
  else:
    loss = bce_object(labels, logits)
    
  return loss

In [None]:
def discriminator_loss(real_logits, fake_logits):
  # losses of real with label "1"
  real_loss = GANLoss(logits=real_logits, is_real=True)
  # losses of fake with label "0"
  fake_loss = GANLoss(logits=fake_logits, is_real=False)
  
  return real_loss + fake_loss

In [None]:
def cycle_consistency_loss(X, X2Y2X):
  cycle_loss = mae_object(X, X2Y2X) # L1 loss
  #cycle_loss = mse_object(X, X2Y2X) # L2 loss
  return cycle_loss

In [None]:
def generator_loss(fake_logits, imagesX, generated_images_X2Y2X):
  # losses of Generator with label "1" that used to fool the Discriminator
  gan_loss = GANLoss(logits=fake_logits, is_real=True)
  
  # mean absolute error
  cycle_loss = cycle_consistency_loss(imagesX, generated_images_X2Y2X)

  return gan_loss + (LAMBDA * cycle_loss)

### Define optimizers

In [None]:
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate_D, beta_1=0.5)
generator_optimizer = tf.keras.optimizers.Adam(learning_rate_G, beta_1=0.5)

## Checkpoints (Object-based saving)

In [None]:
checkpoint_dir = train_dir
if not tf.io.gfile.exists(checkpoint_dir):
  tf.io.gfile.makedirs(checkpoint_dir)
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator_X2Y=generator_X2Y,
                                 generator_Y2X=generator_Y2X,
                                 discriminator_X=discriminator_X,
                                 discriminator_Y=discriminator_Y)

## Define `generate_and_print_or_save` functions

In [None]:
def generate_and_print_or_save_sample_images(inputs_X, inputs_Y,
                                             is_save=False, epoch=None, checkpoint_dir=checkpoint_dir):
  X2Y = generator_X2Y(inputs_X, training=False)
  X2Y2X = generator_Y2X(X2Y, training=False)
  print_or_save_sample_images_pix2pix(inputs_X, X2Y, X2Y2X,
                                      model_name='cyclegan', name='X2Y2X',
                                      is_save=is_save, epoch=epoch, checkpoint_dir=checkpoint_dir)

  Y2X = generator_Y2X(inputs_Y, training=False)
  Y2X2Y = generator_X2Y(Y2X, training=False)
  print_or_save_sample_images_pix2pix(inputs_Y, Y2X, Y2X2Y,
                                      model_name='cyclegan', name='Y2X2Y',
                                      is_save=is_save, epoch=epoch, checkpoint_dir=checkpoint_dir)

In [None]:
# keeping the constant test input for generation (prediction) so
# it will be easier to see the improvement of the pix2pix.
for inputs_X, inputs_Y in zip(testX_dataset.take(1), testY_dataset.take(1)):
  const_test_input_X = inputs_X
  const_test_input_Y = inputs_Y

In [None]:
# Check for test data X -> Y -> X
# Check for test data Y -> X -> Y
generate_and_print_or_save_sample_images(const_test_input_X, const_test_input_Y)

## Training

### Define training one step function

In [None]:
@tf.function()
def train_step(imagesX, imagesY, global_step):
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    # Image generation from one domain to another domain
    generated_images_X2Y = generator_X2Y(imagesX, training=True)  # G: X -> Y
    generated_images_Y2X = generator_Y2X(imagesY, training=True)  # F: Y -> X
    
    # Image generation from one domain via another domain to original domain
    generated_images_X2Y2X = generator_Y2X(generated_images_X2Y, training=True)  # F: Y -> X
    generated_images_Y2X2Y = generator_X2Y(generated_images_Y2X, training=True)  # G: X -> Y

    # Discriminate real images by Discriminator()
    real_logits_X = discriminator_X(imagesX, training=True)  # D_X
    real_logits_Y = discriminator_Y(imagesY, training=True)  # D_Y

    # Discriminate generated (fake) images by Discriminator()
    fake_logits_X2Y = discriminator_Y(generated_images_X2Y, training=True) # D_Y
    fake_logits_Y2X = discriminator_X(generated_images_Y2X, training=True) # D_X

    gen_X2Y_loss = generator_loss(fake_logits_X2Y, imagesX, generated_images_X2Y2X)
    gen_Y2X_loss = generator_loss(fake_logits_Y2X, imagesY, generated_images_Y2X2Y)
    disc_X_loss = discriminator_loss(real_logits_X, fake_logits_Y2X)
    disc_Y_loss = discriminator_loss(real_logits_Y, fake_logits_X2Y)
    
    total_generator_loss = gen_X2Y_loss + gen_Y2X_loss
    total_discriminator_loss = disc_X_loss + disc_Y_loss

  discriminator_tvars = discriminator_X.trainable_variables + discriminator_Y.trainable_variables
  generator_tvars = generator_X2Y.trainable_variables + generator_Y2X.trainable_variables
  
  gradients_of_discriminator = disc_tape.gradient(total_discriminator_loss, discriminator_tvars)
  gradients_of_generator = gen_tape.gradient(total_generator_loss, generator_tvars)
  
  # Learning rate decay
  num_steps_per_epoch = N // batch_size
  if global_step > num_steps_per_epoch * constant_lr_epochs:
    decay_step = num_steps_per_epoch * decay_lr_epochs
    discriminator_optimizer.lr.assign_sub(learning_rate_D * 1. / decay_step) # tf.train.polynomial_decay (linear decay)
    generator_optimizer.lr.assign_sub(learning_rate_G * 1. / decay_step) # tf.train.polynomial_decay (linear decay)

  discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator_tvars))
  generator_optimizer.apply_gradients(zip(gradients_of_generator, generator_tvars))

  return gen_X2Y_loss, gen_Y2X_loss, disc_X_loss, disc_Y_loss

### Training until max_epochs

In [None]:
print('Start Training.')
num_batches_per_epoch = N // batch_size
global_step = tf.Variable(0, trainable=False)

for epoch in range(max_epochs):

  # End of 'for' loop depends on shorter dataset
  for step, (imagesX, imagesY) in enumerate(zip(trainX_dataset, trainY_dataset)):
    start_time = time.time()

    gen_X2Y_loss, gen_Y2X_loss, disc_X_loss, disc_Y_loss = train_step(imagesX, imagesY, global_step)
    global_step.assign_add(1)

    # print the result images every print_steps
    if global_step.numpy() % print_steps == 0:
      epochs = epoch + step / float(num_batches_per_epoch)
      duration = time.time() - start_time
      examples_per_sec = batch_size / float(duration)
      display.clear_output(wait=True)
      print("Epochs: {:.2f} lr: {:.3g}, {:.3g}, global_step: {:d} loss_D_X: {:.3g} loss_D_Y: {:.3g} loss_G_X2Y: {:.3g} loss_F_Y2X: {:.3g} ({:.2f} examples/sec; {:.3f} sec/batch)".format(
                epochs, generator_optimizer.lr.numpy(), discriminator_optimizer.lr.numpy(), global_step.numpy(), disc_X_loss, disc_Y_loss, gen_X2Y_loss, gen_Y2X_loss, examples_per_sec, duration))
      # generate sample image from random test image
      # the training=True is intentional here since
      # we want the batch statistics while running the model
      # on the test dataset. If we use training=False, we will get 
      # the accumulated statistics learned from the training dataset
      # (which we don't want)
      for test_inputs_X, test_inputs_Y in zip(testX_dataset.take(1), testY_dataset.take(1)):
        generate_and_print_or_save_sample_images(test_inputs_X, test_inputs_Y)

  # saving the result image files every save_images_epochs
  if (epoch + 1) % save_images_epochs == 0:
    display.clear_output(wait=True)
    print("This images are saved at {} epoch".format(epoch+1))
    generate_and_print_or_save_sample_images(const_test_input_X, const_test_input_Y,
                                             is_save=True, epoch=epoch+1, checkpoint_dir=checkpoint_dir)

  # saving (checkpoint) the model every save_epochs
  if (epoch + 1) % save_model_epochs == 0:
    checkpoint.save(file_prefix=checkpoint_prefix)
    
print('Training Done.')

In [None]:
# generating after the final epoch
display.clear_output(wait=True)
for test_inputs_X, test_inputs_Y in zip(testX_dataset.take(1), testY_dataset.take(1)):
  generate_and_print_or_save_sample_images(test_inputs_X, test_inputs_Y)

## Restore the latest checkpoint

In [None]:
# restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

## Display an image using the epoch number

In [None]:
display_image(max_epochs, 'X2Y2X', checkpoint_dir)

In [None]:
display_image(max_epochs, 'Y2X2Y', checkpoint_dir)

## Generate a GIF of all the saved images.

In [None]:
filename1 = model_name + '_' + dataset_name + '_' + 'X2Y2X' + '.gif'
generate_gif(filename1, checkpoint_dir)
filename2 = model_name + '_' + dataset_name + '_' + 'Y2X2Y' + '.gif'
generate_gif(filename2, checkpoint_dir)

In [None]:
display.Image(filename=filename1 + '.png')

In [None]:
display.Image(filename=filename2 + '.png')