# CycleGAN [with horse2zebra dataset]

* `Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks`, [arXiv:1703.10593](https://arxiv.org/abs/1703.10593)
  * Jun-Yan Zhu, Taesung Park, Phillip Isola, Alexei A. Efros
  
* Implemented by [`tf.keras.layers`](https://www.tensorflow.org/api_docs/python/tf/keras/layers) and [`eager execution`](https://www.tensorflow.org/guide/eager).

## Import modules

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import time
import glob

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import PIL
import imageio
from IPython import display
#import wget
import urllib.request
import zipfile

import tensorflow as tf
from tensorflow.keras import layers
tf.enable_eager_execution()

tf.logging.set_verbosity(tf.logging.INFO)

os.environ["CUDA_VISIBLE_DEVICES"]="0"

## Setting hyperparameters

In [None]:
# Training Flags (hyperparameter configuration)
model_name = 'cyclegan'
train_dir = 'train/' + model_name + '/exp1/'
max_epochs = 200
save_model_epochs = 20
print_steps = 1
save_images_epochs = 5
batch_size = 1
learning_rate_D = 2e-4
learning_rate_G = 2e-4

BUFFER_SIZE = 5000
IMG_SIZE = 128
assert IMG_SIZE in [128, 256]
LAMBDA = 10

## Load the dataset

You can download this dataset and similar datasets from [here](https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/). 
This script source is borrowed from [original CycleGAN github repo.](https://github.com/junyanz/CycleGAN/blob/master/datasets/download_dataset.sh)


As mentioned in the [paper](https://arxiv.org/abs/1703.10593) we apply random jittering and mirroring to the training dataset.
* In random jittering, the image is resized to 286 x 286 and then randomly cropped to 256 x 256
* In random mirroring, the image is randomly flipped horizontally i.e left to right.

In [None]:
DATASETS = ["ae_photos",
            "apple2orange",
            "summer2winter_yosemite",
            "horse2zebra",
            "monet2photo",
            "cezanne2photo",
            "ukiyoe2photo",
            "vangogh2photo",
            "maps",
            "cityscapes",
            "facades",
            "iphone2dslr_flower",
            "ae_photos"]

DATASET_YOUWANT = "horse2zebra"

url = 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/' + DATASET_YOUWANT + '.zip'
datasets_path = '../datasets'
if not os.path.isdir(datasets_path):
  os.makedirs(datasets_path)
zipfile_path = os.path.join(datasets_path, DATASET_YOUWANT + '.zip')

# Download dataset
if not os.path.isfile(zipfile_path):
  #wget.download(url, zipfile_path)
  urllib.request.urlretrieve(url=url, filename=DATASET_YOUWANT + '.zip')
  print('download done')
else:
  print('zipfile already exists')

# Extract zipfile
PATH = os.path.join(datasets_path, DATASET_YOUWANT)
if not os.path.isdir(PATH):
  zip_ref = zipfile.ZipFile(zipfile_path, 'r')
  zip_ref.extractall(datasets_path)
  print('zipfile extract done')
else:
  print('zipfile already extracted')

## Set up dataset with `tf.data`

In [None]:
def load_image(image_file, is_train):
  image = tf.read_file(image_file)
  image = tf.image.decode_jpeg(image, channels=3) # fix the output channels for intentionally
  
  input_image = tf.cast(image, tf.float32)

  if is_train:
    # random jittering
    
    # resizing to 286 x 286 x 3
    if IMG_SIZE == 256:
      RESIZE = 286
    else:
      RESIZE = 145
    input_image = tf.image.resize_images(input_image, [RESIZE, RESIZE],
                                         align_corners=True,
                                         method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    
    # randomly cropping to 256 x 256 x 3
    input_image = tf.random_crop(input_image, size=[IMG_SIZE, IMG_SIZE, 3])

    if np.random.random() > 0.5:
      # random mirroring
      input_image = tf.image.flip_left_right(input_image)
  else:
    input_image = tf.image.resize_images(input_image, size=[IMG_SIZE, IMG_SIZE],
                                         align_corners=True, method=2)
    
  input_image = tf.clip_by_value(input_image, 0.0, 255.0)
  
  # normalizing the images to [-1, 1]
  input_image = (input_image / 127.5) - 1

  return input_image

### Use tf.data to create batches, map(do preprocessing) and shuffle the dataset

In [None]:
trainX_dataset = tf.data.Dataset.list_files(os.path.join(PATH, 'trainA/*.jpg'))

In [None]:
N_trainX = len(glob.glob(os.path.join(PATH, 'trainA/*.jpg')))
trainX_dataset = tf.data.Dataset.list_files(os.path.join(PATH, 'trainA/*.jpg'))
trainX_dataset = trainX_dataset.shuffle(BUFFER_SIZE)
trainX_dataset = trainX_dataset.map(lambda x: load_image(x, True))
trainX_dataset = trainX_dataset.batch(batch_size, drop_remainder=True)

In [None]:
N_trainY = len(glob.glob(os.path.join(PATH, 'trainB/*.jpg')))
trainY_dataset = tf.data.Dataset.list_files(os.path.join(PATH, 'trainB/*.jpg'))
trainY_dataset = trainY_dataset.shuffle(BUFFER_SIZE)
trainY_dataset = trainY_dataset.map(lambda x: load_image(x, True))
trainY_dataset = trainY_dataset.batch(batch_size, drop_remainder=True)

In [None]:
testX_dataset = tf.data.Dataset.list_files(os.path.join(PATH, 'testA/*.jpg'))
testX_dataset = testX_dataset.map(lambda x: load_image(x, False))
testX_dataset = testX_dataset.batch(batch_size, drop_remainder=True)

In [None]:
testY_dataset = tf.data.Dataset.list_files(os.path.join(PATH, 'testB/*.jpg'))
testY_dataset = testY_dataset.map(lambda x: load_image(x, False))
testY_dataset = testY_dataset.batch(batch_size, drop_remainder=True)

## Write the generator and discriminator models

* **Generator**
  * The architecture of generator is a modified U-Net.
  * Each block in the encoder is (Conv -> Batchnorm -> Leaky ReLU)
  * Each block in the decoder is (Transposed Conv -> Batchnorm -> Dropout(applied to the first 3 blocks) -> ReLU)
  * There are skip connections between the encoder and decoder (as in U-Net).

* **Discriminator**
  * The Discriminator is a PatchGAN.
  * Each block in the discriminator is (Conv -> BatchNorm -> Leaky ReLU)
  * The shape of the output after the last layer is (batch_size, 30, 30, 1)
  * Each 30x30 patch of the output classifies a 70x70 portion of the input image (such an architecture is called a PatchGAN).
  * Discriminator receives 2 inputs.
    * Input image and the target image, which it should classify as real.
    * Input image and the generated image (output of generator), which it should classify as fake.
    * We concatenate these 2 inputs together in the code (tf.concat([inp, tar], axis=-1))
  * Shape of the input travelling through the generator and the discriminator is in the comments in the code.

To learn more about the architecture and the hyperparameters you can refer the [paper](https://arxiv.org/abs/1611.07004).

In [None]:
class Conv(tf.keras.Model):
  def __init__(self, filters, size, strides, activation='relu'):
    super(Conv, self).__init__()
    self.conv = layers.Conv2D(filters=filters,
                              kernel_size=(size, size),
                              strides=strides,
                              padding='same',
                              kernel_initializer=tf.random_normal_initializer(0., 0.02),
                              use_bias=False)
    self.batchnorm = layers.BatchNormalization()
    assert activation in ['relu', 'tanh', 'none']
    self.activation = activation
  
  def call(self, x, training):
    x = self.conv(x)
    x = self.batchnorm(x, training=training)
    if self.activation == 'relu':
      x = tf.nn.relu(x)
    elif self.activation == 'tanh':
      x = tf.nn.tanh(x)
    else:
      pass
    
    return x

In [None]:
class ResBlock(tf.keras.Model):
  def __init__(self, filters, size):
    super(ResBlock, self).__init__()
    self.conv1 = layers.Conv2D(filters=filters,
                               kernel_size=(size, size),
                               padding='same',
                               activation='relu',
                               kernel_initializer=tf.random_normal_initializer(0., 0.02))
    self.conv2 = layers.Conv2D(filters=filters,
                               kernel_size=(size, size),
                               padding='same',
                               kernel_initializer=tf.random_normal_initializer(0., 0.02))
  
  def call(self, x):
    conv = self.conv1(x)
    conv = self.conv2(conv)
    x = tf.nn.relu(x + conv)
    
    return x

In [None]:
class ConvTranspose(tf.keras.Model):
  def __init__(self, filters, size):
    super(ConvTranspose, self).__init__()
    self.up_conv = layers.Conv2DTranspose(filters=filters,
                                          kernel_size=(size, size),
                                          strides=2,
                                          padding='same',
                                          kernel_initializer=tf.random_normal_initializer(0., 0.02),
                                          use_bias=False)
    self.batchnorm = layers.BatchNormalization()

  def call(self, x, training):
    x = self.up_conv(x)
    x = self.batchnorm(x, training=training)
    x = tf.nn.relu(x)
    
    return x

In [None]:
class Generator(tf.keras.Model):
  def __init__(self, inputs_shape=256):
    super(Generator, self).__init__()
    assert inputs_shape in [128, 256]
    self.inputs_shape = inputs_shape
    self.conv = Conv(32, 7, 1) # c7s1-32
    self.down1 = Conv(64, 3, 2) # d64
    self.down2 = Conv(128, 3, 2) # d128
    
    self.res1 = ResBlock(128, 3) # R128
    self.res2 = ResBlock(128, 3) # R128
    self.res3 = ResBlock(128, 3) # R128
    self.res4 = ResBlock(128, 3) # R128
    self.res5 = ResBlock(128, 3) # R128
    
    if self.inputs_shape == 256:
      self.res6 = ResBlock(128, 3) # R128
      self.res7 = ResBlock(128, 3) # R128
      self.res8 = ResBlock(128, 3) # R128
      self.res9 = ResBlock(128, 3) # R128

    self.up1 = ConvTranspose(64, 3) # u64
    self.up2 = ConvTranspose(32, 3) # u32
    self.last = Conv(3, 7, 1, 'tanh') # c7s1-3
  
  #@tf.contrib.eager.defun
  def call(self, x, training):
    # x shape == (bs, 256, 256, 3)
    x1 = self.conv(x, training=training)   # x1 shape: (bs, 256, 256, 32)
    x2 = self.down1(x1, training=training) # x2 shape: (bs, 128, 128, 64)
    x3 = self.down2(x2, training=training) # x3 shape: (bs, 64, 64, 128)
    
    x4 = self.res1(x3)                     # x4 shape: (bs, 64, 64, 128)
    x5 = self.res2(x4)                     # x5 shape: (bs, 64, 64, 128)
    x6 = self.res3(x5)                     # x6 shape: (bs, 64, 64, 128)
    x7 = self.res4(x6)                     # x7 shape: (bs, 64, 64, 128)
    x8 = self.res5(x7)                     # x8 shape: (bs, 64, 64, 128)
    
    if self.inputs_shape == 256:
      x9 = self.res6(x8)                   # x9 shape: (bs, 64, 64, 128)
      x10 = self.res7(x9)                  # x10 shape: (bs, 64, 64, 128)
      x11 = self.res8(x10)                 # x11 shape: (bs, 64, 64, 128)
      x12 = self.res9(x11)                 # x12 shape: (bs, 64, 64, 128)
    else:
      x12 = x8

    x13 = self.up1(x12, training=training) # x13 shape: (bs, 128, 128, 64)
    x14 = self.up2(x13, training=training) # x14 shape: (bs, 256, 256, 32)

    generated_images = self.last(x14, training=training) # generated_images shape: (bs, 256, 256, 3)

    return generated_images

In [None]:
class DiscDownsample(tf.keras.Model):
  def __init__(self, filters, size, apply_batchnorm=True):
    super(DiscDownsample, self).__init__()
    self.apply_batchnorm = apply_batchnorm
    self.conv = layers.Conv2D(filters=filters,
                              kernel_size=(size, size),
                              strides=2,
                              padding='same',
                              kernel_initializer=tf.random_normal_initializer(0., 0.02),
                              use_bias=False)
    if self.apply_batchnorm:
      self.batchnorm = layers.BatchNormalization()
  
  def call(self, x, training):
    x = self.conv(x)
    if self.apply_batchnorm:
      x = self.batchnorm(x, training=training)
    x = tf.nn.leaky_relu(x)

    return x

In [None]:
class Discriminator(tf.keras.Model):
  def __init__(self):
    super(Discriminator, self).__init__()    
    self.down1 = DiscDownsample(64, 4, False) # C64
    self.down2 = DiscDownsample(128, 4)        # C128
    self.down3 = DiscDownsample(256, 4)        # C256
    self.down4 = DiscDownsample(512, 4)        # C512
    self.last = Conv(1, 4, 1, 'none')          # last
  
  #@tf.contrib.eager.defun
  def call(self, x, training):
    # x shape == (bs, 256, 256, 3)
    x = self.down1(x, training=training) # (bs, 128, 128, 64)
    x = self.down2(x, training=training) # (bs, 64, 64, 128)
    x = self.down3(x, training=training) # (bs, 32, 32, 256)
    x = self.down4(x, training=training) # (bs, 16, 16, 512)
    x = self.last(x, training=training)  # (bs, 16, 16, 1)

    return x

In [None]:
# The call function of Generator and Discriminator have been decorated
# with tf.contrib.eager.defun()
# We get a performance speedup if defun is used (~25 seconds per epoch)
generator_X2Y = Generator(inputs_shape=IMG_SIZE) # This generator_X2Y corresponds to function G: X -> Y in paper's notation
generator_Y2X = Generator(inputs_shape=IMG_SIZE) # This generator_Y2X corresponds to function F: Y -> X in paper's notation
discriminator_X = Discriminator() # This discriminator_X corresponds to function D_X in paper's notation
discriminator_Y = Discriminator() # This discriminator_Y corresponds to function D_Y in paper's notation

## Define the loss functions and the optimizer

* **Discriminator loss**
  * The discriminator loss function takes 2 inputs; real images, generated images
  * real_loss is a sigmoid cross entropy loss of the real images and an array of ones(since these are the real images)
  * generated_loss is a sigmoid cross entropy loss of the generated images and an array of zeros(since these are the fake images)
  * Then the total_loss is the sum of real_loss and the generated_loss
* **Generator loss**
  * It is a sigmoid cross entropy loss of the generated images and an array of ones.
  * The paper also includes L1 loss which is MAE (mean absolute error) between the generated image and the target image.
  * This allows the generated image to become structurally similar to the target image.
  * The formula to calculate the total generator loss = gan_loss + LAMBDA * l1_loss, where LAMBDA = 100. This value was decided by the authors of the paper.

In [None]:
def GANLoss(logits, is_real=True):
  """Computes standard GAN loss between `logits` and `labels`.

  Args:
    logits (`1-rank Tensor`): logits.
    is_real (`bool`): True means `1` labeling, False means `0` labeling.

  Returns:
    loss (`0-randk Tensor): the standard GAN loss value. (binary_cross_entropy)
  """
  if is_real:
    labels = tf.ones_like(logits)
  else:
    labels = tf.zeros_like(logits)
  
  return tf.losses.sigmoid_cross_entropy(multi_class_labels=labels,
                                         logits=logits)

In [None]:
def discriminator_loss(real_logits, fake_logits):
#def discriminator_loss(disc_real_output, disc_generated_output):
  # losses of real with label "1"
  real_loss = GANLoss(logits=real_logits, is_real=True)
  # losses of fake with label "0"
  fake_loss = GANLoss(logits=fake_logits, is_real=False)
  
  return real_loss + fake_loss

In [None]:
def cycle_consistency_loss(X, X2Y2X):
  cycle_loss = tf.reduce_mean(tf.abs(X - X2Y2X))
  return cycle_loss

In [None]:
def generator_loss(fake_logits, imagesX, generated_images_X2Y2X):
  # losses of Generator with label "1" that used to fool the Discriminator
  gan_loss = GANLoss(logits=fake_logits, is_real=True)
  
  # mean absolute error
  cycle_loss = cycle_consistency_loss(imagesX, generated_images_X2Y2X)

  return gan_loss + (LAMBDA * cycle_loss)

In [None]:
#discriminator_optimizer = tf.train.RMSPropOptimizer(learning_rate_D)
discriminator_optimizer = tf.train.AdamOptimizer(learning_rate_D, beta1=0.5)
generator_optimizer = tf.train.AdamOptimizer(learning_rate_G, beta1=0.5)

## Checkpoints (Object-based saving)

In [None]:
checkpoint_dir = train_dir
if not tf.gfile.Exists(checkpoint_dir):
  tf.gfile.MakeDirs(checkpoint_dir)
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator_X2Y=generator_X2Y,
                                 generator_Y2X=generator_Y2X,
                                 discriminator_X=discriminator_X,
                                 discriminator_Y=discriminator_Y)

## Training

In [None]:
def print_or_save_sample_images(X, X2Y, X2Y2X, name,
                                is_save=False, epoch=None, checkpoint_dir=checkpoint_dir):
  plt.figure(figsize=(15, 5))
  plt.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=0.05)

  display_list = [X[0], X2Y[0], X2Y2X[0]]
  title = ['X domain', 'X -> Y', 'X -> Y -> X']

  for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
      
  assert name in ['X2Y2X', 'Y2X2Y']
  if is_save and epoch is not None:
    if name == 'X2Y2X':
      filepath = os.path.join(checkpoint_dir, 'image_X2Y2X_at_epoch_{:04d}.png'.format(epoch))
    else:
      filepath = os.path.join(checkpoint_dir, 'image_Y2X2Y_at_epoch_{:04d}.png'.format(epoch))
    plt.savefig(filepath)
    
  plt.show() 

In [None]:
# keeping the constant test input for generation (prediction) so
# it will be easier to see the improvement of the pix2pix.
for inputs_X, inputs_Y in zip(testX_dataset.take(1), testY_dataset.take(1)):
  const_test_input_X = inputs_X
  const_test_input_Y = inputs_Y

In [None]:
# Check for test data X -> Y -> X
X2Y = generator_X2Y(const_test_input_X, training=False)
X2Y2X = generator_Y2X(X2Y, training=False)
print_or_save_sample_images(const_test_input_X, X2Y, X2Y2X, 'X2Y2X')

In [None]:
# Check for test data Y -> X -> Y
Y2X = generator_X2Y(const_test_input_Y, training=False)
Y2X2Y = generator_Y2X(Y2X, training=False)
print_or_save_sample_images(const_test_input_Y, Y2X, Y2X2Y, 'Y2X2Y')

In [None]:
generator_X2Y.summary()

In [None]:
dy = discriminator_Y(Y2X2Y, training=True)
discriminator_Y.summary()

In [None]:
tf.logging.info('Start Training.')
global_step = tf.train.get_or_create_global_step()
N = min(N_trainX, N_trainY)
for epoch in range(max_epochs):
  
  # End of 'for' loop depends on shorter dataset
  for step, (imagesX, imagesY) in enumerate(zip(trainX_dataset, trainY_dataset)):
    start_time = time.time()
    
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      # Image generation from one domain to another domain
      generated_images_X2Y = generator_X2Y(imagesX, training=True)  # G: X -> Y
      generated_images_Y2X = generator_Y2X(imagesY, training=True)  # F: Y -> X
      
      # Image generation from one domain via another domain to original domain
      generated_images_X2Y2X = generator_Y2X(generated_images_X2Y, training=True)  # F: Y -> X
      generated_images_Y2X2Y = generator_X2Y(generated_images_Y2X, training=True)  # G: X -> Y

      # Discriminate real images by Discriminator()
      real_logits_X = discriminator_X(imagesX, training=True)  # D_X
      real_logits_Y = discriminator_Y(imagesY, training=True)  # D_Y
      
      # Discriminate generated (fake) images by Discriminator()
      fake_logits_X2Y = discriminator_Y(generated_images_X2Y, training=True) # D_Y
      fake_logits_Y2X = discriminator_X(generated_images_Y2X, training=True) # D_X
      
      gen_X2Y_loss = generator_loss(fake_logits_X2Y, imagesX, generated_images_X2Y2X)
      gen_Y2X_loss = generator_loss(fake_logits_Y2X, imagesY, generated_images_Y2X2Y)
      disc_X_loss = discriminator_loss(real_logits_X, fake_logits_Y2X)
      disc_Y_loss = discriminator_loss(real_logits_Y, fake_logits_X2Y)
      
      total_generator_loss = gen_X2Y_loss + gen_Y2X_loss
      total_discriminator_loss = disc_X_loss + disc_Y_loss
    
    grads_generator = gen_tape.gradient(total_generator_loss,
                                        generator_X2Y.variables + generator_Y2X.variables)
    grads_discriminator = disc_tape.gradient(total_discriminator_loss,
                                             discriminator_X.variables + discriminator_Y.variables)
    
    generator_optimizer.apply_gradients(zip(grads_generator,
                                            generator_X2Y.variables + generator_Y2X.variables),
                                        global_step=global_step)
    discriminator_optimizer.apply_gradients(zip(grads_discriminator,
                                                discriminator_X.variables + discriminator_Y.variables))
                                            
    epochs = epoch + step * batch_size / float(N)
    duration = time.time() - start_time

    if global_step.numpy() % print_steps == 0:
      display.clear_output(wait=True)
      examples_per_sec = batch_size / float(duration)
      print("Epochs: {:.2f} global_step: {} loss_D_X: {:.3f} loss_D_Y: {:.3f} loss_G_X2Y: {:.3f} loss_F_Y2X: {:.3f} ({:.2f} examples/sec; {:.3f} sec/batch)".format(
                epochs, global_step.numpy(), disc_X_loss, disc_Y_loss, gen_X2Y_loss, gen_Y2X_loss, examples_per_sec, duration))
      # generate sample image from random test image
      # the training=True is intentional here since
      # we want the batch statistics while running the model
      # on the test dataset. If we use training=False, we will get 
      # the accumulated statistics learned from the training dataset
      # (which we don't want)
      for test_inputs_X, test_inputs_Y in zip(testX_dataset.take(1), testY_dataset.take(1)):
        X2Y = generator_X2Y(test_inputs_X, training=False)
        X2Y2X = generator_Y2X(X2Y, training=False)
        print_or_save_sample_images(test_inputs_X, X2Y, X2Y2X, 'X2Y2X')

        Y2X = generator_Y2X(test_inputs_Y, training=False)
        Y2X2Y = generator_X2Y(Y2X, training=False)
        print_or_save_sample_images(test_inputs_Y, Y2X, Y2X2Y, 'Y2X2Y')


  if (epoch + 1) % save_images_epochs == 0:
    display.clear_output(wait=True)
    print("This images are saved at {} epoch".format(epoch+1))
    X2Y = generator_X2Y(test_input_X, training=False)
    X2Y2X = generator_Y2X(X2Y, training=False)
    print_or_save_sample_images(test_inputs_X, X2Y, X2Y2X, 'X2Y2X',
                                is_save=True, epoch=epoch+1, checkpoint_dir=checkpoint_dir)

    Y2X = generator_Y2X(test_input_Y, training=False)
    Y2X2Y = generator_X2Y(Y2X, training=False)
    print_or_save_sample_images(test_inputs_Y, Y2X, Y2X2Y, 'Y2X2Y',
                                is_save=True, epoch=epoch+1, checkpoint_dir=checkpoint_dir)

  # saving (checkpoint) the model every save_epochs
  if (epoch + 1) % save_model_epochs == 0:
    checkpoint.save(file_prefix = checkpoint_prefix)

In [None]:
# generating after the final epoch
display.clear_output(wait=True)
for test_inputs_X, test_inputs_Y in zip(testX_dataset.take(1), testY_dataset.take(1)):
  X2Y = generator_X2Y(test_inputs_X, training=False)
  Y2X = generator_Y2X(X2Y, training=False)
  print_or_save_sample_images(test_inputs_X, X2Y, Y2X, 'X2Y2X')

  Y2X = generator_Y2X(test_inputs_Y, training=False)
  X2Y = generator_X2Y(Y2X, training=False)
  print_or_save_sample_images(test_inputs_Y, Y2X, X2Y, 'Y2X2Y')

## Restore the latest checkpoint

In [None]:
# restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

## Display an image using the epoch number

In [None]:
def display_image(epoch_no, name='X2Y2X', checkpoint_dir=checkpoint_dir):
  filepath = os.path.join(checkpoint_dir, 'image_at_epoch_{:04d}.png'.format(epoch_no))
  
  assert name in ['X2Y2X', 'Y2X2Y']
  if name == 'X2Y2X':
    filepath = os.path.join(checkpoint_dir, 'image_X2Y2X_at_epoch_{:04d}.png'.format(epoch_no))
  else:
    filepath = os.path.join(checkpoint_dir, 'image_Y2X2Y_at_epoch_{:04d}.png'.format(epoch_no))
  
  return PIL.Image.open(filepath)

In [None]:
display_image(max_epochs, 'X2Y2X')
display_image(max_epochs, 'Y2X2Y')

## Generate a GIF of all the saved images.

In [None]:
with imageio.get_writer(model_name + '.gif', mode='I') as writer:
  filenames = glob.glob(os.path.join(checkpoint_dir, 'image*.png'))
  filenames = sorted(filenames)
  last = -1
  for i, filename in enumerate(filenames):
    frame = 2*(i**0.5)
    if round(frame) > round(last):
      last = frame
    else:
      continue
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)
    
# this is a hack to display the gif inside the notebook
os.system('cp {}.gif {}.gif.png'.format(model_name, model_name))

In [None]:
display.Image(filename=model_name + ".gif.png")