In [None]:
#Use GPU No 0.
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"
sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))

from keras.layers.convolutional import UpSampling2D, Conv2D, MaxPooling2D
from keras.optimizers  import Adam, Adagrad
from keras.activations import sigmoid,relu
from keras.callbacks   import TensorBoard
from keras.models import Model, Sequential
from keras.utils  import to_categorical
from keras.models import load_model
from keras.layers import Input, concatenate, multiply, GaussianNoise, Flatten, Dense, Lambda, Subtract, Activation, multiply, Add, Conv2DTranspose, AveragePooling2D, MaxPooling2D
from keras import backend as K
import keras

import numpy as np
import matplotlib.pyplot as plt

import time
from Datasets import Dataset, get_model_inputs
from perceptual_model import get_perceptual_model

import os
import tensorflow as tf
from PIL import ImageFile

In [None]:
#size of input image: c*c
c = 128

In [None]:
def save_models(savekey, epoch_reg, save_disc=True, save_gen=True):
    if save_disc:
        dis_model.save("../../data/team6/model_checkpoints/{}_dis_model{}.h5".format(savekey,epoch_reg))
    if save_gen:
        generator.save("../../data/team6/model_checkpoints/{}_gen_model{}.h5".format(savekey,epoch_reg))

In [None]:
def MRU(X,I, filter_depth, deconv=False):
    
    print("Input sizes: ",I.shape,X.shape)

    out_size = X.get_shape().as_list()[-1]
    new_size = X.get_shape().as_list()[-1]
    
    #deconv or conv
    if deconv:
        #same as upsample
        X = Conv2DTranspose(filters = out_size, kernel_size=(2,2), strides=(2,2))(X)

    merge_one        = concatenate([X,I])

    #mi
    conv_sig_m     = Conv2D(filters = out_size, kernel_size=(3, 3), padding="same", activation="sigmoid")(merge_one)
    
    #ni
    conv_sig_n       = Conv2D(filters = filter_depth, kernel_size=(3, 3), padding="same", activation="sigmoid")(merge_one)
    
    mul_lower        = multiply([conv_sig_m,X])

    merge_two        = concatenate([mul_lower,I])

    conv_func_one    = Conv2D(filters = filter_depth, kernel_size=(3, 3), padding="same", activation="relu")(merge_two)
    
    mul_higher       = multiply([conv_sig_n, conv_func_one])

    conv_sig_sub     = Lambda(lambda x: 1 - x)(conv_sig_n)
    
    conv_x           = Conv2D(filters=filter_depth, kernel_size=(3, 3), padding="same", activation="relu")(X)

    mul_final        = multiply([conv_x,conv_sig_sub])

    sum_final        = Add()([mul_higher,mul_final])
    
    if not deconv:
        sum_final    = MaxPooling2D(pool_size = 2)(sum_final)
       
    #my addition
    sum_final        = Activation("relu")(sum_final)  
    
    return sum_final

In [None]:
def get_xx_model(img_shape=(c,c,3)):
    I = Input(shape=img_shape)
    
    out = Conv2D(filters=9, kernel_size=(c,c))(I)

    return Model(I, out, name='xx_layer')

In [None]:
def get_encoder(I):

    K1  = concatenate([I,I])
    new = MRU(I  ,I ,32)
    I2  = AveragePooling2D(pool_size=2)(I)
    
    K2  = concatenate([new,I2])
    new = MRU(new,I2,64)
    I3  = AveragePooling2D(pool_size=2)(I2)
    
    K3  = concatenate([new,I3])
    new = MRU(new,I3,128)
    I4  = AveragePooling2D(pool_size=2)(I3)

    K4  = concatenate([new,I4])
    out = MRU(new,I4,256)
    I5  = AveragePooling2D(pool_size=2)(I4)
    
    K5  = concatenate([out,I5])
    
    encoder = Model(I, out, name='encoder');
    
    return encoder,K4,K3,K2,K1
    
def get_decoder(feature_map, K4,K3,K2,K1,I):

    decoder_in = Input(shape=(int(feature_map.shape[1]),int(feature_map.shape[2]),int(feature_map.shape[3])))

    new = MRU(decoder_in,K4,128,deconv=True)
    new = MRU(new,K3,64,deconv=True)
    new = MRU(new,K2,32,deconv=True)
    out = MRU(new,K1, 3,deconv=True)
    
    decoder = Model([decoder_in,I], out, name='decoder');
    
    return decoder

def get_generator(img_shape, color_info_shape):
    
    I = Input(shape=img_shape)
    R = Input(shape=color_info_shape)
    
    encoder,K4,K3,K2,K1    = get_encoder(I)
    
    decoder    = get_decoder(encoder.output, K4,K3,K2,K1,I)
    
    encoded    = encoder(I)
    
    xy_layer   = Conv2DTranspose(filters=9, kernel_size=(8,8))
    
    colors     = xy_layer(R)
    
    encolor    = concatenate([colors,encoded])

    out        = decoder([encoded,I])
    
    return Model([I,R], out, name='generator_model')

def get_discriminator(img_shape):
    
    I = Input(shape=img_shape)
    out = MRU(I, I, 8)

    return Model(I, out, name='discriminator_skeleton')


In [None]:
# Create datasets

dataset_parent_path = os.path.join('..','..','data','team6','celeba_ydk')
dataset = Dataset("celeba", dataset_parent_path, False, 1)
save_key = "celeba-ct"

num_class = 2

In [None]:
# GAN parameters
img_rows = c
img_cols = c
channels = 3
img_shape = (img_rows, img_cols, channels)

In [None]:
img_objects = dataset.get_all_image_objects()

In [None]:
percept_model = get_perceptual_model()

In [None]:
#constructing xx_model

in_real      = Input(shape=(img_shape))

xx_layer     = Conv2D(filters=9, kernel_size=(128,128), name="xx_model", kernel_initializer="random_uniform")

color_inf1   = xx_layer(in_real)

xx_model     = Model(input=in_real, outputs=color_inf1)

# Compiling generator model
losses = {'xx_model': 'mse'}

lossWeights = {"xx_model": 1.0}


opt = keras.optimizers.Adam(lr=2e-4)
xx_model.compile(optimizer=opt, loss=losses, loss_weights=lossWeights, metrics=['mse'])

In [None]:
# Constructing generator


# Generator Input
in_gen = Input(shape=(img_shape))
#color map (random noise at prediction)
in_Rmp = Input(shape=(1,1,9))


# Get Generator model
generator = get_generator(img_shape,(1,1,9))

# Normal output of generator
out_gen = generator([in_gen,in_Rmp])

# Final generator model
gen_model = Model([in_gen,in_Rmp], out_gen)

# Constructing discriminator

# Discriminator input 
disc_in = Input(shape=img_shape)

# Get skeleton of discriminator
disc = get_discriminator(img_shape)

# Pass generator output 
x = disc(disc_in)

# Create output layers
validity_dense = Dense(1, activation='sigmoid', name='end_sigmoid')
classification_dense = Dense(num_class, activation='softmax', name='end_softmax')

flat_white = Flatten()(x)

dis_valid_dense = validity_dense(flat_white)
dis_class_dense = classification_dense(flat_white)

# Final discriminator model
dis_model = Model(disc_in, [
                            dis_valid_dense, 
                            dis_class_dense
                            ] )

# Constructing GAN 

# Create input layers
in_gan = Input(shape=(img_shape))
out_gen_for_gan = generator([in_gan,in_Rmp])

# Pass generator output from perceptual model for its loss
percept_out_for_gan = percept_model(out_gen_for_gan)

# Pass generator output from discriminator model for
gan_dis_sket_out = disc(out_gen_for_gan)

# 
flat_black = Flatten()(gan_dis_sket_out)

# Pass from validity dense layer
gan_valid_out = validity_dense(flat_black)
# Pass from auxilary classification dense layer
gan_class_out = classification_dense(flat_black)

# Final GAN model
gan_model = Model(
                  input=[in_gan, in_Rmp], 
                  outputs=[
                      out_gen_for_gan, 
                           gan_valid_out, 
                           gan_class_out, 
                           percept_out_for_gan,
                          ]
                 )

In [None]:
# Compiling gan model
losses = {
    'perceptual-model': 'mse',
    'generator_model': 'mse',
    'end_softmax': 'categorical_crossentropy',
    'end_sigmoid': 'binary_crossentropy',
    }

lossWeights = {"end_softmax": -0.5, # Want to maximize it in GAN
               "end_sigmoid": 0.5,
               "generator_model": 0.5,
               "perceptual-model": 0.01,
              }

# Make discriminator not trainable for gan

gan_model.layers[-3].trainable = False
gan_model.layers[-2].trainable = False
gan_model.layers[-1].trainable = False

gan_opt = keras.optimizers.Adam(lr=2e-4)

gan_model.compile(loss=losses, optimizer=gan_opt,loss_weights=lossWeights, metrics=['accuracy'])

In [None]:
# Compiling generator model
losses = {'generator_model': 'mse'}

lossWeights = {"generator_model": 0.1,}


gen_opt = keras.optimizers.Adam(lr=2e-4)
gen_model.compile(optimizer=gen_opt, loss=losses, loss_weights=lossWeights, metrics=['mse'])

In [None]:
# Compiling discriminator model
losses = {'end_sigmoid': 'mse',
          'end_softmax': 'mse'
         }

lossWeights = {"end_sigmoid": 1.0, 
               "end_softmax": 1.0, 
              }

dis_opt = keras.optimizers.Adam(lr=1e-6)

dis_model.trainable = True
dis_model.layers[1].layers[1].trainable = False
dis_model.compile(loss=losses, optimizer=dis_opt,loss_weights=lossWeights, metrics=['mse'])

In [None]:
# Save logs to ./v2_logs
logdir = 'v2_logs'
if not os.path.exists(logdir):
    os.mkdir(logdir)
callback = TensorBoard(logdir, write_graph=True)
callback.set_model(gan_model)

In [None]:
def write_log(callback, names, logs, batch_no):
    for name, value in zip(names, logs):
        summary = tf.Summary()
        summary_value = summary.value.add()
        summary_value.simple_value = value
        summary_value.tag = name
        callback.writer.add_summary(summary, batch_no)
        callback.writer.flush()

In [None]:
#Save example predictions.

def save_imgs(epoch, input_data, iteration,col=np.array([])):
    r, c = 2, 2
    idx = np.array([12,5,2,10]) 
    elements = input_data[idx,:,:]
    if col.size<=1:
        col = np.random.random_sample((elements.size,1,1,9))
    gen_imgs = gen_model.predict([elements, col])

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt, :,:,:]/255)
            axs[i,j].axis('off')
            cnt += 1
    fig.suptitle('EPOCH-ITERATION: {}-{}'.format(epoch, iteration))
    fig.savefig("results/{}/{}_{}_{}.png".format(save_key,save_key,epoch,iteration))
    plt.close()

In [None]:
resuming_training = False

if resuming_training:
    reg_all_epochs = epoch
    goto_save = True
    iter_reg = iter_num
else:
    reg_all_epochs = 0
    goto_save = False
    iter_reg = 1

In [None]:
time_init_train = time.time()

batch_size    = 32
iter_count    = len(img_objects) // batch_size
total_count   = 0
initial_epoch = reg_all_epochs
epochs        = 100

for epoch in range(epochs):
    
    reg_all_epochs = initial_epoch + epoch
    
    print("Enter epoch {}.".format(reg_all_epochs));
    
    epoch_init_time = time.time()
    
    if goto_save:
        iter_low  = iter_reg
        goto_save = False
    else:
        iter_low  = 1

    
    for iter_num in range(iter_low, iter_count):
        
        total_count += 1
        
        # Obtain batch of inputs
        batch_objects = img_objects[batch_size*iter_num:batch_size*(iter_num+1)]
        #try:
        in_x0, in_x1, class_labels = get_model_inputs(batch_objects, img_shape, num_class)
        
        if epoch ==0 and iter_num == 1:
            #get a sample for example imgs
            sv_imgs_in = in_x0

        # At every 100 iteration, save example predictions.
        if iter_num%100 ==0:
            save_imgs(reg_all_epochs, sv_imgs_in, iter_num)
            print("{}/{}".format(iter_num,iter_count) )
        colors = xx_model.predict(in_x1)
        
        # Forward pass the generator to get predicted examples
        gen_sketches = gen_model.predict([in_x0,colors])
        
        # Pass real inputs from perceptual model to be use in loss
        org_percepts = percept_model.predict(in_x1)
    

        # train the discriminator with class 1 only
        if epoch + iter_num %2 == 0:
            disc_losses = dis_model.train_on_batch(gen_sketches, [np.zeros(batch_size), class_labels])

        # train the discriminator with class 2 only
        else:
            disc_losses = dis_model.train_on_batch(in_x1, [np.ones(batch_size), class_labels] )

        gan_losses = gan_model.train_on_batch([in_x0,colors], [in_x1, np.ones(batch_size), 
                                                      class_labels,
                                                      org_percepts,
                                                     ])
                                                          
        
        write_log(callback, gan_model.metrics_names, gan_losses, total_count)
        
        #save model every 2000 iterations.
        if(iter_num%2000==0):
            save_models(save_key, reg_all_epochs, save_disc=True, save_gen=True)
            print("\tTotal time elapsed: {:.2f} secs.".format(time.time()-time_init_train))
        
    
    print("Epoch {} finished in {:.2f} secs.".format(reg_all_epochs, time.time()-epoch_init_time))
    

In [None]:
import cv2
import matplotlib.pyplot as plt
 
# Predict from source files. 
# gt_bit 1 if ground-truth exists, 0 otherwise.

def predict_real(addr, interpolation = 0, gt_bit=0):

    img_original  = cv2.imread(addr)

    img  = cv2.resize(img_original, (c,c), interpolation=interpolation) 
    imgx = np.expand_dims(img,0)
    img_result = gen_model.predict([imgx,np.random.random_sample((1,1,1,9))])
    img_result = np.array(img_result[0])
    img_result = np.reshape(img_result, img_result.shape[:])/255
    fig, axs = plt.subplots(1, 2+gt_bit)
    
    axs[0].imshow(img)
    axs[1].imshow(img_result)
    axs[1].axis('off')
    if gt_bit==1:
        tmp = addr.split("/")
        tmp[-2] = "data"
        tmp = str.join("/",tmp)
        img_gt  = cv2.imread(tmp)
        img_gt  = cv2.resize(img_gt, (c,c))
        axs[2].imshow(img_gt[...,::-1])
        axs[2].axis('off')

In [None]:
# Save source images

def save_src(input_data=sv_imgs_in):
    r, c = 2, 2
    idx = np.array([12,4,2,11]) 
    elements = input_data[idx,:,:]
    gen_imgs,_ = gen_model.predict(elements)

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(input_data[idx[i*r+j]]/255)
            axs[i,j].axis('on')
            cnt += 1
    fig.suptitle('SOURCE EDGES')
    fig.show()
    plt.show()
    fig.savefig("results/{}/src.png".format(save_key))
    plt.close()