<a href="https://colab.research.google.com/github/objectc/Generative-Models/blob/master/CycleGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install tensorflow-gpu==2.0.0-beta1

Collecting tensorflow-gpu==2.0.0-beta1
[?25l  Downloading https://files.pythonhosted.org/packages/2b/53/e18c5e7a2263d3581a979645a185804782e59b8e13f42b9c3c3cfb5bb503/tensorflow_gpu-2.0.0b1-cp36-cp36m-manylinux1_x86_64.whl (348.9MB)
[K     |████████████████████████████████| 348.9MB 16kB/s 
Installing collected packages: tensorflow-gpu
Successfully installed tensorflow-gpu-2.0.0b1


Install keras-contrib for InstanceNormalization, you can also implement it by yourself. See [this link](https://stackoverflow.com/a/48118940/1651560) for the details of InstanceNormalization and BatchNormalization.
P.S I forked the original repo and updated the code in order to make it compatible with TensorFlow 2.0.

In [3]:
!pip --upgrade --force-reinstall install git+https://github.com/objectc/keras-contrib.git


Usage:   
  pip3 <command> [options]

no such option: --upgrade


In [6]:
!pip install --force-reinstall git+https://github.com/objectc/keras-contrib.git

Collecting git+https://github.com/objectc/keras-contrib.git
  Cloning https://github.com/objectc/keras-contrib.git to /tmp/pip-req-build-rpbtvuvi
  Running command git clone -q https://github.com/objectc/keras-contrib.git /tmp/pip-req-build-rpbtvuvi
Collecting keras (from keras-contrib==2.0.8)
[?25l  Downloading https://files.pythonhosted.org/packages/5e/10/aa32dad071ce52b5502266b5c659451cfd6ffcbf14e6c8c4f16c0ff5aaab/Keras-2.2.4-py2.py3-none-any.whl (312kB)
[K     |████████████████████████████████| 317kB 9.0MB/s 
[?25hCollecting keras-preprocessing>=1.0.5 (from keras->keras-contrib==2.0.8)
[?25l  Downloading https://files.pythonhosted.org/packages/28/6a/8c1f62c37212d9fc441a7e26736df51ce6f0e38455816445471f10da4f0a/Keras_Preprocessing-1.1.0-py2.py3-none-any.whl (41kB)
[K     |████████████████████████████████| 51kB 27.4MB/s 
[?25hCollecting scipy>=0.14 (from keras->keras-contrib==2.0.8)
[?25l  Downloading https://files.pythonhosted.org/packages/72/4c/5f81e7264b0a7a8bd570810f48cd346

In [21]:
!ls /usr/local/lib/python3.6/dist-packages/keras_contrib/layers/normalization/

ls: cannot access '/usr/local/lib/python3.6/dist-packages/keras_contrib/layers/normalization/': No such file or directory


This implementation of CycleGAN are based on the paper [Unpaired Image-to-Image Translation
using Cycle-Consistent Adversarial Networks](https://arxiv.org/pdf/1703.10593.pdf)

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate, BatchNormalization, Activation, ZeroPadding2D, LeakyReLU, UpSampling2D, Conv2D
from keras_contrib.layers import InstanceNormalization
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
import functools

Using TensorFlow backend.


According to the paper,  reflection padding was used to reduce artifacts.

In [0]:
# Keras does not provide reflection padding implementation, we need to custom our own padding layer class.

class ReflectionPadding2D(keras.layers.Layer):
    def __init__(self, kernel_size = None, padding=(1, 1), **kwargs):
        self.padding = tuple(padding)
        # if kernel_size pass in, use half of it as the padding
        if kernel_size:
          self.padding = (kernel_size//2, kernel_size//2)
        self.input_spec = [keras.layers.InputSpec(ndim=4)]
        super(ReflectionPadding2D, self).__init__(**kwargs)

    def compute_output_shape(self, s):
        """ If you are using "channels_last" configuration"""
        return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])

    def call(self, x, mask=None):
        w_pad,h_pad = self.padding
        return tf.pad(x, [[0,0], [h_pad,h_pad], [w_pad,w_pad], [0,0] ], 'REFLECT')

In [0]:
class CycleGAN:
  
  def __init__(self):
    self.input_shape = (128, 128, 3)
    self.conv = functools.partial(keras.layers.Conv2D, kernel_size=3, strides=2, activation=LeakyReLU, padding='same')
  
  # Residual blocks
  def residual(self, X, filters):
    X_shortcut = X
#     output = InstanceNormalization()(X)
#     output = keras.layers.LeakyReLU()(output)
    output = Conv2D(filters=filters, kernel_size=[3, 3], strides=[1, 1], padding="same")(X)

    output = InstanceNormalization()(output)
    output = keras.layers.LeakyReLU()(output)
    output = Conv2D(filters=filters, kernel_size=[3, 3], strides=[1, 1], padding="same")(output)

    output = keras.layers.add([X_shortcut,output])

    return output
  
  def create_generator(self):
    def conv(layer_input, filters, kernel_size=3, strides=2):
      output = ReflectionPadding2D(kernel_size=kernel_size)
      output = keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, strides=strides)(layer_input)
      # The paper point out that they use InstanceNormalization instead of BatchNormalization. 
      output = InstanceNormalization()(output)
      output = keras.layers.LeakyReLU()(output)
      return output
      
    input = keras.layers.Input(shape=self.input_shape)
    # Downsampling
    d = conv(input, filters=64, kernel_size=7, strides=1)
    d = conv(d, filters=128, kernel_size=3, strides=2)
    d = conv(d, filters=256, kernel_size=3, strides=2)
    
    # resnet layers
    for i in range(6):
      r = self.residual(d, 256)
      
    # Upsampling
    u = InstanceNormalization()(r)
    u = keras.layers.LeakyReLU()(u)
    u = keras.layers.Conv2DTranspose(128, kernel_size=3, strides=2, padding='same')(u)
    u = InstanceNormalization()(u)
    u = keras.layers.LeakyReLU()(u)
    u = keras.layers.Conv2DTranspose(64, kernel_size=3, strides=2, padding='same')(u)
    u = InstanceNormalization()(u)
    u = keras.layers.LeakyReLU()(u)
    output = keras.layers.Conv2DTranspose(3, kernel_size=7, padding='same', activation='tanh')(u)
    return keras.models.Model(input, output)
    

  
  def create_discriminator(self):
    input = Input(shape=self.input_shape)
    d = input
    kernel_size = 4
    for depth in range(4):
      if depth == 3:
        d = keras.layers.ZeroPadding2D(1)(d)
      d = keras.layers.Conv2D(64*(2**depth), kernel_size=kernel_size, strides=2, padding='same')(d)
      if depth > 0:
        d = InstanceNormalization()(d)
      d = keras.layers.LeakyReLU()(d)
    output = keras.layers.ZeroPadding2D(1)(d)
    output = keras.layers.Conv2D(1, kernel_size=kernel_size, activation='sigmoid')(output)
    return Model(input, output)
    

In [0]:
GAN = CycleGAN()
gen = GAN.create_generator()
dis = GAN.create_discriminator()

The [PatchGAN](https://arxiv.org/pdf/1611.07004.pdf) / Markovian discriminator works by classifying individual (N x N) patches in the image as “real vs. fake”, opposed to classifying the entire image as “real vs. fake”. The authors reason that this enforces more constraints that encourage sharp high-frequency detail. Additionally, the PatchGAN has fewer parameters and runs faster than classifying the entire image.
Read for more details [here](https://towardsdatascience.com/pix2pix-869c17900998)

In [0]:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

# Images size
w = 256
h = 256

# Cyclic consistency factor

lmda = 10

# Optimizer parameters

lr = 0.0002
beta_1 = 0.5
beta_2 = 0.999
epsilon = 1e-08



disc_a_history = []
disc_b_history = []

gen_a2b_history = {'bc':[], 'mae':[]} 
gen_b2a_history = {'bc':[], 'mae':[]}

gen_b2a_history_new = []
gen_a2b_history_new = []
cycle_history = []

# Data loading

def loadImage(path, h, w):
    
    '''Load single image from specified path'''
    img = image.load_img(path)
    img = img.resize((w,h))
    x = image.img_to_array(img)
    return x


def loadImagesFromDataset(h, w, dataset, use_hdf5=False):

    '''Return a tuple (trainA, trainB, testA, testB) 
    containing numpy arrays populated from the
     test and train set for each part of the cGAN'''

    if (use_hdf5):
        path="./datasets/processed/"+dataset+"_data.h5"
        data = []
        print('\n', '-' * 15, 'Loading data from dataset', dataset, '-' * 15)
        with h5py.File(path, "r") as hf:
            for set_name in tqdm(["trainA_data", "trainB_data", "testA_data", "testB_data"]):
                data.append(hf[set_name][:].astype(np.float32))

        return (set_data for set_data in data)

    else:
        path = "./datasets/"+dataset
        print(path)
        train_a = glob.glob(path + "/trainA/*.png")
        train_b = glob.glob(path + "/trainB/*.png")
        test_a = glob.glob(path + "/testA/*.png")
        test_b = glob.glob(path + "/testB/*.png")

        print("Import trainA")
        if dataset == "nike2adidas" or ("adiedges" in dataset):
            tr_a = np.array([loadImage(p, h, w) for p in tqdm(train_a[:1000])])
        else:
            tr_a = np.array([loadImage(p, h, w) for p in tqdm(train_a)])

        print("Import trainB")
        if dataset == "nike2adidas" or ("adiedges" in dataset):
            tr_b = np.array([loadImage(p, h, w) for p in tqdm(train_b[:1000])])
        else:
            tr_b = np.array([loadImage(p, h, w) for p in tqdm(train_b)])

        print("Import testA")
        ts_a = np.array([loadImage(p, h, w) for p in tqdm(test_a)])

        print("Import testB")
        ts_b = np.array([loadImage(p, h, w) for p in tqdm(test_b)])

    return tr_a, tr_b, ts_a, ts_b
    


# Create a wall of generated images

def plotGeneratedImages(epoch, set_a, set_b, generator_a2b, generator_b2a, examples=6):
    
    true_batch_a = set_a[np.random.randint(0, set_a.shape[0], size=examples)]
    true_batch_b = set_b[np.random.randint(0, set_b.shape[0], size=examples)]

    # Get fake and cyclic images
    generated_a2b = generator_a2b.predict(true_batch_a)
    cycle_a = generator_b2a.predict(generated_a2b)
    generated_b2a = generator_b2a.predict(true_batch_b)
    cycle_b = generator_a2b.predict(generated_b2a)
    
    k = 0

    # Allocate figure
    plt.figure(figsize=(w/10, h/10))

    for output in [true_batch_a, generated_a2b, cycle_a, true_batch_b, generated_b2a, cycle_b]:
        output = (output+1.0)/2.0
        for i in range(output.shape[0]):
            plt.subplot(examples, examples, k*examples +(i + 1))
            img = output[i].transpose(1, 2, 0)  # Using (ch, h, w) scheme needs rearranging for plt to (h, w, ch)
            #print(img.shape)
            plt.imshow(img)
            plt.axis('off')
        plt.tight_layout()
        k += 1
    plt.savefig("images/epoch"+str(epoch)+".png")
    plt.close()


# Plot the loss from each batch

def plotLoss_new():
    plt.figure(figsize=(10, 8))
    plt.plot(disc_a_history, label='Discriminator A loss')
    plt.plot(disc_b_history, label='Discriminator B loss')
    plt.plot(gen_a2b_history_new, label='Generator a2b loss')
    plt.plot(gen_b2a_history_new, label='Generator b2a loss')
    #plt.plot(cycle_history, label="Cyclic loss")
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('images/cyclegan_loss.png')
    plt.close()

def saveModels(epoch, genA2B, genB2A, discA, discB):
    genA2B.save('models/generatorA2B_epoch_%d.h5' % epoch)
    genB2A.save('models/generatorB2A_epoch_%d.h5' % epoch)
    discA.save('models/discriminatorA_epoch_%d.h5' % epoch)
    discB.save('models/discriminatorB_epoch_%d.h5' % epoch)


# Training

def train(epochs, batch_size, dataset, baselr, use_pseudounet=False, use_unet=False, use_decay=False, plot_models=True):

    # Load data and normalize
    x_train_a, x_train_b, x_test_a, x_test_b = loadImagesFromDataset(h, w, dataset, use_hdf5=False)

    x_train_a = (x_train_a.astype(np.float32) - 127.5) / 127.5
    x_train_b = (x_train_b.astype(np.float32) - 127.5) / 127.5
    x_test_a = (x_test_a.astype(np.float32) - 127.5) / 127.5
    x_test_b = (x_test_b.astype(np.float32) - 127.5) / 127.5

    batchCount_a = x_train_a.shape[0] / batch_size
    batchCount_b = x_train_b.shape[0] / batch_size

    # Train on same image amount, would be best to have even sets
    batchCount = min([batchCount_a, batchCount_b])

    print('\nEpochs:', epochs)
    print('Batch size:', batch_size)
    print('Batches per epoch: ', batchCount, "\n")

    #Retrieve components and save model before training, to preserve weights initialization
    disc_a, disc_b, gen_a2b, gen_b2a = components(w, h, pseudounet=use_pseudounet, unet=use_unet, plot=plot_models)
    saveModels(0, gen_a2b, gen_b2a, disc_a, disc_b)

    #Initialize fake images pools
    pool_a2b = []
    pool_b2a = []

    # Define optimizers
    adam_disc = Adam(lr=baselr, beta_1=0.5)
    adam_gen = Adam(lr=baselr, beta_1=0.5)

    # Define image batches
    true_a = gen_a2b.inputs[0]
    true_b = gen_b2a.inputs[0]

    fake_b = gen_a2b.outputs[0]
    fake_a = gen_b2a.outputs[0]

    fake_pool_a = K.placeholder(shape=(None, 3, h, w))
    fake_pool_b = K.placeholder(shape=(None, 3, h, w))

    # Labels for generator training
    y_fake_a = K.ones_like(disc_a([fake_a]))
    y_fake_b = K.ones_like(disc_b([fake_b]))

    # Labels for discriminator training
    y_true_a = K.ones_like(disc_a([true_a])) * 0.9
    y_true_b = K.ones_like(disc_b([true_b])) * 0.9

    fakelabel_a2b = K.zeros_like(disc_b([fake_b]))
    fakelabel_b2a = K.zeros_like(disc_a([fake_a]))

    # Define losses
    disc_a_loss = mse_loss(y_true_a, disc_a([true_a])) + mse_loss(fakelabel_b2a, disc_a([fake_pool_a]))
    disc_b_loss = mse_loss(y_true_b, disc_b([true_b])) + mse_loss(fakelabel_a2b, disc_b([fake_pool_b]))

    gen_a2b_loss = mse_loss(y_fake_b, disc_b([fake_b]))
    gen_b2a_loss = mse_loss(y_fake_a, disc_a([fake_a]))

    cycle_a_loss = mae_loss(true_a, gen_b2a([fake_b]))
    cycle_b_loss = mae_loss(true_b, gen_a2b([fake_a]))
    cyclic_loss = cycle_a_loss + cycle_b_loss

    # Prepare discriminator updater
    discriminator_weights = disc_a.trainable_weights + disc_b.trainable_weights
    disc_loss = (disc_a_loss + disc_b_loss) * 0.5
    discriminator_updater = adam_disc.get_updates(discriminator_weights, [], disc_loss)

    # Prepare generator updater
    generator_weights = gen_a2b.trainable_weights + gen_b2a.trainable_weights
    gen_loss = (gen_a2b_loss + gen_b2a_loss + lmda * cyclic_loss)
    generator_updater = adam_gen.get_updates(generator_weights, [], gen_loss)

    # Define trainers
    generator_trainer = K.function([true_a, true_b], [gen_a2b_loss, gen_b2a_loss, cyclic_loss], generator_updater)
    discriminator_trainer = K.function([true_a, true_b, fake_pool_a, fake_pool_b], [disc_a_loss/2, disc_b_loss/2], discriminator_updater)

    epoch_counter = 1

    # Start training
    for e in range(1, epochs + 1):
        print('\n','-'*15, 'Epoch %d' % e, '-'*15)

        #Learning rate decay
        if use_decay and (epoch_counter > 100):
            lr -= baselr/100
            adam_disc.lr = lr
            adam_gen.lr = lr


        # Initialize progbar and batch counter
        #progbar = generic_utils.Progbar(batchCount)

        np.random.shuffle(x_train_a)
        np.random.shuffle(x_train_b)

        # Cycle through batches
        for i in trange(int(batchCount)):

            # Select true images for training
            #true_batch_a = x_train_a[np.random.randint(0, x_train_a.shape[0], size=batch_size)]
            #true_batch_b = x_train_b[np.random.randint(0, x_train_b.shape[0], size=batch_size)]

            true_batch_a = x_train_a[i*batch_size:i*batch_size+batch_size]
            true_batch_b = x_train_b[i*batch_size:i*batch_size+batch_size]

            # Fake images pool 
            a2b = gen_a2b.predict(true_batch_a)
            b2a = gen_b2a.predict(true_batch_b)

            tmp_b2a = []
            tmp_a2b = []

            for element in a2b:
                if len(pool_a2b) < 50:
                    pool_a2b.append(element)
                    tmp_a2b.append(element)
                else:
                    p = random.uniform(0, 1)

                    if p > 0.5:
                        index = random.randint(0, 49)
                        tmp = np.copy(pool_a2b[index])
                        pool_a2b[index] = element
                        tmp_a2b.append(tmp)
                    else:
                        tmp_a2b.append(element)
            
            for element in b2a:
                if len(pool_b2a) < 50:
                    pool_b2a.append(element)
                    tmp_b2a.append(element)
                else:
                    p = random.uniform(0, 1)

                    if p >0.5:
                        index = random.randint(0, 49)
                        tmp = np.copy(pool_b2a[index])
                        pool_b2a[index] = element
                        tmp_b2a.append(tmp)
                    else:
                        tmp_b2a.append(element)

            pool_a = np.array(tmp_b2a)
            pool_b = np.array(tmp_a2b)

            # Update network and obtain losses
            disc_a_err, disc_b_err = discriminator_trainer([true_batch_a, true_batch_b, pool_a, pool_b])
            gen_a2b_err, gen_b2a_err, cyclic_err = generator_trainer([true_batch_a, true_batch_b])

            # progbar.add(1, values=[
            #                             ("D A", disc_a_err*2),
            #                             ("D B", disc_b_err*2),
            #                             ("G A2B loss", gen_a2b_err),
            #                             ("G B2A loss", gen_b2a_err),
            #                             ("Cyclic loss", cyclic_err)
            #                            ])

        # Save losses for plotting
        disc_a_history.append(disc_a_err)
        disc_b_history.append(disc_b_err)

        gen_a2b_history_new.append(gen_a2b_err)
        gen_b2a_history_new.append(gen_b2a_err)

        #cycle_history.append(cyclic_err[0])
        plotLoss_new()

        plotGeneratedImages(epoch_counter, x_test_a, x_test_b, gen_a2b, gen_b2a)

        if epoch_counter > 150:
            saveModels(epoch_counter, gen_a2b, gen_b2a, disc_a, disc_b)

        epoch_counter += 1


if __name__ == '__main__':
    train(200, 1, "horse2zebra", lr, use_decay=True, use_pseudounet=False, use_unet=False, plot_models=False)