In [None]:
import numpy as np
# import tensorflow.keras as keras

from keras.optimizers import Adam
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers import LeakyReLU, UpSampling2D, Conv2D
from keras.models import Sequential, Model, load_model
from keras import backend as K
from keras.preprocessing.image import ImageDataGenerator

from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt
import scipy
from tensorflow import logging
import imageio, skimage
import matplotlib.image as mpimg

import datetime, os, pickle
from os import listdir
from os.path import isfile, join

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
def get_img(img_filepath,target_size):
    _, _, n_C = target_size
    if n_C == 1:
        mode='L'
    elif n_C == 3:
        mode='RGB'
    else:
        raise Exception('Unexpected number of chanel '+str(n_C)+'!')
    # x = imageio.imread(img_filepath,as_gray=as_gray).astype(np.float)
    x = scipy.misc.imread(img_filepath, mode=mode).astype(np.float)
    x = scipy.misc.imresize(x, target_size)
    if n_C == 1 :
        x = np.stack((x,)*1, -1)
    x = np.array(x)/127.5 - 1.
    return x

def load_images(target_size):
    n_H, n_W, n_C = input_shape
    data_dir='data/raw/'
    img_files = [f for f in listdir(data_dir) if isfile(join(data_dir, f)) and '_x.jpg' in f]

    n_x=len(img_files)
    n_x=10
    X=np.zeros((n_x,n_H, n_W, n_C))
    Y=np.zeros((n_x,n_H, n_W, n_C))
    for i in range(n_x):
        img_id=img_files[i].strip('_x.jpg').strip('data_')
        X[i,:,:,:] = get_img(data_dir+'data_'+str(img_id)+'_x.jpg',target_size)
        Y[i,:,:,:] = get_img(data_dir+'data_'+str(img_id)+'_y.jpg',target_size)
    return X,Y

def load_facades_images(input_shape):
    n_H, n_W, n_C = input_shape
    test_data='data/facades/train/'
    img_files = [f for f in listdir(test_data) if isfile(join(test_data, f)) and '.jpg' in f]
    
    n_x=len(img_files)
    n_x=10
    X=np.zeros((n_x,n_H, n_W, n_C))
    Y=np.zeros((n_x,n_H, n_W, n_C))
    for i in range(n_x):
        img = get_img(test_data+img_files[i],[n_H,n_W*2,n_C])
        Y[i,:,:,:], X[i,:,:,:] = img[:, :n_W, :], img[:, n_W:, :]
    return X,Y

def load_realworld_images(input_shape):
    test_data_dir='data/test/'
    img_training_files = [f for f in listdir(test_data_dir) if isfile(join(test_data_dir, f)) and '.jpg' in f]
    X_test=np.zeros((len(img_training_files),target_size[0], target_size[1], target_size[2]))
    for i in range(len(img_training_files)):
        X_test[i,:,:,:] = get_img(test_data_dir+img_training_files[i],target_size)
    return X_test


In [None]:
# SOURCE https://github.com/eriklindernoren/Keras-GAN

def build_generator(input_shape,gf,name):
    """U-Net Generator"""

    n_H, n_W, n_C = input_shape
    def conv2d(layer_input, filters, f_size=4, bn=True):
        """Layers used during downsampling"""
        d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
        d = LeakyReLU(alpha=0.2)(d)
        if bn:
            d = BatchNormalization(momentum=0.8)(d)
        return d

    def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
        """Layers used during upsampling"""
        u = UpSampling2D(size=2)(layer_input)
        u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)
        if dropout_rate:
            u = Dropout(dropout_rate)(u)
        u = BatchNormalization(momentum=0.8)(u)
        u = Concatenate()([u, skip_input])
        return u

    # Image input
    d0 = Input(shape=input_shape)

    # Downsampling
    d1 = conv2d(d0, gf, bn=False)
    d2 = conv2d(d1, gf*2)
    d3 = conv2d(d2, gf*4)
    d4 = conv2d(d3, gf*8)
    d5 = conv2d(d4, gf*8)
    d6 = conv2d(d5, gf*8)
    d7 = conv2d(d6, gf*8)

    # Upsampling
    u1 = deconv2d(d7, d6, gf*8)
    u2 = deconv2d(u1, d5, gf*8)
    u3 = deconv2d(u2, d4, gf*8)
    u4 = deconv2d(u3, d3, gf*4)
    u5 = deconv2d(u4, d2, gf*2)
    u6 = deconv2d(u5, d1, gf)

    u7 = UpSampling2D(size=2)(u6)
    output_img = Conv2D(n_C, kernel_size=4, strides=1, padding='same', activation='tanh')(u7)

    return Model(d0, output_img,name=name)

def build_discriminator(input_shape, df, name):
    
    n_H, n_W, n_C = input_shape
    def d_layer(layer_input, filters, f_size=4, bn=True):
        """Discriminator layer"""
        d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
        d = LeakyReLU(alpha=0.2)(d)
        if bn:
            d = BatchNormalization(momentum=0.8)(d)
        return d

    img_A = Input(shape=input_shape)
    img_B = Input(shape=input_shape)

    # Concatenate image and conditioning image by channels to produce input
    combined_imgs = Concatenate(axis=-1)([img_A, img_B])

    d1 = d_layer(combined_imgs, df, bn=False)
    d2 = d_layer(d1, df*2)
    d3 = d_layer(d2, df*4)
    d4 = d_layer(d3, df*8)

    validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)

    return Model([img_A, img_B], validity, name=name)

In [None]:
def build_model(input_shape, gf=64, df=64, name='combined',init_model=True):    
    # Build and compile the discriminator
    discriminator = build_discriminator(input_shape,df,'discriminator')
    discriminator.compile(loss='mse',optimizer=Adam(0.0002, 0.5),metrics=['accuracy'])

    #-------------------------
    # Build the generator
    generator = build_generator(input_shape,gf,'generator')

    # Input images and their conditioning images
    img_A = Input(shape=input_shape)
    img_B = Input(shape=input_shape)

    # By conditioning on B generate a fake version of A
    fake_A = generator(img_B)

    # For the combined model we will only train the generator
    """By setting trainable=False after the discriminator has been compiled the discriminator 
    is still trained during discriminator.train_on_batch but since it's set to non-trainable 
    before the combined model is compiled it's not trained during combined.train_on_batch."""
    discriminator.trainable = False

    # Discriminators determines validity of translated images / condition pairs
    valid = discriminator([fake_A, img_B])

    combined = Model(inputs=[img_A, img_B], outputs=[valid, fake_A], name=name)
    combined.compile(loss=['mse', 'mae'],loss_weights=[1, 100],optimizer=Adam(0.0002, 0.5))

    return generator, discriminator, combined

In [None]:
def sample_images(generator, imgs_X, imgs_Y,epoch):
    m, n_H, n_W, _ = imgs_X.shape
    DPI = plt.gcf().get_dpi()
    figsize=((3*n_W)/float(DPI),(m*n_H)/float(DPI))
    
    generated_Y = generator.predict(imgs_X,batch_size=1)
    titles = ['Original', 'Generated', 'Condition']
    fig, axs = plt.subplots(m, len(titles),figsize=figsize)
    
    for r in range(m):
        axs[r,0].imshow(0.5 * imgs_X[r,:,:,:]+ 0.5)
        axs[r,0].set_title(titles[0])
        axs[r,0].axis('off')
        axs[r,1].imshow(0.5 * generated_Y[r,:,:,:]+ 0.5)
        axs[r,1].set_title(titles[1])
        axs[r,1].axis('off')
        axs[r,2].imshow(0.5 * imgs_Y[r,:,:,:]+ 0.5)
        axs[r,2].set_title(titles[2])
        axs[r,2].axis('off')
    fig.savefig("output/pix2pix_epoch_%d.png" % (epoch))
    plt.close()

In [None]:
def train_epoch(generator, discriminator, combined,imgs_A, imgs_B, epochs=1, batch_size=1):    
    # Calculate output shape of D (PatchGAN)
    m, n_H, n_W, n_C = imgs_B.shape 
    disc_patch = (int(n_H/16), int(n_W/16), 1)
    # Adversarial loss ground truths
    valid = np.ones((m,) + disc_patch)
    fake = np.zeros((m,) + disc_patch)

    logging.info('Training Discriminator')
    # Condition on B and generate a translated version
    fake_A = generator.predict(imgs_B,batch_size=batch_size)
    # Train the discriminators (original images = real / generated = Fake)
    d_loss_real = discriminator.fit(x=[imgs_A, imgs_B], y=valid, batch_size=batch_size, epochs=epochs, verbose=0)
    d_loss_fake = discriminator.fit(x=[fake_A, imgs_B], y=fake, batch_size=batch_size, epochs=epochs, verbose=0)
    d_loss = 0.5 * np.add(d_loss_real.history['loss'], d_loss_fake.history['loss'])

    logging.info('Training Generator')
    # Train the generators. SET Discriminator trainable false.
    g_loss = combined.fit(x=[imgs_A, imgs_B], y=[valid, imgs_A], batch_size=batch_size, epochs=epochs)

    loss={'d_loss_real':d_loss_real,'d_loss_fake':d_loss_fake,'g_loss':g_loss}
    return generator, discriminator, combined, loss


In [None]:
input_shape=[256,256,3]
X,Y = load_images(input_shape)
#X,Y = load_facades_images(input_shape)

# X_train, X_val, Y_train, Y_val = train_test_split(X,Y,test_size=0.1,random_state=2)
# X_train, X_val, Y_train, Y_val = train_test_split(X,Y,test_size=2,random_state=2)
if True:
    i_sample=np.random.randint(len(X))
    f, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
    ax1.imshow(0.5 * X[i_sample,:,:,:] + 0.5)
    ax2.imshow(0.5 * Y[i_sample,:,:,:] + 0.5)
    plt.show()

In [None]:
logging.set_verbosity(logging.ERROR)

generator, discriminator, combined = build_model(input_shape,init_model=True)
for epoch in range(0, 200):
    
    generator, discriminator, combined ,_=train_epoch(generator, discriminator, combined,imgs_A=Y, imgs_B=X, epochs=1, batch_size=16)
    
    logging.info('saving model')
    pickle.dump(combined,open('saved_model/pix2pix_emoji_combined.pkl',"wb" ))
    if epoch % 5 == 0 : 
        pickle.dump(combined,open('saved_model/pix2pix_emoji_combined_epoch'+str(epoch)+'.pkl',"wb" ))
    
    # generate sample images from dataset
    np.random.seed(3)
    m = X.shape[0]
    _s=np.random.randint(m-5)
    X_sample, Y_sample =X[_s:_s+5,:,:,:], Y[_s:_s+5,:,:,:]
    sample_images(generator,X_sample, Y_sample,epoch)
    