In [None]:
import os 
import tensorflow as tf

from keras import backend as K
from keras.models import Sequential, Model
from keras.initializers import RandomNormal
from keras.layers import BatchNormalization, Conv2D, ZeroPadding2D, Input, Dropout, Concatenate
from keras.layers import Conv2DTranspose, UpSampling2D, Activation, Add, Lambda, Cropping2D
from keras.layers.advanced_activations import LeakyReLU
from keras.activations import relu
from keras.optimizers import Adam

from PIL import Image
import glob
import numpy as np
from IPython.display import display, clear_output
from random import shuffle
import time

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"]="0"
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"

config = tf.ConfigProto(log_device_placement=True)
config.gpu_options.per_process_gpu_memory_fraction = 0.5
config.gpu_options.allow_growth = True 
sess = tf.Session(config=config)
K.set_session(sess)

In [None]:
def batchnorm():
    return BatchNormalization(momentum=0.9, axis=3, epsilon=1e-5, gamma_initializer = RandomNormal(1., 0.02))


def conv2d(x, *a, **k):
    return Conv2D(kernel_initializer = RandomNormal(0, 0.02), *a, **k)(x)


def conv_block(x, filters, size, stride=(2, 2), has_norm_layer=True, use_norm_instance=False,
               has_activation_layer=True, use_leaky_relu=False, padding='same'):
    x = conv2d(x, filters, (size, size), strides=stride, padding=padding)
    if has_norm_layer:
        if not use_norm_instance:
            x = batchnorm()(x)
        else:
            x = InstanceNormalization(axis=1)(x)
    if has_activation_layer:
        if not use_leaky_relu:
            x = Activation('relu')(x)
        else:
            x = LeakyReLU(alpha=0.2)(x)
    return x

In [None]:
def resnet_block(x, filters=256, padding='same'):
    y = conv2d(x, filters, kernel_size=3, strides=1, padding=padding)
    y = LeakyReLU(alpha=0.2)(x)
    y = conv2d(x, filters, kernel_size=3, strides=1, padding=padding)
    
    return Add()([y, x])


def up_block(x, filters, size, use_norm_instance=False):
    x = Conv2DTranspose(filters, kernel_size=size, strides=2, padding='same',
                            use_bias=True if use_norm_instance else False,
                            kernel_initializer=RandomNormal(0, 0.02))(x)
    x = batchnorm()(x)
    x = Activation('relu')(x)

    return x

In [None]:
# def generator(image_size=256, channels=3, res_blocks=6):
#     """Builds the generator that consists of an encoder, 
#     a transformer and a decoder."""
#     inputs = Input(shape=(None, None, channels))
#     x = inputs
    
#     # Encoder
#     x = conv2d(x, 64, kernel_size=7, strides=1, padding='same')
#     x = conv2d(x, 128, kernel_size=3, strides=2, padding='same')
#     x = conv2d(x, 256, kernel_size=3, strides=2, padding='same')
    
#     # Transformer
#     for i in range(res_blocks):
#         x = resnet_block(x)
        
#     # Decoder
#     x = Conv2DTranspose(128, kernel_size=3, strides=2, padding='same')(x)
#     x = batchnorm()(x)
#     x = Activation('relu')(x)
#     x = Conv2DTranspose(64, kernel_size=3, strides=2, padding='same')(x)
    
#     x = conv2d(x, 3, (7, 7), activation='tanh', strides=(1, 1) ,padding='same')    
#     outputs = x
    
#     return Model(inputs=inputs, outputs=[outputs])
def generator(image_size=256, channels=3, res_blocks=6):
    inputs = Input(shape=(image_size, image_size, channels))
    x = inputs
    
    # Encoder
    x = conv_block(x, 64, 7, (1, 1))
    x = conv_block(x, 128, 3, (2, 2))
    x = conv_block(x, 256, 3, (2, 2))
    
    # Transformer
    for i in range(res_blocks):
        x = res_block(x)
        
    # Decoder    
    x = up_block(x, 128, 3)
    x = up_block(x, 64, 3)
    
    x = conv2d(3, (7, 7), activation='tanh', strides=(1, 1) ,padding='same')(x)    
    outputs = x
    return Model(inputs=inputs, outputs=[outputs])


def unet_generator(isize=256, nc_in=3, nc_out=3, ngf=64, fixed_input_size=True):    
    max_nf = 8*ngf    
    def block(x, s, nf_in, use_batchnorm=True, nf_out=None, nf_next=None):
        assert s>=2 and s%2==0
        if nf_next is None:
            nf_next = min(nf_in*2, max_nf)
        if nf_out is None:
            nf_out = nf_in
        x = conv2d(x, nf_next, kernel_size=4, strides=2, use_bias=(not (use_batchnorm and s>2)),
                   padding="same", name = 'conv_{0}'.format(s))
        if s>2:
            if use_batchnorm:
                x = batchnorm()(x, training=1)
            x2 = LeakyReLU(alpha=0.2)(x)
            x2 = block(x2, s//2, nf_next)
            x = Concatenate(axis=-1)([x, x2])            
        x = Activation("relu")(x)
        x = Conv2DTranspose(nf_out, kernel_size=4, strides=2, use_bias=not use_batchnorm,
                            kernel_initializer = RandomNormal(0, 0.02),          
                            name = 'convt.{0}'.format(s))(x)        
        x = Cropping2D(1)(x)
        if use_batchnorm:
            x = batchnorm()(x, training=1)  # training parameter?
        if s <=8:
            x = Dropout(0.5)(x, training=1)
        return x
    
    s = isize if fixed_input_size else None
   
    y = inputs = Input(shape=(s, s, nc_in))        
    y = block(y, isize, nc_in, False, nf_out=nc_out, nf_next=ngf)
    y = Activation('tanh')(y)
    return Model(inputs=inputs, outputs=[y])

In [None]:
# def discriminator(image_size=256, channels=3, fl_filters=64, hidden_layers=3):
#     inputs = Input(shape=(None, None, channels))
#     x = inputs

#     x = ZeroPadding2D(padding=(1, 1))(x)
#     x = conv2d(x, fl_filters, kernel_size=4, strides=2, padding='valid')
#     x = LeakyReLU(alpha=0.2)(x)

#     x = ZeroPadding2D(padding=(1, 1))(x)
#     for i in range(1, hidden_layers + 1):
#         nf = 2 ** i * fl_filters
#         x = conv2d(x, fl_filters, kernel_size=4, strides=2, padding='valid')
#         x = LeakyReLU(alpha=0.2)(x)
#         x = ZeroPadding2D(padding=(1, 1))(x)
#     x = conv2d(x, 1, kernel_size=4, activation='sigmoid', strides=(1, 1))
#     outputs = x
    
#     return Model(inputs=[inputs], outputs=outputs)


def discriminator(channels=3, ndf=64, hidden_layers=3, channel_first=False):
    """ndf: filters of the first layer"""    
    inputs = Input(shape=(None, None, channels))
    x = inputs
    x = conv2d(x, ndf, kernel_size=4, strides=2, padding="same")
    x = LeakyReLU(alpha=0.2)(x)
    
    for layer in range(1, hidden_layers):        
        out_feat = 2 ** layer * ndf 
        x = conv2d(x, out_feat, kernel_size=4, strides=2, padding="same", use_bias=False)
        x = batchnorm()(x, training=1) # training parameter?       
        x = LeakyReLU(alpha=0.2)(x)
    
    out_feat = ndf * 2 ** hidden_layers 
    x = ZeroPadding2D(padding=(1, 1))(x)
    x = conv2d(x, out_feat, kernel_size=4,  use_bias=False) 
    x = batchnorm()(x, training=1)
    x = LeakyReLU(alpha=0.2)(x)
    
    # final layer
    x = ZeroPadding2D(padding=(1, 1))(x)
    x = conv2d(x, 1, kernel_size=4, activation = "sigmoid")   
    return Model(inputs=[inputs], outputs=x)

In [None]:
def mse(output, target):
    return K.mean(K.abs(K.square(output-target)))

def disc_loss(disc, real, fake):
    d_real = disc([real])  # input  -> [0, 1].  Prob that real input is real.
    d_fake = disc([fake])  # generated sample -> [0, 1]. Prob that generated output is real.
    d_loss_real = mse(d_real, K.ones_like(d_real))
    d_loss_fake = mse(d_fake, K.zeros_like(d_fake))
    d_loss = (d_loss_real + d_loss_fake)/2
    
    return d_loss

def cycle_loss(reconstructed, real):
    return K.mean(K.abs(reconstructed - real))

def gen_loss(disc, fake):
    d_gen = disc([fake])
    return mse(d_gen, K.ones_like(d_gen))

In [None]:
d_a = discriminator()
d_b = discriminator()
g_a = generator()
g_b = generator()
real_a = g_b.inputs[0]
fake_b = g_b.outputs[0]
rec_a = g_a([fake_b])
real_b = g_a.inputs[0]
fake_a = g_a.outputs[0]
rec_b = g_b([fake_a])

d_a_loss = disc_loss(d_a, real_a, fake_a)
d_b_loss = disc_loss(d_b, real_b, fake_b)
g_a_loss = gen_loss(d_a, fake_a)
g_b_loss = gen_loss(d_b, fake_b)

cycleA_generate = K.function([real_a], [fake_b, rec_a])
cycleB_generate = K.function([real_b], [fake_a, rec_b])

cycle_loss = cycle_loss(rec_a, real_a) + cycle_loss(rec_b, real_b)
g_total_a = g_a_loss + 10*cycle_loss
g_total_b = g_b_loss + 10*cycle_loss
g_total = g_a_loss + g_b_loss + 10*cycle_loss
d_total = d_a_loss + d_b_loss

In [None]:
weights_d = d_a.trainable_weights + d_b.trainable_weights
weights_g = g_a.trainable_weights + g_b.trainable_weights

training_updates = Adam(lr=2e-4, beta_1=0.5, beta_2=0.999).get_updates(weights_d, [], d_total)
d_train_function = K.function([real_a, real_b], [d_a_loss, d_b_loss], training_updates)
training_updates = Adam(lr=2e-4, beta_1=0.5, beta_2=0.999).get_updates(weights_g, [], g_total)
g_train_function = K.function([real_a, real_b], [g_a_loss, g_b_loss, cycle_loss], training_updates)

In [None]:
def read_image(img, imagesize=256):
    img = Image.open(img).convert('RGB')
    img = img.resize((256, 256), Image.BICUBIC)
    img = np.array(img)
    img = img.astype(np.float32)
    img = (img - 127.5) / 127.5

    return img

trainA = glob.glob('data/trainA/*')
trainB = glob.glob('data/trainB/*')
print(len(trainA))
print(len(trainB))

In [None]:
def minibatch(data, batchsize=1):
    length = len(data)
    shuffle(data)
    epoch = i = 0
    tmpsize = None    
    while True:
        size = tmpsize if tmpsize else batchsize
        if i+size > length:
            i = 0
            epoch+=1        
        rtn = [read_image(data[j]) for j in range(i,i+size)]
        i+=size
        tmpsize = yield epoch, np.float32(rtn)       

def minibatchAB(dataA, dataB, batchsize=1):
    batchA=minibatch(dataA, batchsize)
    batchB=minibatch(dataB, batchsize)
    tmpsize = None    
    while True:        
        ep1, A = batchA.send(tmpsize)
        ep2, B = batchB.send(tmpsize)
        tmpsize = yield max(ep1, ep2), A, B

In [None]:
def display_image(X, rows=1, image_size=256):
    int_X = ((X*127.5+127.5).clip(0,255).astype('uint8'))
    int_X = int_X.reshape(-1,image_size,image_size, 3)
    int_X = int_X.reshape(rows, -1, image_size, image_size,3).swapaxes(1,2).reshape(rows*image_size,-1, 3)
    pil_X = Image.fromarray(int_X)
    display(pil_X)

train_batch = minibatchAB(trainA, trainB, 4)

_, A, B = next(train_batch)
display_image(A)
display_image(B)
del train_batch, A, B

In [None]:
class ImagePool():
    def __init__(self, pool_size=50):
        self.pool_size = pool_size
        if self.pool_size > 0:
            self.num_imgs = 0
            self.images = []

    def query(self, images):
        if self.pool_size == 0:
            return images
        
        return_images = []
        for image in images:
            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = uniform(0, 1)
                if p > 0.5:
                    random_id = randint(0, self.pool_size-1)
                    tmp = self.images[random_id]
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return_images = np.stack(return_images, axis=0)
        
        return return_images

In [None]:
def showG(A,B):
    def G(generated, X):
        r = np.array([generated([X[i:i+1]]) for i in range(X.shape[0])])
        return r.swapaxes(0,1)[:,:,0]        
    rA = G(cycleA_generate, A)
    rB = G(cycleB_generate, B)
    arr = np.concatenate([A,B,rA[0],rB[0],rA[1],rB[1]])
    display_image(arr, 3)

In [None]:
t0 = time.time()
epoch = 0
EPOCHS = 200
DISPLAY_STEP = 50
counter = 0

train_batch = minibatchAB(trainA, trainB)
while epoch < EPOCHS:
    epoch, A, B = next(train_batch)
    _,_  = d_train_function([A, B])
    _,_, _ = g_train_function([A, B])
    counter += 1
    if np.mod(counter, DISPLAY_STEP) == 0:
        clear_output()
        print('[Epoch {}/{}][Iteration {}]'.format(epoch, EPOCHS, counter))
        showG(A,B)
        
    # TODO save 2 pictures every 10 epochs
    # TODO connect tensorboard