# CycleGAN


Ref: CHU, Casey; ZHMOGINOV, Andrey; SANDLER, Mark. Cyclegan, a master of steganography. arXiv preprint arXiv:1712.02950, 2017.

In [26]:
# load libraries
import datetime
import matplotlib.pyplot as plt
from data_loader import DataLoader
import numpy as np
import os

In [27]:
# from keras.datasets import mnist
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from keras.layers import Input, Dropout, Concatenate, Activation
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D, Conv2DTranspose
from keras.models import Model
from keras.optimizers import Adam
from keras.initializers import RandomNormal
from keras.layers.merge import add

In [28]:
import tensorflow as tf
from keras.engine.topology import Layer
from keras.engine import InputSpec

class ReflectionPadding2D(Layer):
    """ ref. https://stackoverflow.com/questions/50677544/reflection-padding-conv2d
    """
    def __init__(self, padding=(1, 1), **kwargs):
        self.padding = tuple(padding)
        self.input_spec = [InputSpec(ndim=4)]
        super(ReflectionPadding2D, self).__init__(**kwargs)

    def compute_output_shape(self, s):
        """ If you are using "channels_last" configuration"""
        return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])

    def call(self, x, mask=None):
        w_pad,h_pad = self.padding
        return tf.pad(x, [[0,0], [h_pad,h_pad], [w_pad,w_pad], [0,0] ], 'REFLECT')

## Data loader

In [29]:
from data_loader2 import DataLoader

## Helper function for the generator

In [30]:
def conv7s1(layer_input, filters, final, weight_init):
    y = ReflectionPadding2D(padding =(3,3))(layer_input)
    y = Conv2D(filters, kernel_size=(7,7), strides=1, padding='valid', kernel_initializer = weight_init)(y)
    if final:
        y = Activation('tanh')(y)
    else:
        y = InstanceNormalization(axis = -1, center = False, scale = False)(y)
        y = Activation('relu')(y)

    return y

def downsample(layer_input,filters, weight_init):
    y = Conv2D(filters, kernel_size=(3,3), strides=2, padding='same', kernel_initializer = weight_init)(layer_input)
    y = InstanceNormalization(axis = -1, center = False, scale = False)(y)
    y = Activation('relu')(y)
    
    return y

def residual(layer_input, filters, weight_init):
    shortcut = layer_input
    y = ReflectionPadding2D(padding =(1,1))(layer_input)
    y = Conv2D(filters, kernel_size=(3, 3), strides=1, padding='valid', kernel_initializer = weight_init)(y)
    y = InstanceNormalization(axis = -1, center = False, scale = False)(y)
    y = Activation('relu')(y)

    y = ReflectionPadding2D(padding =(1,1))(y)
    y = Conv2D(filters, kernel_size=(3, 3), strides=1, padding='valid', kernel_initializer = weight_init)(y)
    y = InstanceNormalization(axis = -1, center = False, scale = False)(y)

    return add([shortcut, y])

def upsample(layer_input,filters, weight_init):
    y = Conv2DTranspose(filters, kernel_size=(3, 3), strides=2, padding='same', kernel_initializer = weight_init)(layer_input)
    y = InstanceNormalization(axis = -1, center = False, scale = False)(y)
    y = Activation('relu')(y)
    
    return y

## Generator

In [31]:
def build_generator(img_shape, gf, weight_init):
    """ U-Net Generator
        :param gf: generator filter size in the first layer
        
        
        ref.: Kaiming He et al., “Deep Residual Learning for Image Recognition” 
              10 December 2015, https://arxiv.org/abs/1512.03385
    """
    # Image input
    img = Input(shape=img_shape)

    y = img

    y = conv7s1(y, gf, False, weight_init)
    y = downsample(y, gf * 2, weight_init)
    y = downsample(y, gf * 4, weight_init)
    y = residual(y, gf * 4, weight_init)
    y = residual(y, gf * 4, weight_init)
    y = residual(y, gf * 4, weight_init)
    y = residual(y, gf * 4, weight_init)
    y = residual(y, gf * 4, weight_init)
    y = residual(y, gf * 4, weight_init)
    y = residual(y, gf * 4, weight_init)
    y = residual(y, gf * 4, weight_init)
    y = residual(y, gf * 4, weight_init)
    y = upsample(y, gf * 2, weight_init)
    y = upsample(y, gf, weight_init)
    y = conv7s1(y, 3, True, weight_init)
    output = y

    return Model(img, output)

## Helper function for Discriminator

In [32]:
def conv4(layer_input, filters, weight_init, stride = 2, norm=True):
    y = Conv2D(filters, kernel_size=(4,4), strides=stride, padding='same', kernel_initializer = weight_init)(layer_input)
    if norm:
        y = InstanceNormalization(axis = -1, center = False, scale = False)(y)
    y = LeakyReLU(0.2)(y)

    return y

# Discriminator

In [33]:
def build_discriminator(img_shape, df, weight_init):

    img = Input(shape=img_shape)

    y = conv4(img, df, weight_init, stride = 2, norm = False)
    y = conv4(y, df * 2, weight_init, stride = 2)
    y = conv4(y, df * 4, weight_init, stride = 4)
    y = conv4(y, df * 8, weight_init, stride = 1)

    output = Conv2D(1, kernel_size=4, strides=1, padding='same', kernel_initializer = weight_init)(y)

    return Model(img, output)

In [34]:
def sample_images(epoch, batch_i,
                  g_AB, g_BA, # generators
                  dataset_name, data_loader):
    os.makedirs('images/%s' % dataset_name, exist_ok=True)
    r, c = 2, 3 # rows, columns

    imgs_A = data_loader.load_data(domain="A", batch_size=1, is_testing=True)
    imgs_B = data_loader.load_data(domain="B", batch_size=1, is_testing=True)

    # Translate images to the other domain
    fake_B = g_AB.predict(imgs_A)
    fake_A = g_BA.predict(imgs_B)
    # Translate back to original domain
    reconstr_A = g_BA.predict(fake_B)
    reconstr_B = g_AB.predict(fake_A)

    gen_imgs = np.concatenate([imgs_A, fake_B, reconstr_A, imgs_B, fake_A, reconstr_B])

    # Rescale images 0 - 1
    gen_imgs = 0.5 * gen_imgs + 0.5

    titles = ['Original', 'Translated', 'Reconstructed']
    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):  # rows
        for j in range(c):  # columns
            axs[i,j].imshow(gen_imgs[cnt])
            axs[i, j].set_title(titles[j])
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("images/%s/%d_%d.png" % (dataset_name, epoch, batch_i))
    plt.close()

## Perform training of the GAN

In [35]:
def train(epochs, 
          g_AB, g_BA, # generators
          d_A, d_B, # discriminators
          combined, # full GAN
          dataset_name, img_rows, img_cols,
          batch_size=1, sample_interval=50,
          ):
    
    # Calculate output shape of D (PatchGAN)
    patch = int(img_rows / 2**4)
    disc_patch = (patch, patch, 1)

    # class that loads the data in batches
    data_loader = DataLoader(dataset_name=dataset_name,
                             img_res=(img_rows, img_cols))

    start_time = datetime.datetime.now()

    # Adversarial loss ground truths
    valid = np.ones((batch_size,) + disc_patch)
    fake = np.zeros((batch_size,) + disc_patch)

    for epoch in range(epochs):
        for batch_i, (imgs_A, imgs_B) in enumerate(data_loader.load_batch(batch_size)):

            # ----------------------
            #  Train Discriminators
            # ----------------------

            # Translate images to opposite domain
            fake_B = g_AB.predict(imgs_A)
            fake_A = g_BA.predict(imgs_B)

            # Train the discriminators (original images = real / translated = Fake)
            dA_loss_real = d_A.train_on_batch(imgs_A, valid)
            dA_loss_fake = d_A.train_on_batch(fake_A, fake)
            dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)

            dB_loss_real = d_B.train_on_batch(imgs_B, valid)
            dB_loss_fake = d_B.train_on_batch(fake_B, fake)
            dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

            # Total disciminator loss
            d_loss = 0.5 * np.add(dA_loss, dB_loss)


            # ------------------
            #  Train Generators
            # ------------------

            # Train the generators
            g_loss = combined.train_on_batch([imgs_A, imgs_B],
                                             [valid, valid,
                                              imgs_A, imgs_B,
                                              imgs_A, imgs_B])

            elapsed_time = datetime.datetime.now() - start_time

            # Plot the progress
            print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f, id: %05f] time: %s " \
                                                                    % ( epoch, epochs,
                                                                        batch_i, data_loader.n_batches,
                                                                        d_loss[0], 100*d_loss[1],
                                                                        g_loss[0],
                                                                        np.mean(g_loss[1:3]),
                                                                        np.mean(g_loss[3:5]),
                                                                        np.mean(g_loss[5:6]),
                                                                        elapsed_time))

            # If at save interval => save generated image samples
            if batch_i % sample_interval == 0:
                sample_images(epoch, batch_i, 
                              g_AB, g_BA, # generators
                              dataset_name, data_loader)

## main()

In [36]:
img_rows = 128
img_cols = 128
channels = 3
img_shape = (img_rows, img_cols, channels)

# Configure data loader
dataset_name = 'apple2orange'

In [37]:
weight_init = RandomNormal(mean=0., stddev=0.02)

In [38]:
# Loss weights
lambda_cycle = 10.0               # Cycle-consistency loss
lambda_id = 0.1 * lambda_cycle    # Identity loss

In [39]:
# optimizer
optimizer = Adam(0.0002, 0.5)

In [40]:
# Number of filters in the first layer of G and D
gf = 32
df = 64

In [41]:
# Build and compile the discriminators
d_A = build_discriminator(img_shape, df, weight_init)
d_B = build_discriminator(img_shape, df, weight_init)
d_A.compile(loss='mse',
    optimizer=optimizer,
    metrics=['accuracy'])
d_B.compile(loss='mse',
    optimizer=optimizer,
    metrics=['accuracy'])

d_A.summary()

Model: "model_8"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_9 (InputLayer)         (None, 128, 128, 3)       0         
_________________________________________________________________
conv2d_65 (Conv2D)           (None, 64, 64, 64)        3136      
_________________________________________________________________
leaky_re_lu_17 (LeakyReLU)   (None, 64, 64, 64)        0         
_________________________________________________________________
conv2d_66 (Conv2D)           (None, 32, 32, 128)       131200    
_________________________________________________________________
instance_normalization_59 (I (None, 32, 32, 128)       0         
_________________________________________________________________
leaky_re_lu_18 (LeakyReLU)   (None, 32, 32, 128)       0         
_________________________________________________________________
conv2d_67 (Conv2D)           (None, 8, 8, 256)         5245

In [42]:
#-------------------------
# Construct Computational
#   Graph of Generators
#-------------------------

# Build the generators
g_AB = build_generator(img_shape, gf, weight_init)
g_BA = build_generator(img_shape, gf, weight_init)

In [43]:
# Input images from both domains
img_A = Input(shape=img_shape)
img_B = Input(shape=img_shape)

In [44]:
# Translate images to the other domain
fake_B = g_AB(img_A)
fake_A = g_BA(img_B)

In [45]:
# Translate images back to original domain
reconstr_A = g_BA(fake_B)
reconstr_B = g_AB(fake_A)

In [46]:
# Identity mapping of images
img_A_id = g_BA(img_A)
img_B_id = g_AB(img_B)

In [47]:
# For the combined model we will only train the generators
d_A.trainable = False
d_B.trainable = False

# Discriminators determines validity of translated images
valid_A = d_A(fake_A)
valid_B = d_B(fake_B)

In [48]:
# Combined model trains generators to fool discriminators
combined = Model(inputs=[img_A, img_B],
                      outputs=[ valid_A, valid_B,
                                reconstr_A, reconstr_B,
                                img_A_id, img_B_id ])
combined.compile(loss=['mse', 'mse',
                            'mae', 'mae',
                            'mae', 'mae'],
                    loss_weights=[  1, 1,
                                    lambda_cycle, lambda_cycle,
                                    lambda_id, lambda_id ],
                    optimizer=optimizer)

In [None]:
epochs = 2
train(epochs=epochs, batch_size=1, sample_interval=200,
      g_AB=g_AB, g_BA=g_BA, d_A=d_A, d_B=d_B, combined=combined,
      dataset_name=dataset_name, img_rows=img_rows, img_cols=img_cols
      )

  'Discrepancy between trainable weights and collected trainable'


[Epoch 0/2] [Batch 0/995] [D loss: 2.000792, acc:  28%] [G loss: 29.565210, adv: 5.295656, recon: 0.864134, id: 0.956028] time: 0:00:38.102565 
[Epoch 0/2] [Batch 1/995] [D loss: 4.242106, acc:  25%] [G loss: 22.354902, adv: 2.932665, recon: 0.749814, id: 0.863008] time: 0:00:40.991532 
[Epoch 0/2] [Batch 2/995] [D loss: 3.104193, acc:  25%] [G loss: 19.969934, adv: 3.053705, recon: 0.629672, id: 0.812977] time: 0:00:43.269533 
[Epoch 0/2] [Batch 3/995] [D loss: 3.100769, acc:  26%] [G loss: 21.144304, adv: 3.783855, recon: 0.617508, id: 0.729241] time: 0:00:45.609533 
[Epoch 0/2] [Batch 4/995] [D loss: 4.996716, acc:  20%] [G loss: 20.225487, adv: 3.791628, recon: 0.575086, id: 0.587869] time: 0:00:48.011532 
[Epoch 0/2] [Batch 5/995] [D loss: 2.492306, acc:  25%] [G loss: 19.631672, adv: 2.330153, recon: 0.677704, id: 0.785379] time: 0:00:50.436533 
[Epoch 0/2] [Batch 6/995] [D loss: 3.546727, acc:  24%] [G loss: 15.401932, adv: 2.195482, recon: 0.500669, id: 0.578063] time: 0:00:52.

[Epoch 0/2] [Batch 57/995] [D loss: 0.580573, acc:  48%] [G loss: 13.380840, adv: 0.654946, recon: 0.546718, id: 0.508078] time: 0:03:29.930535 
[Epoch 0/2] [Batch 58/995] [D loss: 0.659199, acc:  39%] [G loss: 12.706997, adv: 0.539571, recon: 0.525923, id: 0.381460] time: 0:03:33.382532 
[Epoch 0/2] [Batch 59/995] [D loss: 0.692901, acc:  40%] [G loss: 14.649514, adv: 0.488370, recon: 0.619799, id: 0.626680] time: 0:03:36.874532 
[Epoch 0/2] [Batch 60/995] [D loss: 0.687771, acc:  39%] [G loss: 9.404800, adv: 0.677464, recon: 0.363756, id: 0.372094] time: 0:03:40.271532 
[Epoch 0/2] [Batch 61/995] [D loss: 0.807573, acc:  39%] [G loss: 9.734346, adv: 0.772186, recon: 0.370794, id: 0.372183] time: 0:03:43.628533 
[Epoch 0/2] [Batch 62/995] [D loss: 0.926175, acc:  39%] [G loss: 9.457861, adv: 0.506066, recon: 0.381674, id: 0.479276] time: 0:03:47.032535 
[Epoch 0/2] [Batch 63/995] [D loss: 0.622785, acc:  41%] [G loss: 10.122426, adv: 0.717768, recon: 0.394109, id: 0.338855] time: 0:03