In [1]:
import os
import datetime
import matplotlib.pyplot as plt
import numpy as np

In [2]:
from data_loader import DataLoader

In [3]:
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from keras.layers import Input, Dropout, Concatenate
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Model
from keras.optimizers import Adam

Using TensorFlow backend.


# Generator

In [4]:
def conv2d(layer_input, filters, f_size=4, normalize=True):
    """Layers used during downsampling"""
    d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
    d = LeakyReLU(alpha=0.2)(d)
    if normalize:
        d = InstanceNormalization()(d)
    return d

def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
    """Layers used during upsampling"""
    u = UpSampling2D(size=2)(layer_input)
    u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)
    if dropout_rate:
        u = Dropout(dropout_rate)(u)
    u = InstanceNormalization()(u)
    u = Concatenate()([u, skip_input])
    return u

In [5]:
def build_generator(img_shape, channels, gf):
    """U-Net Generator"""

    # Image input
    d0 = Input(shape=img_shape)

    # Downsampling
    d1 = conv2d(d0, gf, normalize=False)
    d2 = conv2d(d1, gf*2)
    d3 = conv2d(d2, gf*4)
    d4 = conv2d(d3, gf*8)
    d5 = conv2d(d4, gf*8)
    d6 = conv2d(d5, gf*8)
    d7 = conv2d(d6, gf*8)

    # Upsampling
    u1 = deconv2d(d7, d6, gf*8)
    u2 = deconv2d(u1, d5, gf*8)
    u3 = deconv2d(u2, d4, gf*8)
    u4 = deconv2d(u3, d3, gf*4)
    u5 = deconv2d(u4, d2, gf*2)
    u6 = deconv2d(u5, d1, gf)

    u7 = UpSampling2D(size=2)(u6)
    output_img = Conv2D(channels, kernel_size=4, strides=1,
                        padding='same', activation='tanh')(u7)

    return Model(d0, output_img)

# Discriminator

In [6]:
def d_layer(layer_input, filters, f_size=4, normalization=True):
    """Discriminator layer"""
    d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
    d = LeakyReLU(alpha=0.2)(d)
    if normalization:
        d = InstanceNormalization()(d)
    return d

In [7]:
def build_discriminator(img_shape, df):
    img = Input(shape=img_shape)

    d1 = d_layer(img, df, normalization=False)
    d2 = d_layer(d1, df*2)
    d3 = d_layer(d2, df*4)
    d4 = d_layer(d3, df*8)

    validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)

    return Model(img, validity)

# Helper function

In [8]:
def sample_images(data_loader, dataset_name, g_AB, g_BA, epoch, batch_i):
    os.makedirs('images/%s' % dataset_name, exist_ok=True)
    r, c = 2, 3

    imgs_A, imgs_B = data_loader.load_data(batch_size=1, is_testing=True)

    # Translate images to the other domain
    fake_B = g_AB.predict(imgs_A)
    fake_A = g_BA.predict(imgs_B)
    # Translate back to original domain
    reconstr_A = g_BA.predict(fake_B)
    reconstr_B = g_AB.predict(fake_A)

    gen_imgs = np.concatenate([imgs_A, fake_B, reconstr_A, imgs_B, fake_A, reconstr_B])

    # Rescale images 0 - 1
    gen_imgs = 0.5 * gen_imgs + 0.5

    titles = ['Original', 'Translated', 'Reconstructed']
    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt])
            axs[i, j].set_title(titles[j])
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("images/%s/%d_%d.png" % (dataset_name, epoch, batch_i))
    plt.close()

# Train

In [9]:
def train(data_loader, dataset_name, 
          g_AB, g_BA, d_A, d_B, combined, 
          disc_patch,
          epochs, batch_size=128, sample_interval=50):

    start_time = datetime.datetime.now()

    # Adversarial loss ground truths
    valid = np.ones((batch_size,) + disc_patch)
    fake = np.zeros((batch_size,) + disc_patch)

    for epoch in range(epochs):
        for batch_i, (imgs_A, imgs_B) in enumerate(data_loader.load_batch(batch_size)):
            print("epoch", epoch, "- batch #", batch_i)

            # ----------------------
            #  Train Discriminators
            # ----------------------

            # Translate images to opposite domain
            fake_B = g_AB.predict(imgs_A)
            fake_A = g_BA.predict(imgs_B)

            # Train the discriminators (original images = real / translated = Fake)
            dA_loss_real = d_A.train_on_batch(imgs_A, valid)
            dA_loss_fake = d_A.train_on_batch(fake_A, fake)
            dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)

            dB_loss_real = d_B.train_on_batch(imgs_B, valid)
            dB_loss_fake = d_B.train_on_batch(fake_B, fake)
            dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

            # Total disciminator loss
            d_loss = 0.5 * np.add(dA_loss, dB_loss)

            # ------------------
            #  Train Generators
            # ------------------

            # Train the generators
            g_loss = combined.train_on_batch([imgs_A, imgs_B], [valid, valid, \
                                                                     imgs_B, imgs_A, \
                                                                     imgs_A, imgs_B])

            elapsed_time = datetime.datetime.now() - start_time
            # Plot the progress
            print ("[%d] [%d/%d] time: %s, [d_loss: %f, g_loss: %f]" % (epoch, batch_i,
                                                                        data_loader.n_batches,
                                                                        elapsed_time,
                                                                        d_loss[0], g_loss[0]))

            # If at save interval => save generated image samples
            if batch_i % sample_interval == 0:
                sample_images(data_loader, dataset_name, g_AB, g_BA, epoch, batch_i)

# main()

In [10]:
# Input shape
img_rows = 128
img_cols = 128
channels = 3
img_shape = (img_rows, img_cols, channels)

In [11]:
# Configure data loader
dataset_name = 'edges2shoes'
data_loader = DataLoader(dataset_name=dataset_name,
                         img_res=(img_rows, img_cols))

In [12]:
# Calculate output shape of D (PatchGAN)
patch = int(img_rows / 2**4)
disc_patch = (patch, patch, 1)

In [13]:
# Number of filters in the first layer of G and D
gf = 64
df = 64

In [14]:
# create optimizer
optimizer = Adam(0.0002, 0.5)

Instructions for updating:
If using Keras pass *_constraint arguments to layers.


In [15]:
# Build and compile the discriminators
d_A = build_discriminator(img_shape, df)
d_B = build_discriminator(img_shape, df)
d_A.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
d_B.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])

In [16]:
#-------------------------
# Construct Computational
#   Graph of Generators
#-------------------------

# Build the generators
g_AB = build_generator(img_shape, channels, gf)
g_BA = build_generator(img_shape, channels, gf)

In [17]:
# Input images from both domains
img_A = Input(shape=img_shape)
img_B = Input(shape=img_shape)

In [18]:
# Translate images to the other domain
fake_B = g_AB(img_A)
fake_A = g_BA(img_B)

In [19]:
# Translate images back to original domain
reconstr_A = g_BA(fake_B)
reconstr_B = g_AB(fake_A)

In [20]:
# For the combined model we will only train the generators
d_A.trainable = False
d_B.trainable = False

In [21]:
# Discriminators determines validity of translated images
valid_A = d_A(fake_A)
valid_B = d_B(fake_B)

In [22]:
# Objectives
# + Adversarial: Fool domain discriminators
# + Translation: Minimize MAE between e.g. fake B and true B
# + Cycle-consistency: Minimize MAE between reconstructed images and original
combined = Model(inputs=[img_A, img_B],
                      outputs=[ valid_A, valid_B,
                                fake_B, fake_A,
                                reconstr_A, reconstr_B ])
combined.compile(loss=['mse', 'mse',
                       'mae', 'mae',
                       'mae', 'mae'],
                 optimizer=optimizer)

In [None]:
# epochs=20
epochs=5
train(data_loader, dataset_name, 
      g_AB, g_BA, d_A, d_B, combined,
      disc_patch,
      epochs=epochs, batch_size=1, sample_interval=200)

epoch 0 # 0


  'Discrepancy between trainable weights and collected trainable'


[0] [0/49825] time: 0:00:02.753157, [d_loss: 0.325678, g_loss: 3.150854]


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


epoch 0 # 1
[0] [1/49825] time: 0:00:06.512421, [d_loss: 0.272776, g_loss: 2.676764]
epoch 0 # 2
[0] [2/49825] time: 0:00:09.725880, [d_loss: 0.245721, g_loss: 2.466418]
epoch 0 # 3
[0] [3/49825] time: 0:00:12.909697, [d_loss: 0.258993, g_loss: 1.671192]
epoch 0 # 4
[0] [4/49825] time: 0:00:15.965880, [d_loss: 0.188796, g_loss: 2.790794]
epoch 0 # 5
[0] [5/49825] time: 0:00:19.117180, [d_loss: 0.191661, g_loss: 1.591216]
epoch 0 # 6
[0] [6/49825] time: 0:00:22.311919, [d_loss: 0.196063, g_loss: 2.830503]
epoch 0 # 7
[0] [7/49825] time: 0:00:25.948970, [d_loss: 0.163234, g_loss: 2.287107]
epoch 0 # 8
[0] [8/49825] time: 0:00:29.338723, [d_loss: 0.349116, g_loss: 1.565508]
epoch 0 # 9
[0] [9/49825] time: 0:00:32.507564, [d_loss: 0.321941, g_loss: 1.596231]
epoch 0 # 10
[0] [10/49825] time: 0:00:35.999136, [d_loss: 0.193477, g_loss: 2.406057]
epoch 0 # 11
[0] [11/49825] time: 0:00:39.392599, [d_loss: 0.350549, g_loss: 1.200775]
epoch 0 # 12
[0] [12/49825] time: 0:00:42.521781, [d_loss: 0.