# CycleGAN


Ref: CHU, Casey; ZHMOGINOV, Andrey; SANDLER, Mark. Cyclegan, a master of steganography. arXiv preprint arXiv:1712.02950, 2017.

In [1]:
# load libraries
import datetime
import matplotlib.pyplot as plt
from data_loader import DataLoader
import numpy as np
import os

In [2]:
# from keras.datasets import mnist
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from keras.layers import Input, Dropout, Concatenate
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Model
from keras.optimizers import Adam

Using TensorFlow backend.


## Data loader

In [3]:
from data_loader import DataLoader

## Helper function for the generator

In [4]:
def conv2d(layer_input, filters, f_size=4, strides=2, alpha=0.2):
    """Layers used during downsampling"""
    d = Conv2D(filters, kernel_size=f_size, strides=strides, padding='same')(layer_input)
    d = LeakyReLU(alpha=alpha)(d)
    d = InstanceNormalization()(d)
    return d

def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0, up_size=2, strides=1):
    """ Layers used during upsampling
        :param skip_input: layer to make skip connection to
    """
    u = UpSampling2D(size=up_size)(layer_input)
    u = Conv2D(filters, kernel_size=f_size, strides=strides, padding='same', activation='relu')(u)
    if dropout_rate:
        u = Dropout(dropout_rate)(u)
    u = InstanceNormalization()(u)
    u = Concatenate()([u, skip_input])
    return u

## Generator

In [5]:
def build_generator(img_shape, gf, channels):
    """ U-Net Generator
        :param gf: generator filter size in the first layer
        
        
        ref.: RONNEBERGER, Olaf; FISCHER, Philipp; BROX, Thomas. U-net: Convolutional networks for biomedical image segmentation. 
              In: International Conference on Medical image computing and computer-assisted intervention. 
              Springer, Cham, 2015. p. 234-241.
    """

    # Image input
    d0 = Input(shape=img_shape)

    # Downsampling
    d1 = conv2d(d0, gf)
    d2 = conv2d(d1, gf*2)
    d3 = conv2d(d2, gf*4)
    d4 = conv2d(d3, gf*8)

    # Upsampling
    u1 = deconv2d(d4, d3, gf*4)
    u2 = deconv2d(u1, d2, gf*2)
    u3 = deconv2d(u2, d1, gf)

    u4 = UpSampling2D(size=2)(u3)
    output_img = Conv2D(channels, kernel_size=4, strides=1, padding='same', activation='tanh')(u4)

    return Model(d0, output_img)

## Helper function for Discriminator

In [6]:
def d_layer(layer_input, filters, f_size=4, normalization=True, strides=2):
    """Discriminator layer"""
    d = Conv2D(filters, kernel_size=f_size, strides=strides, padding='same')(layer_input)
    d = LeakyReLU(alpha=0.2)(d)
    if normalization:
        d = InstanceNormalization()(d)
    return d

# Discriminator

In [7]:
def build_discriminator(img_shape, df):

    img = Input(shape=img_shape)

    d1 = d_layer(img, df, normalization=False)
    d2 = d_layer(d1, df*2)
    d3 = d_layer(d2, df*4)
    d4 = d_layer(d3, df*8)

    validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)

    return Model(img, validity)

In [8]:
def sample_images(epoch, batch_i,
                  g_AB, g_BA, # generators
                  dataset_name, data_loader):
    os.makedirs('images/%s' % dataset_name, exist_ok=True)
    r, c = 2, 3

    imgs_A = data_loader.load_data(domain="A", batch_size=1, is_testing=True)
    imgs_B = data_loader.load_data(domain="B", batch_size=1, is_testing=True)

    # Translate images to the other domain
    fake_B = g_AB.predict(imgs_A)
    fake_A = g_BA.predict(imgs_B)
    # Translate back to original domain
    reconstr_A = g_BA.predict(fake_B)
    reconstr_B = g_AB.predict(fake_A)

    gen_imgs = np.concatenate([imgs_A, fake_B, reconstr_A, imgs_B, fake_A, reconstr_B])

    # Rescale images 0 - 1
    gen_imgs = 0.5 * gen_imgs + 0.5

    titles = ['Original', 'Translated', 'Reconstructed']
    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt])
            axs[i, j].set_title(titles[j])
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("images/%s/%d_%d.png" % (dataset_name, epoch, batch_i))
    plt.close()

## Perform training of the GAN

In [9]:
def train(epochs, 
          g_AB, g_BA, # generators
          d_A, d_B, # discriminators
          combined, # full GAN
          dataset_name, img_rows, img_cols,
          batch_size=1, sample_interval=50,
          ):
    
    # Calculate output shape of D (PatchGAN)
    patch = int(img_rows / 2**4)
    disc_patch = (patch, patch, 1)

    # class that loads the data in batches
    data_loader = DataLoader(dataset_name=dataset_name,
                             img_res=(img_rows, img_cols))

    start_time = datetime.datetime.now()

    # Adversarial loss ground truths
    valid = np.ones((batch_size,) + disc_patch)
    fake = np.zeros((batch_size,) + disc_patch)

    for epoch in range(epochs):
        for batch_i, (imgs_A, imgs_B) in enumerate(data_loader.load_batch(batch_size)):

            # ----------------------
            #  Train Discriminators
            # ----------------------

            # Translate images to opposite domain
            fake_B = g_AB.predict(imgs_A)
            fake_A = g_BA.predict(imgs_B)

            # Train the discriminators (original images = real / translated = Fake)
            dA_loss_real = d_A.train_on_batch(imgs_A, valid)
            dA_loss_fake = d_A.train_on_batch(fake_A, fake)
            dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)

            dB_loss_real = d_B.train_on_batch(imgs_B, valid)
            dB_loss_fake = d_B.train_on_batch(fake_B, fake)
            dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

            # Total disciminator loss
            d_loss = 0.5 * np.add(dA_loss, dB_loss)


            # ------------------
            #  Train Generators
            # ------------------

            # Train the generators
            g_loss = combined.train_on_batch([imgs_A, imgs_B],
                                             [valid, valid,
                                              imgs_A, imgs_B,
                                              imgs_A, imgs_B])

            elapsed_time = datetime.datetime.now() - start_time

            # Plot the progress
            print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f, id: %05f] time: %s " \
                                                                    % ( epoch, epochs,
                                                                        batch_i, data_loader.n_batches,
                                                                        d_loss[0], 100*d_loss[1],
                                                                        g_loss[0],
                                                                        np.mean(g_loss[1:3]),
                                                                        np.mean(g_loss[3:5]),
                                                                        np.mean(g_loss[5:6]),
                                                                        elapsed_time))

            # If at save interval => save generated image samples
            if batch_i % sample_interval == 0:
                sample_images(epoch, batch_i, 
                              g_AB, g_BA, # generators
                              dataset_name, data_loader)

## main()

In [10]:
img_rows = 128
img_cols = 128
channels = 3
img_shape = (img_rows, img_cols, channels)

# Configure data loader
dataset_name = 'apple2orange'

In [11]:
# Loss weights
lambda_cycle = 10.0               # Cycle-consistency loss
lambda_id = 0.1 * lambda_cycle    # Identity loss

In [12]:
# optimizer
optimizer = Adam(0.0002, 0.5)

Instructions for updating:
If using Keras pass *_constraint arguments to layers.


In [13]:
# Number of filters in the first layer of G and D
gf = 32
df = 64

In [14]:
# Build and compile the discriminators
d_A = build_discriminator(img_shape, df)
d_B = build_discriminator(img_shape, df)
d_A.compile(loss='mse',
    optimizer=optimizer,
    metrics=['accuracy'])
d_B.compile(loss='mse',
    optimizer=optimizer,
    metrics=['accuracy'])

In [15]:
#-------------------------
# Construct Computational
#   Graph of Generators
#-------------------------

# Build the generators
g_AB = build_generator(img_shape, gf, channels)
g_BA = build_generator(img_shape, gf, channels)

In [16]:
# Input images from both domains
img_A = Input(shape=img_shape)
img_B = Input(shape=img_shape)

In [17]:
# Translate images to the other domain
fake_B = g_AB(img_A)
fake_A = g_BA(img_B)

In [18]:
# Translate images back to original domain
reconstr_A = g_BA(fake_B)
reconstr_B = g_AB(fake_A)

In [19]:
# Identity mapping of images
img_A_id = g_BA(img_A)
img_B_id = g_AB(img_B)

In [20]:
# For the combined model we will only train the generators
d_A.trainable = False
d_B.trainable = False

# Discriminators determines validity of translated images
valid_A = d_A(fake_A)
valid_B = d_B(fake_B)

In [21]:
# Combined model trains generators to fool discriminators
combined = Model(inputs=[img_A, img_B],
                      outputs=[ valid_A, valid_B,
                                reconstr_A, reconstr_B,
                                img_A_id, img_B_id ])
combined.compile(loss=['mse', 'mse',
                            'mae', 'mae',
                            'mae', 'mae'],
                    loss_weights=[  1, 1,
                                    lambda_cycle, lambda_cycle,
                                    lambda_id, lambda_id ],
                    optimizer=optimizer)

In [22]:
epochs = 3 #200
train(epochs=epochs, batch_size=1, sample_interval=200,
      g_AB=g_AB, g_BA=g_BA, d_A=d_A, d_B=d_B, combined=combined,
      dataset_name=dataset_name, img_rows=img_rows, img_cols=img_cols
      )

TypeError: Cannot handle this data type: (1, 1, 3), <i4