In [38]:
import tensorflow as tf

import os 

from keras.models import Model
from keras.optimizers import Adam
from keras.layers import Concatenate
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import BatchNormalization
from keras.layers import Activation
from keras.layers import LeakyReLU
from keras.layers import Input


In [7]:
def load(image_file):
    #decode into uint8 tensor
    image = tf.io.read_file(image_file)
    image = tf.io.decode_jpeg(image)

    #split into photo + painting 
    w = tf.shape(image)[1]
    w = w // 2
    input_image = image[:,w:,:]
    real_image = image[:,:w,:]

    #convert both images to float32 tensors 
    input_image = tf.cast(input_image, tf.float32)
    real_image = tf.cast(real_image, tf.float32)

    return input_image, real_image 



In [39]:
def Discriminator(image_shape):

    # stddev can be edited 
    initializer = tf.random_normal_initializer(0.0, 0.05)

    inp_img = Input(shape=image_shape, name='input_image')
    target_img = Input(shape=image_shape, name='target_image')

    merged = Concatenate()([inp_img,target_img])

    # adapted from 'Image-to-Image Translation with Conditional Adversarial Networks' - Isola, UC Berkeley
    # C64-C128-C256-C512
    
    d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=initializer)(merged)
    d = LeakyReLU(alpha=0.2)(d)

    d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=initializer)(merged)
    d = BatchNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)

    d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=initializer)(merged)
    d = BatchNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)


    d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=initializer)(merged)
    d = BatchNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)

    # patch output
    patch_out = Activation('sigmoid')(d)
    model = Model([inp_img,target_img], patch_out)
    out = Adam(learning_rate=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', loss_weights=[0.5])

    return model


    