<a href="https://colab.research.google.com/github/objectc/Generative-Models/blob/master/CycleGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install tensorflow-gpu==2.0.0-beta1

Collecting tensorflow-gpu==2.0.0-beta1
[?25l  Downloading https://files.pythonhosted.org/packages/2b/53/e18c5e7a2263d3581a979645a185804782e59b8e13f42b9c3c3cfb5bb503/tensorflow_gpu-2.0.0b1-cp36-cp36m-manylinux1_x86_64.whl (348.9MB)
[K     |████████████████████████████████| 348.9MB 61kB/s 
Collecting tf-estimator-nightly<1.14.0.dev2019060502,>=1.14.0.dev2019060501 (from tensorflow-gpu==2.0.0-beta1)
[?25l  Downloading https://files.pythonhosted.org/packages/32/dd/99c47dd007dcf10d63fd895611b063732646f23059c618a373e85019eb0e/tf_estimator_nightly-1.14.0.dev2019060501-py2.py3-none-any.whl (496kB)
[K     |████████████████████████████████| 501kB 50.9MB/s 
Collecting tb-nightly<1.14.0a20190604,>=1.14.0a20190603 (from tensorflow-gpu==2.0.0-beta1)
[?25l  Downloading https://files.pythonhosted.org/packages/a4/96/571b875cd81dda9d5dfa1422a4f9d749e67c0a8d4f4f0b33a4e5f5f35e27/tb_nightly-1.14.0a20190603-py3-none-any.whl (3.1MB)
[K     |████████████████████████████████| 3.1MB 42.7MB/s 
Installing c

Install keras-contrib for InstanceNormalization, you can also implement it by yourself. See [this link](https://stackoverflow.com/a/48118940/1651560) for the details of InstanceNormalization and BatchNormalization.
P.S I forked the original repo and updated the code in order to make it compatible with TensorFlow 2.0.

In [2]:
!pip install --force-reinstall git+https://github.com/objectc/keras-contrib.git

Collecting git+https://github.com/objectc/keras-contrib.git
  Cloning https://github.com/objectc/keras-contrib.git to /tmp/pip-req-build-nwi76uuc
  Running command git clone -q https://github.com/objectc/keras-contrib.git /tmp/pip-req-build-nwi76uuc
Collecting keras (from keras-contrib==2.0.8)
[?25l  Downloading https://files.pythonhosted.org/packages/5e/10/aa32dad071ce52b5502266b5c659451cfd6ffcbf14e6c8c4f16c0ff5aaab/Keras-2.2.4-py2.py3-none-any.whl (312kB)
[K     |████████████████████████████████| 317kB 9.1MB/s 
[?25hCollecting keras-preprocessing>=1.0.5 (from keras->keras-contrib==2.0.8)
[?25l  Downloading https://files.pythonhosted.org/packages/28/6a/8c1f62c37212d9fc441a7e26736df51ce6f0e38455816445471f10da4f0a/Keras_Preprocessing-1.1.0-py2.py3-none-any.whl (41kB)
[K     |████████████████████████████████| 51kB 8.9MB/s 
[?25hCollecting keras-applications>=1.0.6 (from keras->keras-contrib==2.0.8)
[?25l  Downloading https://files.pythonhosted.org/packages/71/e3/19762fdfc62877ae91

This implementation of CycleGAN are based on the paper [Unpaired Image-to-Image Translation
using Cycle-Consistent Adversarial Networks](https://arxiv.org/pdf/1703.10593.pdf)

In [0]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate, BatchNormalization, Activation, ZeroPadding2D, LeakyReLU, UpSampling2D, Conv2D
from keras_contrib.layers import InstanceNormalization
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
import datetime
import matplotlib.pyplot as plt
import sys
from data_loader import DataLoader
import os

According to the paper,  reflection padding was used to reduce artifacts.

In [0]:
# Keras does not provide reflection padding implementation, we need to custom our own padding layer class.

class ReflectionPadding2D(keras.layers.Layer):
    def __init__(self, kernel_size = None, padding=(1, 1), **kwargs):
        self.padding = tuple(padding)
        # if kernel_size pass in, use half of it as the padding
        if kernel_size:
          self.padding = (kernel_size//2, kernel_size//2)
        self.input_spec = [keras.layers.InputSpec(ndim=4)]
        super(ReflectionPadding2D, self).__init__(**kwargs)

    def compute_output_shape(self, s):
        """ If you are using "channels_last" configuration"""
        return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])

    def call(self, x, mask=None):
        w_pad,h_pad = self.padding
        return tf.pad(x, [[0,0], [h_pad,h_pad], [w_pad,w_pad], [0,0] ], 'REFLECT')

In [0]:
class CycleGAN:
  
  def __init__(self):
    self.input_shape = (128, 128, 3)
  
  # Residual blocks
  def residual(self, X, filters):
    X_shortcut = X
#     output = InstanceNormalization()(X)
#     output = keras.layers.LeakyReLU()(output)
    output = Conv2D(filters=filters, kernel_size=[3, 3], strides=[1, 1], padding="same")(X)

    output = InstanceNormalization()(output)
    output = keras.layers.LeakyReLU()(output)
    output = Conv2D(filters=filters, kernel_size=[3, 3], strides=[1, 1], padding="same")(output)

    output = keras.layers.add([X_shortcut,output])

    return output
  
  def create_generator(self):
    
    def conv(layer_input, filters, kernel_size=3, strides=2):
      output = ReflectionPadding2D(kernel_size=kernel_size)(layer_input)
      output = keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, strides=strides)(output)
      # The paper point out that they use InstanceNormalization instead of BatchNormalization. 
      output = InstanceNormalization()(output)
      output = keras.layers.LeakyReLU()(output)
      return output
      
    input = keras.layers.Input(shape=self.input_shape)
    # Downsampling
    d = conv(input, filters=64, kernel_size=7, strides=1)
    d = conv(d, filters=128, kernel_size=3, strides=2)
    d = conv(d, filters=256, kernel_size=3, strides=2)
    
    # resnet layers
    for i in range(6):
      r = self.residual(d, 256)
      
    # Upsampling
    u = InstanceNormalization()(r)
    u = keras.layers.LeakyReLU()(u)
    u = keras.layers.Conv2DTranspose(128, kernel_size=3, strides=2, padding='same')(u)
    u = InstanceNormalization()(u)
    u = keras.layers.LeakyReLU()(u)
    u = keras.layers.Conv2DTranspose(64, kernel_size=3, strides=2, padding='same')(u)
    u = InstanceNormalization()(u)
    u = keras.layers.LeakyReLU()(u)
    output = keras.layers.Conv2DTranspose(3, kernel_size=7, padding='same', activation='tanh')(u)
    return keras.models.Model(input, output)
    

  
  def create_discriminator(self):
    input = Input(shape=self.input_shape)
    d = input
    kernel_size = 4
    for depth in range(4):
      if depth == 3:
        d = keras.layers.ZeroPadding2D(1)(d)
      d = keras.layers.Conv2D(64*(2**depth), kernel_size=kernel_size, strides=2, padding='same')(d)
      if depth > 0:
        d = InstanceNormalization()(d)
      d = keras.layers.LeakyReLU()(d)
    output = keras.layers.ZeroPadding2D(1)(d)
    output = keras.layers.Conv2D(1, kernel_size=kernel_size, activation='sigmoid')(output)
    return Model(input, output)
    

In [7]:
def create_generator():
    def residual(X, filters):
      X_shortcut = X
  #     output = InstanceNormalization()(X)
  #     output = keras.layers.LeakyReLU()(output)
      output = Conv2D(filters=filters, kernel_size=[3, 3], strides=[1, 1], padding="same")(X)

      output = InstanceNormalization()(output)
      output = keras.layers.LeakyReLU()(output)
      output = Conv2D(filters=filters, kernel_size=[3, 3], strides=[1, 1], padding="same")(output)

      output = keras.layers.add([X_shortcut,output])

      return output
    def conv(layer_input, filters, kernel_size=3, strides=2):
      output = ReflectionPadding2D(kernel_size=kernel_size)(layer_input)
      print(output.shape)
      output = keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, strides=strides)(output)
      # The paper point out that they use InstanceNormalization instead of BatchNormalization. 
      output = InstanceNormalization()(output)
      output = keras.layers.LeakyReLU()(output)
      return output
      
    input = keras.layers.Input(shape=(128,128,3))
    # Downsampling
    d = conv(input, filters=64, kernel_size=7, strides=1)
    d = conv(d, filters=128, kernel_size=3, strides=2)
    d = conv(d, filters=256, kernel_size=3, strides=2)
    
    # resnet layers
    for i in range(6):
      r = residual(d, 256)
      
    # Upsampling
    u = InstanceNormalization()(r)
    u = keras.layers.LeakyReLU()(u)
    u = keras.layers.Conv2DTranspose(128, kernel_size=3, strides=2, padding='same')(u)
    u = InstanceNormalization()(u)
    u = keras.layers.LeakyReLU()(u)
    u = keras.layers.Conv2DTranspose(64, kernel_size=3, strides=2, padding='same')(u)
    u = InstanceNormalization()(u)
    u = keras.layers.LeakyReLU()(u)
    output = keras.layers.Conv2DTranspose(3, kernel_size=7, padding='same', activation='tanh')(u)
    return keras.models.Model(input, output)
test = create_generator()
test.summary()

(None, 134, 134, 3)
(None, 130, 130, 64)
(None, 66, 66, 128)
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 128, 128, 3) 0                                            
__________________________________________________________________________________________________
reflection_padding2d (Reflectio (None, 134, 134, 3)  0           input_1[0][0]                    
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 128, 128, 64) 9472        reflection_padding2d[0][0]       
__________________________________________________________________________________________________
instance_normalization (Instanc (None, 128, 128, 64) 2           conv2d[0][0]                     
_________________________________

In [0]:
GAN = CycleGAN()

The [PatchGAN](https://arxiv.org/pdf/1611.07004.pdf) / Markovian discriminator works by classifying individual (N x N) patches in the image as “real vs. fake”, opposed to classifying the entire image as “real vs. fake”. The authors reason that this enforces more constraints that encourage sharp high-frequency detail. Additionally, the PatchGAN has fewer parameters and runs faster than classifying the entire image.
Read for more details [here](https://towardsdatascience.com/pix2pix-869c17900998)

In [9]:
!bash download_dataset.sh apple2orange

for details.

--2019-06-28 18:37:13--  https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/apple2orange.zip
Resolving people.eecs.berkeley.edu (people.eecs.berkeley.edu)... 128.32.189.73
Connecting to people.eecs.berkeley.edu (people.eecs.berkeley.edu)|128.32.189.73|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 78456409 (75M) [application/zip]
Saving to: ‘./datasets/apple2orange.zip’


2019-06-28 18:37:18 (16.7 MB/s) - ‘./datasets/apple2orange.zip’ saved [78456409/78456409]

Archive:  ./datasets/apple2orange.zip
   creating: ./datasets/apple2orange/trainA/
  inflating: ./datasets/apple2orange/trainA/n07740461_6908.jpg  
  inflating: ./datasets/apple2orange/trainA/n07740461_7635.jpg  
  inflating: ./datasets/apple2orange/trainA/n07740461_586.jpg  
  inflating: ./datasets/apple2orange/trainA/n07740461_9813.jpg  
  inflating: ./datasets/apple2orange/trainA/n07740461_6835.jpg  
  inflating: ./datasets/apple2orange/trainA/n07740461_2818.jpg  
  infla

In [0]:
dataset_name = 'apple2orange'
from data_loader import DataLoader
data_loader = DataLoader(dataset_name=dataset_name)

In [0]:
optimizer = keras.optimizers.Adam(0.0002, 0.5)
d_A = GAN.create_discriminator()
d_B = GAN.create_discriminator()
d_A.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])
d_B.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])


In [0]:
g_AB = GAN.create_generator()
g_BA = GAN.create_generator()
# Translate images to the other domain
img_A = Input(shape=GAN.input_shape)
img_B = Input(shape=GAN.input_shape)
fake_B = g_AB(img_A)
fake_A = g_BA(img_B)
# Translate images back to original domain
reconstr_A = g_BA(fake_B)
reconstr_B = g_AB(fake_A)

# Identity mapping of images
# 
img_A_id = g_BA(img_A)
img_B_id = g_AB(img_B)

# For the combined model we will only train the generators
d_A.trainable = False
d_B.trainable = False

# Discriminators determines validity of translated images
valid_A = d_A(fake_A)
valid_B = d_B(fake_B)

# Combined model trains generators to fool discriminators
combined = Model(inputs=[img_A, img_B],
                      outputs=[ valid_A, valid_B,
                                reconstr_A, reconstr_B,
                                img_A_id, img_B_id ])

 # Loss weights
lambda_cycle = 10.0                    # Cycle-consistency loss
# With higher identity loss, the image translation becomes more conservative, 
# so it makes less changes. 
# I tried using values like 0.1, 0.5, 1 and 10, and the result is not so different. 
lambda_id = 0.1 * lambda_cycle    # Identity loss
combined.compile(loss=['mse', 'mse',
                            'mae', 'mae',
                            'mae', 'mae'],
                    loss_weights=[  1, 1,
                                    lambda_cycle, lambda_cycle,
                                    lambda_id, lambda_id ],
                    optimizer=optimizer)

In [0]:

def sample_images(epoch, batch_i):
        os.makedirs('images/%s' % dataset_name, exist_ok=True)
        r, c = 2, 3

        imgs_A = data_loader.load_data(domain="A", batch_size=1, is_testing=True)
        imgs_B = data_loader.load_data(domain="B", batch_size=1, is_testing=True)

        # Demo (for GIF)
        #imgs_A = self.data_loader.load_img('datasets/apple2orange/testA/n07740461_1541.jpg')
        #imgs_B = self.data_loader.load_img('datasets/apple2orange/testB/n07749192_4241.jpg')

        # Translate images to the other domain
        fake_B = g_AB.predict(imgs_A)
        fake_A = g_BA.predict(imgs_B)
        # Translate back to original domain
        reconstr_A = g_BA.predict(fake_B)
        reconstr_B = g_AB.predict(fake_A)

        gen_imgs = np.concatenate([imgs_A, fake_B, reconstr_A, imgs_B, fake_A, reconstr_B])

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        titles = ['Original', 'Translated', 'Reconstructed']
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt])
                axs[i, j].set_title(titles[j])
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%s/%d_%d.png" % (dataset_name, epoch, batch_i))
        plt.close()

In [0]:
def train_single_batch(imgs_A, imgs_B, valid, fake):
  # ----------------------
  #  Train Discriminators
  # ----------------------

  # Translate images to opposite domain
  fake_B = g_AB.predict(imgs_A)
  fake_A = g_BA.predict(imgs_B)

  # Train the discriminators (original images = real / translated = Fake)
  dA_loss_real = d_A.train_on_batch(imgs_A, valid)
  dA_loss_fake = d_A.train_on_batch(fake_A, fake)
  dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)

  dB_loss_real = d_B.train_on_batch(imgs_B, valid)
  dB_loss_fake = d_B.train_on_batch(fake_B, fake)
  dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

  # Total disciminator loss
  d_loss = 0.5 * np.add(dA_loss, dB_loss)


  # ------------------
  #  Train Generators
  # ------------------

  # Train the generators
  g_loss = combined.train_on_batch([imgs_A, imgs_B],
                                          [valid, valid,
                                          imgs_A, imgs_B,
                                          imgs_A, imgs_B])
  return d_loss, g_loss



In [0]:
def train(epochs, batch_size=10, sample_interval=20):

    start_time = datetime.datetime.now()

    # Adversarial loss ground truths
    patch = int(GAN.input_shape[0] / 2**4)
    disc_patch = (patch, patch, 1)
    valid = np.ones((batch_size,) + disc_patch)
    fake = np.zeros((batch_size,) + disc_patch)
    losses = {'g':[], 'd':[]}
    for epoch in range(epochs):
        d_loss, g_loss = None, None
        for batch_i, (imgs_A, imgs_B) in enumerate(data_loader.load_batch(batch_size)):
            d_loss_batch, g_loss_batch = train_single_batch(imgs_A, imgs_B, valid, fake)
            if d_loss is None:
              d_loss, g_loss = d_loss_batch, g_loss_batch
            else:
              d_loss += d_loss_batch
              g_loss += g_loss_batch
        losses['d'].append(d_loss)
        losses['g'].append(g_loss)
        # If at save interval => save generated image samples
        if epoch % sample_interval == 0:
            sample_images(epoch, 0)
            elapsed_time = datetime.datetime.now() - start_time

            # Plot the progress
            print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%]  \
                   [G loss: %05f, adv: %05f, recon: %05f, id: %05f] time: %s " \
                                                                    % ( epoch, epochs,
                                                                        batch_i, data_loader.n_batches,
                                                                        d_loss[0], 100*d_loss[1],
                                                                        g_loss[0],
                                                                        np.mean(g_loss[1:3]),
                                                                        np.mean(g_loss[3:5]),
                                                                        np.mean(g_loss[5:6]),
                                                                        elapsed_time))

In [0]:
import scipy.misc
from skimage.transform import resize
scipy.misc.imresize = resize
train(epochs=200, batch_size=10, sample_interval=200)

W0628 21:36:39.497565 140701518124928 training.py:1952] Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?
W0628 21:36:41.275531 140701518124928 training.py:1952] Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?
W0628 21:36:41.302526 140701518124928 training.py:1952] Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?
W0628 21:36:43.064100 140701518124928 training.py:1952] Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?


[Epoch 0/200] [Batch 97/99] [D loss: 32.038742, acc: 4678%]                     [G loss: 16.657583, adv: 0.000263, recon: 0.760162, id: 0.649737] time: 0:01:54.333985 


In [0]:
%loadpy data_loader.py

In [0]:


# Images size
w = 256
h = 256

# Cyclic consistency factor

lmda = 10

# Optimizer parameters

lr = 0.0002
beta_1 = 0.5
beta_2 = 0.999
epsilon = 1e-08



disc_a_history = []
disc_b_history = []

gen_a2b_history = {'bc':[], 'mae':[]} 
gen_b2a_history = {'bc':[], 'mae':[]}

gen_b2a_history_new = []
gen_a2b_history_new = []
cycle_history = []

# Data loading

def loadImage(path, h, w):
    
    '''Load single image from specified path'''
    img = image.load_img(path)
    img = img.resize((w,h))
    x = image.img_to_array(img)
    return x


def loadImagesFromDataset(h, w, dataset, use_hdf5=False):

    '''Return a tuple (trainA, trainB, testA, testB) 
    containing numpy arrays populated from the
     test and train set for each part of the cGAN'''

    if (use_hdf5):
        path="./datasets/processed/"+dataset+"_data.h5"
        data = []
        print('\n', '-' * 15, 'Loading data from dataset', dataset, '-' * 15)
        with h5py.File(path, "r") as hf:
            for set_name in tqdm(["trainA_data", "trainB_data", "testA_data", "testB_data"]):
                data.append(hf[set_name][:].astype(np.float32))

        return (set_data for set_data in data)

    else:
        path = "./datasets/"+dataset
        print(path)
        train_a = glob.glob(path + "/trainA/*.png")
        train_b = glob.glob(path + "/trainB/*.png")
        test_a = glob.glob(path + "/testA/*.png")
        test_b = glob.glob(path + "/testB/*.png")

        print("Import trainA")
        if dataset == "nike2adidas" or ("adiedges" in dataset):
            tr_a = np.array([loadImage(p, h, w) for p in tqdm(train_a[:1000])])
        else:
            tr_a = np.array([loadImage(p, h, w) for p in tqdm(train_a)])

        print("Import trainB")
        if dataset == "nike2adidas" or ("adiedges" in dataset):
            tr_b = np.array([loadImage(p, h, w) for p in tqdm(train_b[:1000])])
        else:
            tr_b = np.array([loadImage(p, h, w) for p in tqdm(train_b)])

        print("Import testA")
        ts_a = np.array([loadImage(p, h, w) for p in tqdm(test_a)])

        print("Import testB")
        ts_b = np.array([loadImage(p, h, w) for p in tqdm(test_b)])

    return tr_a, tr_b, ts_a, ts_b
    


# Create a wall of generated images

def plotGeneratedImages(epoch, set_a, set_b, generator_a2b, generator_b2a, examples=6):
    
    true_batch_a = set_a[np.random.randint(0, set_a.shape[0], size=examples)]
    true_batch_b = set_b[np.random.randint(0, set_b.shape[0], size=examples)]

    # Get fake and cyclic images
    generated_a2b = generator_a2b.predict(true_batch_a)
    cycle_a = generator_b2a.predict(generated_a2b)
    generated_b2a = generator_b2a.predict(true_batch_b)
    cycle_b = generator_a2b.predict(generated_b2a)
    
    k = 0

    # Allocate figure
    plt.figure(figsize=(w/10, h/10))

    for output in [true_batch_a, generated_a2b, cycle_a, true_batch_b, generated_b2a, cycle_b]:
        output = (output+1.0)/2.0
        for i in range(output.shape[0]):
            plt.subplot(examples, examples, k*examples +(i + 1))
            img = output[i].transpose(1, 2, 0)  # Using (ch, h, w) scheme needs rearranging for plt to (h, w, ch)
            #print(img.shape)
            plt.imshow(img)
            plt.axis('off')
        plt.tight_layout()
        k += 1
    plt.savefig("images/epoch"+str(epoch)+".png")
    plt.close()


# Plot the loss from each batch

def plotLoss_new():
    plt.figure(figsize=(10, 8))
    plt.plot(disc_a_history, label='Discriminator A loss')
    plt.plot(disc_b_history, label='Discriminator B loss')
    plt.plot(gen_a2b_history_new, label='Generator a2b loss')
    plt.plot(gen_b2a_history_new, label='Generator b2a loss')
    #plt.plot(cycle_history, label="Cyclic loss")
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('images/cyclegan_loss.png')
    plt.close()

def saveModels(epoch, genA2B, genB2A, discA, discB):
    genA2B.save('models/generatorA2B_epoch_%d.h5' % epoch)
    genB2A.save('models/generatorB2A_epoch_%d.h5' % epoch)
    discA.save('models/discriminatorA_epoch_%d.h5' % epoch)
    discB.save('models/discriminatorB_epoch_%d.h5' % epoch)


# Training

def train(epochs, batch_size, dataset, baselr, use_pseudounet=False, use_unet=False, use_decay=False, plot_models=True):

    # Load data and normalize
    x_train_a, x_train_b, x_test_a, x_test_b = loadImagesFromDataset(h, w, dataset, use_hdf5=False)

    x_train_a = (x_train_a.astype(np.float32) - 127.5) / 127.5
    x_train_b = (x_train_b.astype(np.float32) - 127.5) / 127.5
    x_test_a = (x_test_a.astype(np.float32) - 127.5) / 127.5
    x_test_b = (x_test_b.astype(np.float32) - 127.5) / 127.5

    batchCount_a = x_train_a.shape[0] / batch_size
    batchCount_b = x_train_b.shape[0] / batch_size

    # Train on same image amount, would be best to have even sets
    batchCount = min([batchCount_a, batchCount_b])

    print('\nEpochs:', epochs)
    print('Batch size:', batch_size)
    print('Batches per epoch: ', batchCount, "\n")

    #Retrieve components and save model before training, to preserve weights initialization
    disc_a, disc_b, gen_a2b, gen_b2a = components(w, h, pseudounet=use_pseudounet, unet=use_unet, plot=plot_models)
    saveModels(0, gen_a2b, gen_b2a, disc_a, disc_b)

    #Initialize fake images pools
    pool_a2b = []
    pool_b2a = []

    # Define optimizers
    adam_disc = Adam(lr=baselr, beta_1=0.5)
    adam_gen = Adam(lr=baselr, beta_1=0.5)

    # Define image batches
    true_a = gen_a2b.inputs[0]
    true_b = gen_b2a.inputs[0]

    fake_b = gen_a2b.outputs[0]
    fake_a = gen_b2a.outputs[0]

    fake_pool_a = K.placeholder(shape=(None, 3, h, w))
    fake_pool_b = K.placeholder(shape=(None, 3, h, w))

    # Labels for generator training
    y_fake_a = K.ones_like(disc_a([fake_a]))
    y_fake_b = K.ones_like(disc_b([fake_b]))

    # Labels for discriminator training
    y_true_a = K.ones_like(disc_a([true_a])) * 0.9
    y_true_b = K.ones_like(disc_b([true_b])) * 0.9

    fakelabel_a2b = K.zeros_like(disc_b([fake_b]))
    fakelabel_b2a = K.zeros_like(disc_a([fake_a]))

    # Define losses
    disc_a_loss = mse_loss(y_true_a, disc_a([true_a])) + mse_loss(fakelabel_b2a, disc_a([fake_pool_a]))
    disc_b_loss = mse_loss(y_true_b, disc_b([true_b])) + mse_loss(fakelabel_a2b, disc_b([fake_pool_b]))

    gen_a2b_loss = mse_loss(y_fake_b, disc_b([fake_b]))
    gen_b2a_loss = mse_loss(y_fake_a, disc_a([fake_a]))

    cycle_a_loss = mae_loss(true_a, gen_b2a([fake_b]))
    cycle_b_loss = mae_loss(true_b, gen_a2b([fake_a]))
    cyclic_loss = cycle_a_loss + cycle_b_loss

    # Prepare discriminator updater
    discriminator_weights = disc_a.trainable_weights + disc_b.trainable_weights
    disc_loss = (disc_a_loss + disc_b_loss) * 0.5
    discriminator_updater = adam_disc.get_updates(discriminator_weights, [], disc_loss)

    # Prepare generator updater
    generator_weights = gen_a2b.trainable_weights + gen_b2a.trainable_weights
    gen_loss = (gen_a2b_loss + gen_b2a_loss + lmda * cyclic_loss)
    generator_updater = adam_gen.get_updates(generator_weights, [], gen_loss)

    # Define trainers
    generator_trainer = K.function([true_a, true_b], [gen_a2b_loss, gen_b2a_loss, cyclic_loss], generator_updater)
    discriminator_trainer = K.function([true_a, true_b, fake_pool_a, fake_pool_b], [disc_a_loss/2, disc_b_loss/2], discriminator_updater)

    epoch_counter = 1

    # Start training
    for e in range(1, epochs + 1):
        print('\n','-'*15, 'Epoch %d' % e, '-'*15)

        #Learning rate decay
        if use_decay and (epoch_counter > 100):
            lr -= baselr/100
            adam_disc.lr = lr
            adam_gen.lr = lr


        # Initialize progbar and batch counter
        #progbar = generic_utils.Progbar(batchCount)

        np.random.shuffle(x_train_a)
        np.random.shuffle(x_train_b)

        # Cycle through batches
        for i in trange(int(batchCount)):

            # Select true images for training
            #true_batch_a = x_train_a[np.random.randint(0, x_train_a.shape[0], size=batch_size)]
            #true_batch_b = x_train_b[np.random.randint(0, x_train_b.shape[0], size=batch_size)]

            true_batch_a = x_train_a[i*batch_size:i*batch_size+batch_size]
            true_batch_b = x_train_b[i*batch_size:i*batch_size+batch_size]

            # Fake images pool 
            a2b = gen_a2b.predict(true_batch_a)
            b2a = gen_b2a.predict(true_batch_b)

            tmp_b2a = []
            tmp_a2b = []

            for element in a2b:
                if len(pool_a2b) < 50:
                    pool_a2b.append(element)
                    tmp_a2b.append(element)
                else:
                    p = random.uniform(0, 1)

                    if p > 0.5:
                        index = random.randint(0, 49)
                        tmp = np.copy(pool_a2b[index])
                        pool_a2b[index] = element
                        tmp_a2b.append(tmp)
                    else:
                        tmp_a2b.append(element)
            
            for element in b2a:
                if len(pool_b2a) < 50:
                    pool_b2a.append(element)
                    tmp_b2a.append(element)
                else:
                    p = random.uniform(0, 1)

                    if p >0.5:
                        index = random.randint(0, 49)
                        tmp = np.copy(pool_b2a[index])
                        pool_b2a[index] = element
                        tmp_b2a.append(tmp)
                    else:
                        tmp_b2a.append(element)

            pool_a = np.array(tmp_b2a)
            pool_b = np.array(tmp_a2b)

            # Update network and obtain losses
            disc_a_err, disc_b_err = discriminator_trainer([true_batch_a, true_batch_b, pool_a, pool_b])
            gen_a2b_err, gen_b2a_err, cyclic_err = generator_trainer([true_batch_a, true_batch_b])

            # progbar.add(1, values=[
            #                             ("D A", disc_a_err*2),
            #                             ("D B", disc_b_err*2),
            #                             ("G A2B loss", gen_a2b_err),
            #                             ("G B2A loss", gen_b2a_err),
            #                             ("Cyclic loss", cyclic_err)
            #                            ])

        # Save losses for plotting
        disc_a_history.append(disc_a_err)
        disc_b_history.append(disc_b_err)

        gen_a2b_history_new.append(gen_a2b_err)
        gen_b2a_history_new.append(gen_b2a_err)

        #cycle_history.append(cyclic_err[0])
        plotLoss_new()

        plotGeneratedImages(epoch_counter, x_test_a, x_test_b, gen_a2b, gen_b2a)

        if epoch_counter > 150:
            saveModels(epoch_counter, gen_a2b, gen_b2a, disc_a, disc_b)

        epoch_counter += 1


if __name__ == '__main__':
    train(200, 1, "horse2zebra", lr, use_decay=True, use_pseudounet=False, use_unet=False, plot_models=False)