In [None]:
!pip install tensorflow-gpu==2.0.0-beta1

Install keras-contrib for InstanceNormalization, you can also implement it by yourself. See [this link](https://stackoverflow.com/a/48118940/1651560) for the details of InstanceNormalization and BatchNormalization.
P.S I forked the original repo and updated the code in order to make it compatible with TensorFlow 2.0.

In [9]:
!pip install --force-reinstall git+https://github.com/objectc/keras-contrib.git

Collecting git+https://github.com/objectc/keras-contrib.git
  Cloning https://github.com/objectc/keras-contrib.git to /tmp/pip-req-build-cbvzfd6d
  Running command git clone -q https://github.com/objectc/keras-contrib.git /tmp/pip-req-build-cbvzfd6d
Collecting keras (from keras-contrib==2.0.8)
[?25l  Downloading https://files.pythonhosted.org/packages/5e/10/aa32dad071ce52b5502266b5c659451cfd6ffcbf14e6c8c4f16c0ff5aaab/Keras-2.2.4-py2.py3-none-any.whl (312kB)
[K     |████████████████████████████████| 317kB 5.0MB/s eta 0:00:01
[?25hCollecting keras-preprocessing>=1.0.5 (from keras->keras-contrib==2.0.8)
[?25l  Downloading https://files.pythonhosted.org/packages/28/6a/8c1f62c37212d9fc441a7e26736df51ce6f0e38455816445471f10da4f0a/Keras_Preprocessing-1.1.0-py2.py3-none-any.whl (41kB)
[K     |████████████████████████████████| 51kB 20.0MB/s eta 0:00:01
[?25hCollecting h5py (from keras->keras-contrib==2.0.8)
[?25l  Downloading https://files.pythonhosted.org/packages/8e/fd/2ca5c4f4ed33ac41

This implementation of CycleGAN are based on the paper [Unpaired Image-to-Image Translation
using Cycle-Consistent Adversarial Networks](https://arxiv.org/pdf/1703.10593.pdf)

In [11]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate, BatchNormalization, Activation, ZeroPadding2D, LeakyReLU, UpSampling2D, Conv2D
from keras_contrib.layers import InstanceNormalization
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
import datetime
import matplotlib.pyplot as plt
import sys
from data_loader import DataLoader
import os

According to the paper,  reflection padding was used to reduce artifacts.

In [2]:
# Keras does not provide reflection padding implementation, we need to custom our own padding layer class.

class ReflectionPadding2D(keras.layers.Layer):
    def __init__(self, kernel_size = None, padding=(1, 1), **kwargs):
        self.padding = tuple(padding)
        # if kernel_size pass in, use half of it as the padding
        if kernel_size:
          self.padding = (kernel_size//2, kernel_size//2)
        self.input_spec = [keras.layers.InputSpec(ndim=4)]
        super(ReflectionPadding2D, self).__init__(**kwargs)

    def compute_output_shape(self, s):
        """ If you are using "channels_last" configuration"""
        return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])

    def call(self, x, mask=None):
        w_pad,h_pad = self.padding
        return tf.pad(x, [[0,0], [h_pad,h_pad], [w_pad,w_pad], [0,0] ], 'REFLECT')

In [3]:
class CycleGAN:
  
  def __init__(self):
    self.input_shape = (128, 128, 3)
  
  # Residual blocks
  def residual(self, X, filters):
    X_shortcut = X
#     output = InstanceNormalization()(X)
#     output = keras.layers.LeakyReLU()(output)
    output = Conv2D(filters=filters, kernel_size=[3, 3], strides=[1, 1], padding="same")(X)

    output = InstanceNormalization()(output)
    output = keras.layers.LeakyReLU()(output)
    output = Conv2D(filters=filters, kernel_size=[3, 3], strides=[1, 1], padding="same")(output)

    output = keras.layers.add([X_shortcut,output])

    return output
  
  def create_generator(self):
    
    def conv(layer_input, filters, kernel_size=3, strides=2):
      output = ReflectionPadding2D(kernel_size=kernel_size)(layer_input)
      output = keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, strides=strides)(output)
      # The paper point out that they use InstanceNormalization instead of BatchNormalization. 
      output = InstanceNormalization()(output)
      output = keras.layers.LeakyReLU()(output)
      return output
      
    input = keras.layers.Input(shape=self.input_shape)
    # Downsampling
    d = conv(input, filters=64, kernel_size=7, strides=1)
    d = conv(d, filters=128, kernel_size=3, strides=2)
    d = conv(d, filters=256, kernel_size=3, strides=2)
    
    # resnet layers
    for i in range(6):
      r = self.residual(d, 256)
      
    # Upsampling
    u = InstanceNormalization()(r)
    u = keras.layers.LeakyReLU()(u)
    u = keras.layers.Conv2DTranspose(128, kernel_size=3, strides=2, padding='same')(u)
    u = InstanceNormalization()(u)
    u = keras.layers.LeakyReLU()(u)
    u = keras.layers.Conv2DTranspose(64, kernel_size=3, strides=2, padding='same')(u)
    u = InstanceNormalization()(u)
    u = keras.layers.LeakyReLU()(u)
    output = keras.layers.Conv2DTranspose(3, kernel_size=7, padding='same', activation='tanh')(u)
    return keras.models.Model(input, output)
    

  
  def create_discriminator(self):
    input = Input(shape=self.input_shape)
    d = input
    kernel_size = 4
    for depth in range(4):
      if depth == 3:
        d = keras.layers.ZeroPadding2D(1)(d)
      d = keras.layers.Conv2D(64*(2**depth), kernel_size=kernel_size, strides=2, padding='same')(d)
      if depth > 0:
        d = InstanceNormalization()(d)
      d = keras.layers.LeakyReLU()(d)
    output = keras.layers.ZeroPadding2D(1)(d)
    output = keras.layers.Conv2D(1, kernel_size=kernel_size, activation='sigmoid')(output)
    return Model(input, output)
    

In [4]:
def create_generator():
    def residual(X, filters):
      X_shortcut = X
  #     output = InstanceNormalization()(X)
  #     output = keras.layers.LeakyReLU()(output)
      output = Conv2D(filters=filters, kernel_size=[3, 3], strides=[1, 1], padding="same")(X)

      output = InstanceNormalization()(output)
      output = keras.layers.LeakyReLU()(output)
      output = Conv2D(filters=filters, kernel_size=[3, 3], strides=[1, 1], padding="same")(output)

      output = keras.layers.add([X_shortcut,output])

      return output
    def conv(layer_input, filters, kernel_size=3, strides=2):
      output = ReflectionPadding2D(kernel_size=kernel_size)(layer_input)
      print(output.shape)
      output = keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, strides=strides)(output)
      # The paper point out that they use InstanceNormalization instead of BatchNormalization. 
      output = InstanceNormalization()(output)
      output = keras.layers.LeakyReLU()(output)
      return output
      
    input = keras.layers.Input(shape=(128,128,3))
    # Downsampling
    d = conv(input, filters=64, kernel_size=7, strides=1)
    d = conv(d, filters=128, kernel_size=3, strides=2)
    d = conv(d, filters=256, kernel_size=3, strides=2)
    
    # resnet layers
    for i in range(6):
      r = residual(d, 256)
      
    # Upsampling
    u = InstanceNormalization()(r)
    u = keras.layers.LeakyReLU()(u)
    u = keras.layers.Conv2DTranspose(128, kernel_size=3, strides=2, padding='same')(u)
    u = InstanceNormalization()(u)
    u = keras.layers.LeakyReLU()(u)
    u = keras.layers.Conv2DTranspose(64, kernel_size=3, strides=2, padding='same')(u)
    u = InstanceNormalization()(u)
    u = keras.layers.LeakyReLU()(u)
    output = keras.layers.Conv2DTranspose(3, kernel_size=7, padding='same', activation='tanh')(u)
    return keras.models.Model(input, output)

In [5]:
GAN = CycleGAN()

The [PatchGAN](https://arxiv.org/pdf/1611.07004.pdf) / Markovian discriminator works by classifying individual (N x N) patches in the image as “real vs. fake”, opposed to classifying the entire image as “real vs. fake”. The authors reason that this enforces more constraints that encourage sharp high-frequency detail. Additionally, the PatchGAN has fewer parameters and runs faster than classifying the entire image.
Read for more details [here](https://towardsdatascience.com/pix2pix-869c17900998)

In [12]:
# !bash download_dataset.sh apple2orange
!bash download_dataset.sh horse2zebra

mkdir: cannot create directory ‘mydatasets’: File exists
for details.

--2019-07-10 20:20:09--  https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip
Resolving people.eecs.berkeley.edu (people.eecs.berkeley.edu)... 128.32.189.73
Connecting to people.eecs.berkeley.edu (people.eecs.berkeley.edu)|128.32.189.73|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 116867962 (111M) [application/zip]
Saving to: ‘./mydatasets/horse2zebra.zip’


2019-07-10 20:20:10 (97.1 MB/s) - ‘./mydatasets/horse2zebra.zip’ saved [116867962/116867962]

mkdir: cannot create directory ‘./mydatasets/horse2zebra/’: File exists
Archive:  ./mydatasets/horse2zebra.zip
   creating: ./mydatasets/horse2zebra/trainA/
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_6223.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_1567.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_3354.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_299.jpg

  inflating: ./mydatasets/horse2zebra/trainA/n02381460_2446.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_1001.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_1624.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_3412.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_1122.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_3175.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_4374.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_8639.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_528.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_11.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_3984.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_1135.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_3902.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_199.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_2967.jpg  
  inflating: .

  inflating: ./mydatasets/horse2zebra/trainA/n02381460_4674.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_175.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_6397.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_386.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_4491.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_5036.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_503.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_8895.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_1023.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_8157.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_4453.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_1262.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_2402.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_2179.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_1034.jpg  
  inflating: 

  inflating: ./mydatasets/horse2zebra/trainA/n02381460_7537.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_4856.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_5029.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_4081.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_3075.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_411.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_4008.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_1579.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_4815.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_203.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_3509.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_631.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_196.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_7597.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_1713.jpg  
  inflating: .

  inflating: ./mydatasets/horse2zebra/trainA/n02381460_2605.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_1331.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_4112.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_7803.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_2921.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_584.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_4402.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_3557.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_4946.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_2615.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_117.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_3685.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_854.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_6388.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_2883.jpg  
  inflating: 

  inflating: ./mydatasets/horse2zebra/trainA/n02381460_177.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_1127.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_7528.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_1835.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_3702.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_2598.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_4661.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_2575.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_8877.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_767.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_1537.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_4385.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_2581.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_7858.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_3875.jpg  
  inflating:

  inflating: ./mydatasets/horse2zebra/trainA/n02381460_1524.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_8704.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_7613.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_1008.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_5782.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_8765.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_559.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_5674.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_941.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_5156.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_7165.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_3956.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_3348.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_421.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_4482.jpg  
  inflating: 

  inflating: ./mydatasets/horse2zebra/trainA/n02381460_5075.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_4103.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_5601.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_4632.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_1098.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_7888.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_6252.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_58.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_8153.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_24.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_4598.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_4597.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_7398.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_3935.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_1708.jpg  
  inflating: .

  inflating: ./mydatasets/horse2zebra/trainA/n02381460_3255.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_6922.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_8611.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_5739.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_6267.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_1334.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_8812.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_73.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_1168.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_4447.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_778.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_4324.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_1037.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_4348.jpg  
  inflating: ./mydatasets/horse2zebra/trainA/n02381460_1523.jpg  
  inflating: 

  inflating: ./mydatasets/horse2zebra/testB/n02391049_3060.jpg  
  inflating: ./mydatasets/horse2zebra/testB/n02391049_7740.jpg  
  inflating: ./mydatasets/horse2zebra/testB/n02391049_1880.jpg  
  inflating: ./mydatasets/horse2zebra/testB/n02391049_4570.jpg  
  inflating: ./mydatasets/horse2zebra/testB/n02391049_2810.jpg  
  inflating: ./mydatasets/horse2zebra/testB/n02391049_6890.jpg  
  inflating: ./mydatasets/horse2zebra/testB/n02391049_5930.jpg  
  inflating: ./mydatasets/horse2zebra/testB/n02391049_10100.jpg  
  inflating: ./mydatasets/horse2zebra/testB/n02391049_3770.jpg  
  inflating: ./mydatasets/horse2zebra/testB/n02391049_410.jpg  
  inflating: ./mydatasets/horse2zebra/testB/n02391049_1100.jpg  
  inflating: ./mydatasets/horse2zebra/testB/n02391049_4730.jpg  
  inflating: ./mydatasets/horse2zebra/testB/n02391049_6190.jpg  
  inflating: ./mydatasets/horse2zebra/testB/n02391049_390.jpg  
  inflating: ./mydatasets/horse2zebra/testB/n02391049_6520.jpg  
  inflating: ./mydatasets/

  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2424.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_415.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_10467.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2776.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_5366.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_8447.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_4697.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2628.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_1869.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_1158.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_11162.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_11195.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_3239.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2698.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_9675.jpg  
  inflat

  inflating: ./mydatasets/horse2zebra/trainB/n02391049_4969.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_3229.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_1738.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_246.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_3872.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_6198.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_8444.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_1679.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_7832.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_795.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2239.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_178.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_1751.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_5695.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_8387.jpg  
  inflating: 

  inflating: ./mydatasets/horse2zebra/trainB/n02391049_1056.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_3436.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_574.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_639.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_133.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2596.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_3016.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2732.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_538.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_1153.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_3693.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_1147.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_1156.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_734.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_956.jpg  
  inflating: ./m

  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2293.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_6947.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_9006.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2645.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_3154.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_1459.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_217.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_33.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_10324.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_6866.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_8844.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_954.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2265.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_10269.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_3236.jpg  
  inflating:

  inflating: ./mydatasets/horse2zebra/trainB/n02391049_5679.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_1973.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_6804.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2819.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2731.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_5788.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_4161.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_927.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_35.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_1236.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_7115.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_796.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2748.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_283.jpg  
  inflating: ./myd

  inflating: ./mydatasets/horse2zebra/trainB/n02391049_596.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_966.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2466.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_3271.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_7048.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_6631.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_1256.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_4576.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_633.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_1555.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_752.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_929.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_976.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_8008.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2229.jpg  
  inflating: ./m

  inflating: ./mydatasets/horse2zebra/trainB/n02391049_152.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_4889.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2651.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_9427.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_3072.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_10007.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2341.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_3917.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_6546.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_6582.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_4448.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2674.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_1309.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2379.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2636.jpg  
  inflatin

  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2838.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_7566.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_3253.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2633.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2973.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_5295.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_6227.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_8334.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_1174.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_3266.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_3868.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_5633.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_4903.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_3091.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2879.jpg  
  inflatin

  inflating: ./mydatasets/horse2zebra/trainB/n02391049_421.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_7236.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_8288.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_1567.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_6517.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_7721.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_984.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2747.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2749.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_3053.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_8831.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_392.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_3096.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_7306.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2588.jpg  
  inflating: 

  inflating: ./mydatasets/horse2zebra/trainB/n02391049_6378.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_5889.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_10837.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_10454.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_3443.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2733.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_3249.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_7812.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2756.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_4144.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_6186.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_249.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_8633.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_102.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2375.jpg  
  inflatin

  inflating: ./mydatasets/horse2zebra/trainB/n02391049_8102.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_3037.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_6512.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_3265.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_968.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_5333.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_1062.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2306.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_7678.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2826.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_5015.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_7434.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_2225.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_506.jpg  
  inflating: ./mydatasets/horse2zebra/trainB/n02391049_5005.jpg  
  inflating:

  inflating: ./mydatasets/horse2zebra/testA/n02381460_6290.jpg  
  inflating: ./mydatasets/horse2zebra/testA/n02381460_3110.jpg  
  inflating: ./mydatasets/horse2zebra/testA/n02381460_50.jpg  
  inflating: ./mydatasets/horse2zebra/testA/n02381460_1620.jpg  
  inflating: ./mydatasets/horse2zebra/testA/n02381460_1160.jpg  
  inflating: ./mydatasets/horse2zebra/testA/n02381460_1090.jpg  
  inflating: ./mydatasets/horse2zebra/testA/n02381460_600.jpg  
  inflating: ./mydatasets/horse2zebra/testA/n02381460_1120.jpg  
  inflating: ./mydatasets/horse2zebra/testA/n02381460_690.jpg  
  inflating: ./mydatasets/horse2zebra/testA/n02381460_7890.jpg  
  inflating: ./mydatasets/horse2zebra/testA/n02381460_7700.jpg  
  inflating: ./mydatasets/horse2zebra/testA/n02381460_7400.jpg  
  inflating: ./mydatasets/horse2zebra/testA/n02381460_200.jpg  
  inflating: ./mydatasets/horse2zebra/testA/n02381460_6690.jpg  
  inflating: ./mydatasets/horse2zebra/testA/n02381460_3010.jpg  
  inflating: ./mydatasets/hors

In [13]:
dataset_name = 'horse2zebra'
from data_loader import DataLoader
data_loader = DataLoader(dataset_name=dataset_name)

In [76]:
import importlib
importlib.reload(DataLoader.module)

AttributeError: type object 'DataLoader' has no attribute 'module'

In [14]:
optimizer = keras.optimizers.Adam(0.0002, 0.5)
d_A = GAN.create_discriminator()
d_B = GAN.create_discriminator()
d_A.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])
d_B.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])


In [15]:
g_AB = GAN.create_generator()
g_BA = GAN.create_generator()
# Translate images to the other domain
img_A = Input(shape=GAN.input_shape)
img_B = Input(shape=GAN.input_shape)
fake_B = g_AB(img_A)
fake_A = g_BA(img_B)
# Translate images back to original domain
reconstr_A = g_BA(fake_B)
reconstr_B = g_AB(fake_A)

# Identity mapping of images
# 
img_A_id = g_BA(img_A)
img_B_id = g_AB(img_B)

# For the combined model we will only train the generators
d_A.trainable = False
d_B.trainable = False

# Discriminators determines validity of translated images
valid_A = d_A(fake_A)
valid_B = d_B(fake_B)

# Combined model trains generators to fool discriminators
combined = Model(inputs=[img_A, img_B],
                      outputs=[ valid_A, valid_B,
                                reconstr_A, reconstr_B,
                                img_A_id, img_B_id ])

 # Loss weights
lambda_cycle = 10.0                    # Cycle-consistency loss
# With higher identity loss, the image translation becomes more conservative, 
# so it makes less changes. 
# I tried using values like 0.1, 0.5, 1 and 10, and the result is not so different. 
lambda_id = 0.1 * lambda_cycle    # Identity loss
combined.compile(loss=['mse', 'mse',
                            'mae', 'mae',
                            'mae', 'mae'],
                    loss_weights=[  1, 1,
                                    lambda_cycle, lambda_cycle,
                                    lambda_id, lambda_id ],
                    optimizer=optimizer)

In [16]:
def save_load_weights(epoch, isSave=True):
    if isSave:
        d_A.save_weights('%s_d_A_weights%08d.h5' % (dataset_name, epoch))
        d_B.save_weights('%s_d_B_weights%08d.h5' % (dataset_name, epoch))
        combined.save_weights('%s_combined_weights%08d.h5' % (dataset_name, epoch))
    else:
        d_A.load_weights('%s_d_A_weights%08d.h5' % (dataset_name, epoch))
        d_B.load_weights('%s_d_B_weights%08d.h5' % (dataset_name, epoch))
        combined.load_weights('%s_combined_weights%08d.h5' % (dataset_name, epoch))

In [15]:
save_load_weights(180, False)

In [17]:

def sample_images(epoch, batch_i):
        os.makedirs('images/%s' % dataset_name, exist_ok=True)
        r, c = 10, 3

        imgs_A = data_loader.load_data(domain="A", batch_size=5, is_testing=False, seed=1)
        imgs_B = data_loader.load_data(domain="B", batch_size=5, is_testing=False, seed=1)
        # Demo (for GIF)
        #imgs_A = self.data_loader.load_img('datasets/apple2orange/testA/n07740461_1541.jpg')
        #imgs_B = self.data_loader.load_img('datasets/apple2orange/testB/n07749192_4241.jpg')

        # Translate images to the other domain
        fake_B = g_AB.predict(imgs_A)
        fake_A = g_BA.predict(imgs_B)
        # Translate back to original domain
        reconstr_A = g_BA.predict(fake_B)
        reconstr_B = g_AB.predict(fake_A)

        gen_imgs = np.concatenate([imgs_A, imgs_B, fake_B, fake_A, reconstr_A, reconstr_B])

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        titles = ['Original', 'Translated', 'Reconstructed']
        fig, axs = plt.subplots(r, c)
        fig.set_size_inches(3, 12)
        cnt = 0
        for i in range(c):
            for j in range(r):
                axs[j,i].imshow(gen_imgs[(i*r+j)])
                axs[j,i].set_title(titles[i])
                axs[j,i].axis('off')
        fig.savefig("images/%s/%d_%d.png" % (dataset_name, epoch, batch_i), bbox_inches='tight')
        plt.close()

In [22]:
sample_images(20, 20)

In [18]:
def train_single_batch(imgs_A, imgs_B, valid, fake):
  # ----------------------
  #  Train Discriminators
  # ----------------------

  # Translate images to opposite domain
  fake_B = g_AB.predict(imgs_A)
  fake_A = g_BA.predict(imgs_B)

  # Train the discriminators (original images = real / translated = Fake)
  dA_loss_real = d_A.train_on_batch(imgs_A, valid)
  dA_loss_fake = d_A.train_on_batch(fake_A, fake)
  dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)

  dB_loss_real = d_B.train_on_batch(imgs_B, valid)
  dB_loss_fake = d_B.train_on_batch(fake_B, fake)
  dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

  # Total disciminator loss
  d_loss = 0.5 * np.add(dA_loss, dB_loss)


  # ------------------
  #  Train Generators
  # ------------------

  # Train the generators
  g_loss = combined.train_on_batch([imgs_A, imgs_B],
                                          [valid, valid,
                                          imgs_A, imgs_B,
                                          imgs_A, imgs_B])
  return d_loss, g_loss



In [23]:
losses = {'g':[], 'd':[]}
def train(epochs, batch_size=10, sample_interval=20):

    start_time = datetime.datetime.now()

    # Adversarial loss ground truths
    patch = int(GAN.input_shape[0] / 2**4)
    disc_patch = (patch, patch, 1)
    valid = np.ones((batch_size,) + disc_patch)
    fake = np.zeros((batch_size,) + disc_patch)
    for epoch in range(20,epochs):
        d_loss, g_loss = None, None
        for batch_i, (imgs_A, imgs_B) in enumerate(data_loader.load_batch(batch_size)):
#             print(batch_i)
            d_loss_batch, g_loss_batch = train_single_batch(imgs_A, imgs_B, valid, fake)
            if d_loss is None:
              d_loss, g_loss = d_loss_batch, g_loss_batch
            else:
              d_loss += d_loss_batch
              g_loss += g_loss_batch
        losses['d'].append(d_loss)
        losses['g'].append(g_loss)
        # If at save interval => save generated image samples
        if epoch % sample_interval == 0:
            sample_images(epoch, 0)
            elapsed_time = datetime.datetime.now() - start_time

            # Plot the progress
            print ("[Epoch %d/%d][D loss: %f, acc: %3d%%]  \
                   [G loss: %05f, adv: %05f, recon: %05f, id: %05f] time: %s " \
                                                                    % ( epoch, epochs,
                                                                        d_loss[0], 100*d_loss[1],
                                                                        g_loss[0],
                                                                        np.mean(g_loss[1:3]),
                                                                        np.mean(g_loss[3:5]),
                                                                        np.mean(g_loss[5:6]),
                                                                        elapsed_time))
            save_load_weights(epoch)

In [18]:
save_load_weights(79)

In [21]:
import scipy.misc
from skimage.transform import resize
scipy.misc.imresize = resize
train(epochs=100, batch_size=10, sample_interval=20)

W0710 20:26:15.829469 140317945714496 training.py:1952] Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?
W0710 20:26:18.123739 140317945714496 training.py:1952] Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?
W0710 20:26:18.160947 140317945714496 training.py:1952] Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?
W0710 20:26:20.347181 140317945714496 training.py:1952] Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?


[Epoch 80/100][D loss: 32.120403, acc: 5281%]                     [G loss: 13.668126, adv: 0.000193, recon: 0.625241, id: 0.558038] time: 0:03:38.332528 


In [None]:
train(epochs=100, batch_size=10, sample_interval=20)

[Epoch 20/100][D loss: 9.584903, acc: 9205%]                     [G loss: 3.409021, adv: 0.813346, recon: 0.081030, id: 0.070068] time: 0:03:05.264480 
[Epoch 40/100][D loss: 14.538404, acc: 8446%]                     [G loss: 3.356036, adv: 0.855646, recon: 0.074292, id: 0.077393] time: 1:05:10.893218 


In [0]:
%loadpy data_loader.py

In [0]:


# Images size
w = 256
h = 256

# Cyclic consistency factor

lmda = 10

# Optimizer parameters

lr = 0.0002
beta_1 = 0.5
beta_2 = 0.999
epsilon = 1e-08



disc_a_history = []
disc_b_history = []

gen_a2b_history = {'bc':[], 'mae':[]} 
gen_b2a_history = {'bc':[], 'mae':[]}

gen_b2a_history_new = []
gen_a2b_history_new = []
cycle_history = []

# Data loading

def loadImage(path, h, w):
    
    '''Load single image from specified path'''
    img = image.load_img(path)
    img = img.resize((w,h))
    x = image.img_to_array(img)
    return x


def loadImagesFromDataset(h, w, dataset, use_hdf5=False):

    '''Return a tuple (trainA, trainB, testA, testB) 
    containing numpy arrays populated from the
     test and train set for each part of the cGAN'''

    if (use_hdf5):
        path="./datasets/processed/"+dataset+"_data.h5"
        data = []
        print('\n', '-' * 15, 'Loading data from dataset', dataset, '-' * 15)
        with h5py.File(path, "r") as hf:
            for set_name in tqdm(["trainA_data", "trainB_data", "testA_data", "testB_data"]):
                data.append(hf[set_name][:].astype(np.float32))

        return (set_data for set_data in data)

    else:
        path = "./datasets/"+dataset
        print(path)
        train_a = glob.glob(path + "/trainA/*.png")
        train_b = glob.glob(path + "/trainB/*.png")
        test_a = glob.glob(path + "/testA/*.png")
        test_b = glob.glob(path + "/testB/*.png")

        print("Import trainA")
        if dataset == "nike2adidas" or ("adiedges" in dataset):
            tr_a = np.array([loadImage(p, h, w) for p in tqdm(train_a[:1000])])
        else:
            tr_a = np.array([loadImage(p, h, w) for p in tqdm(train_a)])

        print("Import trainB")
        if dataset == "nike2adidas" or ("adiedges" in dataset):
            tr_b = np.array([loadImage(p, h, w) for p in tqdm(train_b[:1000])])
        else:
            tr_b = np.array([loadImage(p, h, w) for p in tqdm(train_b)])

        print("Import testA")
        ts_a = np.array([loadImage(p, h, w) for p in tqdm(test_a)])

        print("Import testB")
        ts_b = np.array([loadImage(p, h, w) for p in tqdm(test_b)])

    return tr_a, tr_b, ts_a, ts_b
    


# Create a wall of generated images

def plotGeneratedImages(epoch, set_a, set_b, generator_a2b, generator_b2a, examples=6):
    
    true_batch_a = set_a[np.random.randint(0, set_a.shape[0], size=examples)]
    true_batch_b = set_b[np.random.randint(0, set_b.shape[0], size=examples)]

    # Get fake and cyclic images
    generated_a2b = generator_a2b.predict(true_batch_a)
    cycle_a = generator_b2a.predict(generated_a2b)
    generated_b2a = generator_b2a.predict(true_batch_b)
    cycle_b = generator_a2b.predict(generated_b2a)
    
    k = 0

    # Allocate figure
    plt.figure(figsize=(w/10, h/10))

    for output in [true_batch_a, generated_a2b, cycle_a, true_batch_b, generated_b2a, cycle_b]:
        output = (output+1.0)/2.0
        for i in range(output.shape[0]):
            plt.subplot(examples, examples, k*examples +(i + 1))
            img = output[i].transpose(1, 2, 0)  # Using (ch, h, w) scheme needs rearranging for plt to (h, w, ch)
            #print(img.shape)
            plt.imshow(img)
            plt.axis('off')
        plt.tight_layout()
        k += 1
    plt.savefig("images/epoch"+str(epoch)+".png")
    plt.close()


# Plot the loss from each batch

def plotLoss_new():
    plt.figure(figsize=(10, 8))
    plt.plot(disc_a_history, label='Discriminator A loss')
    plt.plot(disc_b_history, label='Discriminator B loss')
    plt.plot(gen_a2b_history_new, label='Generator a2b loss')
    plt.plot(gen_b2a_history_new, label='Generator b2a loss')
    #plt.plot(cycle_history, label="Cyclic loss")
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('images/cyclegan_loss.png')
    plt.close()

def saveModels(epoch, genA2B, genB2A, discA, discB):
    genA2B.save('models/generatorA2B_epoch_%d.h5' % epoch)
    genB2A.save('models/generatorB2A_epoch_%d.h5' % epoch)
    discA.save('models/discriminatorA_epoch_%d.h5' % epoch)
    discB.save('models/discriminatorB_epoch_%d.h5' % epoch)


# Training

def train(epochs, batch_size, dataset, baselr, use_pseudounet=False, use_unet=False, use_decay=False, plot_models=True):

    # Load data and normalize
    x_train_a, x_train_b, x_test_a, x_test_b = loadImagesFromDataset(h, w, dataset, use_hdf5=False)

    x_train_a = (x_train_a.astype(np.float32) - 127.5) / 127.5
    x_train_b = (x_train_b.astype(np.float32) - 127.5) / 127.5
    x_test_a = (x_test_a.astype(np.float32) - 127.5) / 127.5
    x_test_b = (x_test_b.astype(np.float32) - 127.5) / 127.5

    batchCount_a = x_train_a.shape[0] / batch_size
    batchCount_b = x_train_b.shape[0] / batch_size

    # Train on same image amount, would be best to have even sets
    batchCount = min([batchCount_a, batchCount_b])

    print('\nEpochs:', epochs)
    print('Batch size:', batch_size)
    print('Batches per epoch: ', batchCount, "\n")

    #Retrieve components and save model before training, to preserve weights initialization
    disc_a, disc_b, gen_a2b, gen_b2a = components(w, h, pseudounet=use_pseudounet, unet=use_unet, plot=plot_models)
    saveModels(0, gen_a2b, gen_b2a, disc_a, disc_b)

    #Initialize fake images pools
    pool_a2b = []
    pool_b2a = []

    # Define optimizers
    adam_disc = Adam(lr=baselr, beta_1=0.5)
    adam_gen = Adam(lr=baselr, beta_1=0.5)

    # Define image batches
    true_a = gen_a2b.inputs[0]
    true_b = gen_b2a.inputs[0]

    fake_b = gen_a2b.outputs[0]
    fake_a = gen_b2a.outputs[0]

    fake_pool_a = K.placeholder(shape=(None, 3, h, w))
    fake_pool_b = K.placeholder(shape=(None, 3, h, w))

    # Labels for generator training
    y_fake_a = K.ones_like(disc_a([fake_a]))
    y_fake_b = K.ones_like(disc_b([fake_b]))

    # Labels for discriminator training
    y_true_a = K.ones_like(disc_a([true_a])) * 0.9
    y_true_b = K.ones_like(disc_b([true_b])) * 0.9

    fakelabel_a2b = K.zeros_like(disc_b([fake_b]))
    fakelabel_b2a = K.zeros_like(disc_a([fake_a]))

    # Define losses
    disc_a_loss = mse_loss(y_true_a, disc_a([true_a])) + mse_loss(fakelabel_b2a, disc_a([fake_pool_a]))
    disc_b_loss = mse_loss(y_true_b, disc_b([true_b])) + mse_loss(fakelabel_a2b, disc_b([fake_pool_b]))

    gen_a2b_loss = mse_loss(y_fake_b, disc_b([fake_b]))
    gen_b2a_loss = mse_loss(y_fake_a, disc_a([fake_a]))

    cycle_a_loss = mae_loss(true_a, gen_b2a([fake_b]))
    cycle_b_loss = mae_loss(true_b, gen_a2b([fake_a]))
    cyclic_loss = cycle_a_loss + cycle_b_loss

    # Prepare discriminator updater
    discriminator_weights = disc_a.trainable_weights + disc_b.trainable_weights
    disc_loss = (disc_a_loss + disc_b_loss) * 0.5
    discriminator_updater = adam_disc.get_updates(discriminator_weights, [], disc_loss)

    # Prepare generator updater
    generator_weights = gen_a2b.trainable_weights + gen_b2a.trainable_weights
    gen_loss = (gen_a2b_loss + gen_b2a_loss + lmda * cyclic_loss)
    generator_updater = adam_gen.get_updates(generator_weights, [], gen_loss)

    # Define trainers
    generator_trainer = K.function([true_a, true_b], [gen_a2b_loss, gen_b2a_loss, cyclic_loss], generator_updater)
    discriminator_trainer = K.function([true_a, true_b, fake_pool_a, fake_pool_b], [disc_a_loss/2, disc_b_loss/2], discriminator_updater)

    epoch_counter = 1

    # Start training
    for e in range(1, epochs + 1):
        print('\n','-'*15, 'Epoch %d' % e, '-'*15)

        #Learning rate decay
        if use_decay and (epoch_counter > 100):
            lr -= baselr/100
            adam_disc.lr = lr
            adam_gen.lr = lr


        # Initialize progbar and batch counter
        #progbar = generic_utils.Progbar(batchCount)

        np.random.shuffle(x_train_a)
        np.random.shuffle(x_train_b)

        # Cycle through batches
        for i in trange(int(batchCount)):

            # Select true images for training
            #true_batch_a = x_train_a[np.random.randint(0, x_train_a.shape[0], size=batch_size)]
            #true_batch_b = x_train_b[np.random.randint(0, x_train_b.shape[0], size=batch_size)]

            true_batch_a = x_train_a[i*batch_size:i*batch_size+batch_size]
            true_batch_b = x_train_b[i*batch_size:i*batch_size+batch_size]

            # Fake images pool 
            a2b = gen_a2b.predict(true_batch_a)
            b2a = gen_b2a.predict(true_batch_b)

            tmp_b2a = []
            tmp_a2b = []

            for element in a2b:
                if len(pool_a2b) < 50:
                    pool_a2b.append(element)
                    tmp_a2b.append(element)
                else:
                    p = random.uniform(0, 1)

                    if p > 0.5:
                        index = random.randint(0, 49)
                        tmp = np.copy(pool_a2b[index])
                        pool_a2b[index] = element
                        tmp_a2b.append(tmp)
                    else:
                        tmp_a2b.append(element)
            
            for element in b2a:
                if len(pool_b2a) < 50:
                    pool_b2a.append(element)
                    tmp_b2a.append(element)
                else:
                    p = random.uniform(0, 1)

                    if p >0.5:
                        index = random.randint(0, 49)
                        tmp = np.copy(pool_b2a[index])
                        pool_b2a[index] = element
                        tmp_b2a.append(tmp)
                    else:
                        tmp_b2a.append(element)

            pool_a = np.array(tmp_b2a)
            pool_b = np.array(tmp_a2b)

            # Update network and obtain losses
            disc_a_err, disc_b_err = discriminator_trainer([true_batch_a, true_batch_b, pool_a, pool_b])
            gen_a2b_err, gen_b2a_err, cyclic_err = generator_trainer([true_batch_a, true_batch_b])

            # progbar.add(1, values=[
            #                             ("D A", disc_a_err*2),
            #                             ("D B", disc_b_err*2),
            #                             ("G A2B loss", gen_a2b_err),
            #                             ("G B2A loss", gen_b2a_err),
            #                             ("Cyclic loss", cyclic_err)
            #                            ])

        # Save losses for plotting
        disc_a_history.append(disc_a_err)
        disc_b_history.append(disc_b_err)

        gen_a2b_history_new.append(gen_a2b_err)
        gen_b2a_history_new.append(gen_b2a_err)

        #cycle_history.append(cyclic_err[0])
        plotLoss_new()

        plotGeneratedImages(epoch_counter, x_test_a, x_test_b, gen_a2b, gen_b2a)

        if epoch_counter > 150:
            saveModels(epoch_counter, gen_a2b, gen_b2a, disc_a, disc_b)

        epoch_counter += 1


if __name__ == '__main__':
    train(200, 1, "horse2zebra", lr, use_decay=True, use_pseudounet=False, use_unet=False, plot_models=False)