In [None]:
import tensorflow as tf
import numpy as np
# import tensorflow.keras as keras

from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D
from tensorflow.keras.layers import LeakyReLU, UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model, load_model

import matplotlib.pyplot as plt
import scipy
import datetime

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
data_dir='data/'
raw_data_dir=data_dir+'raw/'

def get_img(i,target_size):
    import imageio, skimage
    x = imageio.imread(raw_data_dir+'data_'+str(i)+'_x.jpg',as_gray=True).astype(np.float)
    x = scipy.misc.imresize(x, target_size)
    x= np.stack((x,)*1, -1)
    x = x/127.5 - 1.
    
    y = imageio.imread(raw_data_dir+'data_'+str(i)+'_y.jpg', as_gray=True).astype(np.float)
    y = scipy.misc.imresize(y, target_size)
    y= np.stack((y,)*1, -1)
    y = y/127.5 - 1.

    return x,y

def load_images(target_size):
    from os import listdir
    from os.path import isfile, join
    img_files = [f for f in listdir(raw_data_dir) if isfile(join(raw_data_dir, f)) and '_x.jpg' in f]
    n_x=len(img_files)
    X=np.zeros((n_x,target_size[0], target_size[1], target_size[2]))
    Y=np.zeros((n_x,target_size[0], target_size[1], target_size[2]))
    
    for i in range(n_x):
        x,y =get_img(i,target_size)
        X[i,:,:,:] = x
        Y[i,:,:,:] = y
    return X,Y,n_x

In [None]:
# SOURCE https://github.com/eriklindernoren/Keras-GAN

def build_generator(input_shape,gf,name):
    """U-Net Generator"""

    n_H, n_W, n_C = input_shape
    def conv2d(layer_input, filters, f_size=4, bn=True):
        """Layers used during downsampling"""
        d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
        d = LeakyReLU(alpha=0.2)(d)
        if bn:
            d = BatchNormalization(momentum=0.8)(d)
        return d

    def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
        """Layers used during upsampling"""
        u = UpSampling2D(size=2)(layer_input)
        u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)
        if dropout_rate:
            u = Dropout(dropout_rate)(u)
        u = BatchNormalization(momentum=0.8)(u)
        u = Concatenate()([u, skip_input])
        return u

    # Image input
    d0 = Input(shape=input_shape)

    # Downsampling
    d1 = conv2d(d0, gf, bn=False)
    d2 = conv2d(d1, gf*2)
    d3 = conv2d(d2, gf*4)
    d4 = conv2d(d3, gf*8)
    d5 = conv2d(d4, gf*8)
    d6 = conv2d(d5, gf*8)
    d7 = conv2d(d6, gf*8)

    # Upsampling
    u1 = deconv2d(d7, d6, gf*8)
    u2 = deconv2d(u1, d5, gf*8)
    u3 = deconv2d(u2, d4, gf*8)
    u4 = deconv2d(u3, d3, gf*4)
    u5 = deconv2d(u4, d2, gf*2)
    u6 = deconv2d(u5, d1, gf)

    u7 = UpSampling2D(size=2)(u6)
    output_img = Conv2D(n_C, kernel_size=4, strides=1, padding='same', activation='tanh')(u7)

    return Model(d0, output_img,name=name)

def build_discriminator(input_shape, df, name):
    
    n_H, n_W, n_C = input_shape
    def d_layer(layer_input, filters, f_size=4, bn=True):
        """Discriminator layer"""
        d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
        d = LeakyReLU(alpha=0.2)(d)
        if bn:
            d = BatchNormalization(momentum=0.8)(d)
        return d

    img_A = Input(shape=input_shape)
    img_B = Input(shape=input_shape)

    # Concatenate image and conditioning image by channels to produce input
    combined_imgs = Concatenate(axis=-1)([img_A, img_B])

    d1 = d_layer(combined_imgs, df, bn=False)
    d2 = d_layer(d1, df*2)
    d3 = d_layer(d2, df*4)
    d4 = d_layer(d3, df*8)

    validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)

    return Model([img_A, img_B], validity, name=name)

In [None]:
def build_model(input_shape, gf=64, df=64, name='pix2pix_emoji'):
    n_H, n_W, n_C = input_shape
    # Calculate output shape of D (PatchGAN)
    patch = int(n_H / 2**4)
    disc_patch = (patch, patch, 1)

    optimizer = Adam(0.0002, 0.5)

    # Build and compile the discriminator
    discriminator = build_discriminator(input_shape,df,'model_discriminator')
    #-------------------------
    # Construct Computational
    #   Graph of Generator
    #-------------------------
    # Build the generator
    generator = build_generator(input_shape,gf,'model_generator')

    # Input images and their conditioning images
    img_A = Input(shape=input_shape)
    img_B = Input(shape=input_shape)

    # By conditioning on B generate a fake version of A
    fake_A = generator(img_B)

    # For the combined model we will only train the generator
    """By setting trainable=False after the discriminator has been compiled the discriminator 
    is still trained during discriminator.train_on_batch but since it's set to non-trainable 
    before the combined model is compiled it's not trained during combined.train_on_batch."""
    # discriminator.trainable = False

    # Discriminators determines validity of translated images / condition pairs
    valid = discriminator([fake_A, img_B])

    combined = Model(inputs=[img_A, img_B], outputs=[valid, fake_A], name=name)

    return combined

In [None]:
def train_epoch(model,imgs_A, imgs_B, epochs=1, batch_size=1):
       
    start_time = datetime.datetime.now()
    
    m, n_H, n_W, n_C = imgs_A.shape
    
    # Calculate output shape of D (PatchGAN)
    patch = int(n_H / 2**4)
    disc_patch = (patch, patch, 1)
    # Adversarial loss ground truths
    valid = np.ones((m,) + disc_patch)
    fake = np.zeros((m,) + disc_patch)

    # ---------------------
    #  Train Discriminator
    # ---------------------

    # Condition on B and generate a translated version
    fake_A = model.get_layer("model_generator").predict(imgs_B)
    # Train the discriminators (original images = real / generated = Fake)
    model.get_layer("model_discriminator").trainable= True
    d_loss_real = model.get_layer("model_discriminator").fit(x=[imgs_A, imgs_B], y=valid, batch_size=batch_size, epochs=epochs, verbose=0)
    d_loss_fake = model.get_layer("model_discriminator").fit(x=[fake_A, imgs_B], y=fake, batch_size=batch_size, epochs=epochs, verbose=0)
    d_loss = 0.5 * np.add(d_loss_real.history['loss'], d_loss_fake.history['loss'])

    # -----------------
    #  Train Generator
    # -----------------
    # Train the generators. SET Discriminator trainable false.
    model.get_layer("model_discriminator").trainable= False
    g_loss = model.fit(x=[imgs_A, imgs_B], y=[valid, imgs_A], batch_size=batch_size, epochs=epochs)
    elapsed_time = datetime.datetime.now() - start_time

    return model,g_loss,d_loss

In [None]:
def sample_images(generator, imgs_A, imgs_B,epoch):
    m, n_H, n_W, _ = imgs_A.shape
    figsize=(n_H,n_W)
    generated_A = generator.predict(imgs_B)
    
    titles = ['Original', 'Generated', 'Condition']
    fig, axs = plt.subplots(m, len(titles),figsize=figsize)
    for r in range(m):
        axs[r,0].imshow(imgs_A[r,:,:,0], cmap='gray')
        axs[r,0].set_title(titles[0])
        axs[r,0].axis('off')
        axs[r,1].imshow(generated_A[r,:,:,0], cmap='gray')
        axs[r,1].set_title(titles[1])
        axs[r,1].axis('off')
        axs[r,2].imshow(imgs_B[r,:,:,0], cmap='gray')
        axs[r,2].set_title(titles[2])
        axs[r,2].axis('off')
    fig.savefig("output/pix2pix_epoch_%d.png" % (epoch))
    plt.close()

In [None]:
input_shape=[256,256,1]
X,Y,n_x = load_images(input_shape)
if False:
    print(X.shape)
    i_sample=np.random.randint(n_x)
    f, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
    ax1.imshow(X[i_sample,:,:,0], cmap="gray")
    ax2.imshow(Y[i_sample,:,:,0], cmap="gray")
    plt.show()

In [None]:
def get_compiled_model(model_filepath=None):
    if model_filepath is None :
        model=build_model(input_shape,name='model_combined')
    else :
        model = load_model(model_filepath)
        
    optimizer = Adam(0.0002, 0.5)
    model.get_layer('model_discriminator').compile(loss='mse',optimizer=optimizer,metrics=['accuracy'])
    model.compile(loss=['mse', 'mae'], loss_weights=[1, 100],optimizer=optimizer)
    return model

# model.summary()

In [None]:
np.random.seed(3)
m, _, _, _ = X.shape
_s=np.random.randint(m-5)
X_sample=X[_s:_s+5,:,:,:]
Y_sample=Y[_s:_s+5,:,:,:]
model_filepath='saved_model/pix2pix_emoji_model__epoch_last.h5'

model=get_compiled_model()
#model = load_model(model_filepath)
for i in range(0, 3):
    model,_,_=train_epoch(model,X, Y, epochs=1, batch_size=16)
    model.save(filepath=model_filepath,overwrite=True)
    sample_images(model.get_layer('model_generator'),X_sample, Y_sample,i)
    model = get_compiled_model(model_filepath)
    