<a href="https://colab.research.google.com/github/lingchm/datascience/blob/master/DL_exercises/CycleGAN_implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Building CycleGAN from Scratch 

references:
* https://www.codeproject.com/Articles/5304922/Building-a-Mobile-Style-Transfer-CycleGAN-with-Ker
* CycleGAN paper (2017): https://arxiv.org/abs/1703.10593

Framwork: Keras and Tensorflow

Dataset: `horse2zebra`

### Preparation

In [4]:
import tensorflow as tf
import tensorflow_datasets as tfds
#from tensorflow_examples.models.pix2pix import pix2pix

import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output

AUTOTUNE = tf.data.AUTOTUNE

### Data loading and processing

Let’s load the dataset and apply some preprocessing techniques such as cropping, jittering, and mirroring, which will help us avoid overfitting of the network:
* Image jittering resizes the image to 286 by 286 pixels and then crops it to 256 by 256 pixels from a randomly selected origin point
* Image mirroring flips the image horizontally, from left to right.

In [5]:
dataset, metadata = tfds.load('cycle_gan/horse2zebra', with_info=True, as_supervised=True)

train_horses, train_zebras = dataset['trainA'], dataset['trainB']
test_horses, test_zebras = dataset['testA'], dataset['testB']

[1mDownloading and preparing dataset cycle_gan/horse2zebra/2.0.0 (download: 111.45 MiB, generated: Unknown size, total: 111.45 MiB) to /root/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0...[0m


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]






0 examples [00:00, ? examples/s]

Shuffling and writing examples to /root/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0.incomplete3FHPPR/cycle_gan-trainA.tfrecord


  0%|          | 0/1067 [00:00<?, ? examples/s]

0 examples [00:00, ? examples/s]

Shuffling and writing examples to /root/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0.incomplete3FHPPR/cycle_gan-trainB.tfrecord


  0%|          | 0/1334 [00:00<?, ? examples/s]

0 examples [00:00, ? examples/s]

Shuffling and writing examples to /root/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0.incomplete3FHPPR/cycle_gan-testA.tfrecord


  0%|          | 0/120 [00:00<?, ? examples/s]

0 examples [00:00, ? examples/s]

Shuffling and writing examples to /root/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0.incomplete3FHPPR/cycle_gan-testB.tfrecord


  0%|          | 0/140 [00:00<?, ? examples/s]

[1mDataset cycle_gan downloaded and prepared to /root/tensorflow_datasets/cycle_gan/horse2zebra/2.0.0. Subsequent calls will reuse this data.[0m


In [12]:
metadata

tfds.core.DatasetInfo(
    name='cycle_gan',
    version=2.0.0,
    description='A dataset consisting of images from two classes A and B (For example: horses/zebras, apple/orange,...)',
    homepage='https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/',
    features=FeaturesDict({
        'image': Image(shape=(None, None, 3), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=2),
    }),
    total_num_examples=2661,
    splits={
        'testA': 120,
        'testB': 140,
        'trainA': 1067,
        'trainB': 1334,
    },
    supervised_keys=('image', 'label'),
    citation="""@article{DBLP:journals/corr/ZhuPIE17,
      author    = {Jun{-}Yan Zhu and
                   Taesung Park and
                   Phillip Isola and
                   Alexei A. Efros},
      title     = {Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial
                   Networks},
      journal   = {CoRR},
      volume    = {abs/1703.10593},

In [18]:
# preprocessing functions
def random_crop(image):
  cropped_image = tf.image.random_crop(
      image, size=[IMG_HEIGHT, IMG_WIDTH, 3])
  return cropped_image

def normalize(image): # normalizing images to [-1, 1]
  image = tf.cast(image, tf.float32)
  image = (image / 127.5) - 1
  return image

def random_jitter(image):
  # resizing to 286 x 286 x 3
  image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  # randomly cropping to 256 x 256 x 3
  image = random_crop(image)
  # randomly mirroring
  image = tf.image.random_flip_left_right(image)
  return image

def preprocess_image_train(image):
  image = random_jitter(image)
  image = normalize(image)
  return image

def preprocess_image_test(image):
  image = normalize(image)
  return image

In [19]:
IMG_HEIGHT, IMG_WIDTH = 256, 256
BUFFER_SIZE = 20

train_horses = train_horses.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

train_zebras = train_zebras.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

test_horses = test_horses.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

test_zebras = test_zebras.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

sample_horse = next(iter(train_horses))
sample_zebra = next(iter(train_zebras))

plt.subplot(121)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)
plt.show()

plt.subplot(122)
plt.title('Horse with random mirroring')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)
plt.show()

plt.subplot(121)
plt.title('Zebra')
plt.imshow(sample_horse[0] * 0.5 + 0.5)
plt.show()

plt.subplot(122)
plt.title('Zebra with random jitter')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)
plt.show()

ValueError: ignored