In [1]:
from keras_contrib.layers.normalization import InstanceNormalization
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import datetime
import matplotlib.pyplot as plt
import sys
import numpy as np
import os
%matplotlib inline


  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


### build generator

In [2]:
def build_generator(img_shape, gf):

    def conv2d(layer_input, filters, f_size=4):
        """Layers used during downsampling"""
        d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
        d = LeakyReLU(alpha=0.2)(d)
        d = InstanceNormalization()(d)
        return d

    def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
        """Layers used during upsampling"""
        u = UpSampling2D(size=2)(layer_input)
        u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)
        if dropout_rate:
            u = Dropout(dropout_rate)(u)
        u = InstanceNormalization()(u)
        u = Concatenate()([u, skip_input])
        return u

    # Image input
    d0 = Input(shape=img_shape)

    # Downsampling
    d1 = conv2d(d0, gf)
    d2 = conv2d(d1, gf*2)
    d3 = conv2d(d2, gf*4)
    d4 = conv2d(d3, gf*8)

    # Upsampling
    u1 = deconv2d(d4, d3, gf*4)
    u2 = deconv2d(u1, d2, gf*2)
    u3 = deconv2d(u2, d1, gf)

    u4 = UpSampling2D(size=2)(u3)
    output_img = Conv2D(img_shape[2], kernel_size=4, strides=1, padding='same', activation='tanh')(u4)

    return Model(d0, output_img)

### build discriminator

In [3]:
def build_discriminator(img_shape, df):

    def d_layer(layer_input, filters, f_size=4, normalization=True):
        """Discriminator layer"""
        d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
        d = LeakyReLU(alpha=0.2)(d)
        if normalization:
            d = InstanceNormalization()(d)
        return d

    img = Input(shape=img_shape)

    d1 = d_layer(img, df, normalization=False)
    d2 = d_layer(d1, df*2)
    d3 = d_layer(d2, df*4)
    d4 = d_layer(d3, df*8)

    validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)

    return Model(img, validity)

In [4]:
img_rows = 128
img_cols = 128
channels = 3
img_shape = (img_rows, img_cols, channels)


# Number of filters in the first layer of G and D
gf = 32
df = 64

# Loss weights
lambda_cycle = 10.0    # Cycle-consistency loss
lambda_id = 0.1 * lambda_cycle    # Identity loss

optimizer = Adam(0.0002, 0.5)

# Build and compile the discriminators
d_A = build_discriminator(img_shape, df)
d_B = build_discriminator(img_shape, df)
d_A.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
d_B.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])

#-------------------------
# Construct Computational
#   Graph of Generators
#-------------------------

# Build the generators
g_AB = build_generator(img_shape, gf)
g_BA = build_generator(img_shape, gf)

# Input images from both domains
img_A = Input(shape=img_shape)
img_B = Input(shape=img_shape)

# Translate images to the other domain
fake_B = g_AB(img_A)
fake_A = g_BA(img_B)
# Translate images back to original domain
reconstr_A = g_BA(fake_B)
reconstr_B = g_AB(fake_A)
# Identity mapping of images
img_A_id = g_BA(img_A)
img_B_id = g_AB(img_B)

# For the combined model we will only train the generators
d_A.trainable = False
d_B.trainable = False

# Discriminators determines validity of translated images
valid_A = d_A(fake_A)
valid_B = d_B(fake_B)

# Combined model trains generators to fool discriminators
combined = Model(inputs=[img_A, img_B], outputs=[ valid_A, valid_B, reconstr_A, reconstr_B, img_A_id, img_B_id ])
combined.compile(loss=['mse', 'mse', 'mae', 'mae', 'mae', 'mae'],
                 loss_weights=[1, 1,lambda_cycle, lambda_cycle, lambda_id, lambda_id],
                 optimizer=optimizer)


In [18]:
def train( models, data_loader, epochs, batch_size=1, sample_interval=50):

    start_time = datetime.datetime.now()
    
    # Calculate output shape of D (PatchGAN)
    patch = int(128 / 2**4)
    disc_patch = (patch, patch, 1)
    
    # Adversarial loss ground truths
    valid = np.ones((batch_size,) + disc_patch)
    fake = np.zeros((batch_size,) + disc_patch)
    
    g_AB, g_BA, d_A, d_B, combined = models
    
    for epoch in range(epochs):
        for batch_i, (imgs_A, imgs_B) in enumerate(data_loader.load_batch(batch_size)):
            #  Train Discriminators
            # Translate images to opposite domain
            fake_B = g_AB.predict(imgs_A)
            fake_A = g_BA.predict(imgs_B)

            # Train the discriminators (original images = real / translated = Fake)
            dA_loss_real = d_A.train_on_batch(imgs_A, valid)
            dA_loss_fake = d_A.train_on_batch(fake_A, fake)
            dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)

            dB_loss_real = d_B.train_on_batch(imgs_B, valid)
            dB_loss_fake = d_B.train_on_batch(fake_B, fake)
            dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

            # Total disciminator loss
            d_loss = 0.5 * np.add(dA_loss, dB_loss)

            # ------------------
            #  Train Generators
            # ------------------
            # Train the generators
            g_loss = combined.train_on_batch([imgs_A, imgs_B], [valid, valid, imgs_A, imgs_B, imgs_A, imgs_B])

            elapsed_time = datetime.datetime.now() - start_time

            # If at save interval => save generated image samples
            if batch_i % sample_interval == 0:
                # Plot the progress
                print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f, id: %05f] time: %s " 
                       % ( epoch, epochs, batch_i, data_loader.n_batches, d_loss[0], 100*d_loss[1],
                          g_loss[0], np.mean(g_loss[1:3]), np.mean(g_loss[3:5]), np.mean(g_loss[5:6]), elapsed_time))
                sample_images(g_AB, g_BA, epoch, batch_i)

In [19]:
def sample_images(g_AB, g_BA,epoch, batch_i):
    os.makedirs('%s' % dataset_name, exist_ok=True)
    r, c = 2, 3

    imgs_A = data_loader.load_data(domain="A", batch_size=1, is_testing=True)
    imgs_B = data_loader.load_data(domain="B", batch_size=1, is_testing=True)

    # Translate images to the other domain
    fake_B = g_AB.predict(imgs_A)
    fake_A = g_BA.predict(imgs_B)
    # Translate back to original domain
    reconstr_A = g_BA.predict(fake_B)
    reconstr_B = g_AB.predict(fake_A)

    gen_imgs = np.concatenate([imgs_A, fake_B, reconstr_A, imgs_B, fake_A, reconstr_B])

    # Rescale images 0 - 1
    gen_imgs = 0.5 * gen_imgs + 0.5

    titles = ['Original', 'Translated', 'Reconstructed']
    fig, axs = plt.subplots(r, c, figsize=(16,10))
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt])
            axs[i, j].set_title(titles[j])
            axs[i,j].axis('off')
            cnt += 1
            
    if not os.path.exists("images"):
        os.makedirs("images")        
    fig.savefig("images/%s/%d_%d.png" % (dataset_name, epoch, batch_i))
    plt.close()

In [20]:
import cv2
from glob import glob
import numpy as np

class DataLoader():
    def __init__(self, dataset_name, img_res=(128, 128)):
        self.dataset_name = dataset_name
        self.img_res = img_res

    def load_data(self, domain, batch_size=1, is_testing=False):
        data_type = "train%s" % domain if not is_testing else "test%s" % domain
        path = glob('./datasets/%s/%s/*' % (self.dataset_name, data_type))

        batch_images = np.random.choice(path, size=batch_size)

        imgs = []
        for img_path in batch_images:
            img = self.imread(img_path)
            if not is_testing:
                img = cv2.resize(img, self.img_res)

                if np.random.random() > 0.5:
                    img = np.fliplr(img)
            else:
                img = cv2.resize(img, self.img_res)
            imgs.append(img)

        imgs = np.array(imgs)/127.5 - 1.

        return imgs

    def load_batch(self, batch_size=1, is_testing=False):
        data_type = "train" if not is_testing else "val"
        path_A = glob('./datasets/%s/%sA/*' % (self.dataset_name, data_type))
        path_B = glob('./datasets/%s/%sB/*' % (self.dataset_name, data_type))

        self.n_batches = int(min(len(path_A), len(path_B)) / batch_size)
        total_samples = self.n_batches * batch_size

        # Sample n_batches * batch_size from each path list so that model sees all
        # samples from both domains
        path_A = np.random.choice(path_A, total_samples, replace=False)
        path_B = np.random.choice(path_B, total_samples, replace=False)

        for i in range(self.n_batches-1):
            batch_A = path_A[i*batch_size:(i+1)*batch_size]
            batch_B = path_B[i*batch_size:(i+1)*batch_size]
            imgs_A, imgs_B = [], []
            for img_A, img_B in zip(batch_A, batch_B):
                img_A = self.imread(img_A)
                img_B = self.imread(img_B)

                img_A = cv2.resize(img_A, self.img_res)
                img_B = cv2.resize(img_B, self.img_res)

                if not is_testing and np.random.random() > 0.5:
                        img_A = np.fliplr(img_A)
                        img_B = np.fliplr(img_B)

                imgs_A.append(img_A)
                imgs_B.append(img_B)

            imgs_A = np.array(imgs_A)/127.5 - 1.
            imgs_B = np.array(imgs_B)/127.5 - 1.

            yield imgs_A, imgs_B

    def load_img(self, path):
        img = self.imread(path)
        img = cv2.resize(img, self.img_res)
        img = img/127.5 - 1.
        return img[np.newaxis, :, :, :]

    def imread(self, path):
        image = cv2.imread(path)
        image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
        ## or 
        # b,g,r = cv2.split(bgr_img)       # get b,g,r
        # image = cv2.merge([r,g,b])     # switch it to rgb
        return image.astype(np.float)

In [21]:
# Configure data loader
dataset_name = 'apple2orange'
data_loader = DataLoader(dataset_name=dataset_name, img_res=(img_rows, img_cols))


In [22]:
models = (g_AB, g_BA, d_A, d_B, combined)

train( models, data_loader, epochs=15, batch_size=1, sample_interval=100)

  'Discrepancy between trainable weights and collected trainable'


[Epoch 0/15] [Batch 0/995] [D loss: 0.246307, acc:  46%] [G loss: 4.771256, adv: 0.326632, recon: 0.176090, id: 0.328710] time: 0:00:00.397468 
[Epoch 0/15] [Batch 100/995] [D loss: 0.135102, acc:  94%] [G loss: 4.359006, adv: 0.574339, recon: 0.141728, id: 0.200707] time: 0:00:22.213006 
[Epoch 0/15] [Batch 200/995] [D loss: 0.227203, acc:  62%] [G loss: 4.331206, adv: 0.426881, recon: 0.144653, id: 0.441087] time: 0:00:43.967412 
[Epoch 0/15] [Batch 300/995] [D loss: 0.204758, acc:  65%] [G loss: 3.301114, adv: 0.399094, recon: 0.107834, id: 0.192739] time: 0:01:05.901087 
[Epoch 0/15] [Batch 400/995] [D loss: 0.216108, acc:  72%] [G loss: 3.454809, adv: 0.389701, recon: 0.107474, id: 0.335664] time: 0:01:28.011275 
[Epoch 0/15] [Batch 500/995] [D loss: 0.371639, acc:  29%] [G loss: 3.009381, adv: 0.288328, recon: 0.101007, id: 0.209313] time: 0:01:50.062860 
[Epoch 0/15] [Batch 600/995] [D loss: 0.295168, acc:  38%] [G loss: 2.883001, adv: 0.294755, recon: 0.092823, id: 0.288110] ti

[Epoch 5/15] [Batch 700/995] [D loss: 0.385364, acc:  16%] [G loss: 2.250860, adv: 0.229021, recon: 0.073281, id: 0.242639] time: 0:20:57.517025 
[Epoch 5/15] [Batch 800/995] [D loss: 0.305820, acc:  21%] [G loss: 2.876603, adv: 0.360311, recon: 0.088219, id: 0.274772] time: 0:21:19.575527 
[Epoch 5/15] [Batch 900/995] [D loss: 0.362853, acc:  21%] [G loss: 2.809940, adv: 0.280507, recon: 0.090926, id: 0.262621] time: 0:21:41.733558 
[Epoch 6/15] [Batch 0/995] [D loss: 0.245251, acc:  60%] [G loss: 2.322829, adv: 0.338223, recon: 0.064041, id: 0.133950] time: 0:22:02.306714 
[Epoch 6/15] [Batch 100/995] [D loss: 0.234603, acc:  50%] [G loss: 2.724847, adv: 0.398576, recon: 0.078256, id: 0.139600] time: 0:22:24.249090 
[Epoch 6/15] [Batch 200/995] [D loss: 0.445503, acc:   6%] [G loss: 2.297245, adv: 0.244856, recon: 0.069934, id: 0.255076] time: 0:22:46.534336 
[Epoch 6/15] [Batch 300/995] [D loss: 0.186331, acc:  70%] [G loss: 3.242939, adv: 0.614805, recon: 0.082883, id: 0.091141] ti

[Epoch 11/15] [Batch 400/995] [D loss: 0.212401, acc:  72%] [G loss: 2.534119, adv: 0.410410, recon: 0.067307, id: 0.204661] time: 0:41:48.941865 
[Epoch 11/15] [Batch 500/995] [D loss: 0.220940, acc:  57%] [G loss: 1.926300, adv: 0.328975, recon: 0.047449, id: 0.144063] time: 0:42:11.205142 
[Epoch 11/15] [Batch 600/995] [D loss: 0.440641, acc:  18%] [G loss: 2.068254, adv: 0.221147, recon: 0.068098, id: 0.090202] time: 0:42:33.516404 
[Epoch 11/15] [Batch 700/995] [D loss: 0.213700, acc:  75%] [G loss: 3.140627, adv: 0.341080, recon: 0.106541, id: 0.188717] time: 0:42:55.737846 
[Epoch 11/15] [Batch 800/995] [D loss: 0.176858, acc:  77%] [G loss: 2.740025, adv: 0.493278, recon: 0.068177, id: 0.133771] time: 0:43:17.911432 
[Epoch 11/15] [Batch 900/995] [D loss: 0.138113, acc:  84%] [G loss: 2.418736, adv: 0.566497, recon: 0.054516, id: 0.061577] time: 0:43:40.134572 
[Epoch 12/15] [Batch 0/995] [D loss: 0.129284, acc:  95%] [G loss: 3.225924, adv: 0.557990, recon: 0.088928, id: 0.209