In [None]:
%matplotlib inline
from skimage import io, filters
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os,sys
from lgan.diffeomorphism import tf_diffeomorphism
from tqdm import tqdm
from scipy import misc
from keras.preprocessing import image
from scipy import ndimage, misc
from tflearn.data_augmentation import ImageAugmentation
import copy

In [None]:
#PREROCESS IMAGES - only do once
def resizeFaces():
    path = '/home/dorian/MyGans/LabelGAN/faces_only'
    dirs = os.listdir( path )    
    for item in dirs:
        if os.path.isfile(os.path.join(path,item)):
            inread = ndimage.imread(os.path.join(path,item), mode="RGB")
            image_cut = inread[0:900,200:1100,:] #cut image:1333x1013->900x900
            image_resized = misc.imresize(image_cut, (64, 64))
            misc.imsave('/home/dorian/MyGans/LabelGAN/faces/0/'+item, image_resized)
    
            
def cutImages():
    path = '/home/ben/celeba/data/0'
    dirs = os.listdir( path )    
    for item in dirs:
        if os.path.isfile(os.path.join(path,item)):
            inread = ndimage.imread(os.path.join(path,item), mode="RGB")
            image_cut = inread[40:198,10:168,:] #cut image:218x179->158x158
            image_resized = misc.imresize(image_cut, (64, 64))
            misc.imsave('/home/ben/celeba/cut/0/'+item, image_resized)
            

#resizeFaces()

print(len(os.listdir("/home/ben/celeba/cut/0")))
print(len(os.listdir('/home/dorian/MyGans/LabelGAN/faces/0/')))

In [None]:
#hyper-parameter
data_size = 202599
batch_size = 32 #batch size
input_dim = 10 
image_size = 4096
imX= 64
imY= 64
LAMBDA = 10
output_dim = imX*imY*3

#Session
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
sess = tf.Session(config=config)


In [None]:
#simple plot function
def plot(samples, labels,y,x):
    fig = plt.figure(figsize=(10, 10))
    gs = gridspec.GridSpec(y, x)
    gs.update(wspace=0.05, hspace=0.05)
    plt.subplots_adjust(left=None, bottom=None, right=1, top=1.5,
                wspace=None, hspace=None)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.title(labels)
        plt.imshow(sample)   
    return fig


#images plotted while learning
def preprocessPlotImages(samples):
    images = np.zeros((batch_size,imY,imX,3))
    for i in range(batch_size):
        befor = copy.copy(samples[1][i])
        after = copy.copy(samples[0][i])
        befor_reshape = misc.imresize(befor, (20,20))/255
        after[44:64,44:64]=befor_reshape
        images[i]=after
    return images
        
        
        
    

In [None]:
datagen = image.ImageDataGenerator(rescale=1./255)

batcher_celeb = datagen.flow_from_directory(
    directory='/home/ben/celeba/cut',
    target_size= (64, 64),
    color_mode= 'rgb',
    class_mode= None,
    batch_size= batch_size)

batcher_faces = datagen.flow_from_directory(
    directory='/home/dorian/MyGans/LabelGAN/faces',
    target_size= (64, 64),
    color_mode= 'rgb',
    class_mode= None,
    batch_size= batch_size)

In [None]:
def getImageBatch():
    img_batch = batcher_celeb.next()
    if len(img_batch) != batch_size:
        img_batch = batcher_celeb.next()
    assert len(img_batch) == batch_size
    return img_batch

def getFacesBatch():
    img_batch = batcher_faces.next()
    if len(img_batch) != batch_size:
        img_batch = batcher_faces.next()
    assert len(img_batch) == batch_size
    return img_batch
  
facebatch = getFacesBatch()
plot(facebatch[0:16], "GenImage",4,4)
plt.show()
celeb = getImageBatch()
plot(celeb[0:16], "CelebA Image",4,4)
plt.show()

In [None]:
def sample_z(batch_size):
    batch = np.zeros((batch_size,input_dim))
    for i in range(batch_size):
        vector = np.random.uniform(-1., 1., size=[input_dim]) #create noise vector of 128
        batch[i] = vector
    return batch


In [None]:
def leakyReLU(x, alpha=0.2):
    return tf.maximum(alpha*x, x)

def resnet_block(inputs, maps=64, kernel=[3, 3],stride=1):
    layer = slim.conv2d(inputs, maps, kernel, stride,weights_initializer=tf.truncated_normal_initializer(stddev=1e-1),
                  padding = 'SAME',activation_fn=None)
    layer = slim.batch_norm(layer)
    layer = tf.nn.relu(layer)
    layer = slim.conv2d(layer, maps, kernel, stride,weights_initializer=tf.truncated_normal_initializer(stddev=1e-1),
                  padding = 'SAME',activation_fn=None)
    layer = slim.batch_norm(layer)
    outputs = tf.add(inputs, layer)
    return outputs

In [None]:
def simGenerator(image, noise):
    image = tf.reshape(image,[batch_size,imY,imX,3])
    noise = slim.fully_connected(noise, 4096,weights_initializer=tf.truncated_normal_initializer(stddev=0.01))
    noise = tf.reshape(noise, [batch_size,imY,imX,1])
    net = tf.concat([image, noise],3)
    net = slim.conv2d(net, 64, [3,3],1)
    net = resnet_block(net, maps=64, kernel=[3, 3],stride=1)
    net = resnet_block(net, maps=64, kernel=[3, 3],stride=1)
    net = resnet_block(net, maps=64, kernel=[3, 3],stride=1)
    net = resnet_block(net, maps=64, kernel=[3, 3],stride=1)
    net = resnet_block(net, maps=64, kernel=[3, 3],stride=1)
    net = slim.conv2d(net, 32, [3,3],1)
    net = slim.conv2d(net, 3, [1,1],1)
    return tf.nn.tanh(net)

#Main Generator
def generator(face_image, z):
    with tf.variable_scope('generator'):
        face_image = tf.reshape(face_image,[batch_size, imY,imX,3])
        #dif_params = generator_dif_paras(z)
        #dif_image = tf_diffeomorphism(face_image,dif_params)
        detailed_image = simGenerator(face_image,z)
        return [detailed_image,face_image]



In [None]:
def discriminator(x):
    with tf.variable_scope('discriminator'):
        x = tf.reshape(x,[batch_size,imY,imX,3])
        x = slim.conv2d(x, 16, [4, 4],2, weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
                         padding = 'SAME',activation_fn=None)
        x = leakyReLU(x)
        net = slim.conv2d(x, 32, [3, 3],2, weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
                         padding = 'SAME',activation_fn=None)
        net = leakyReLU(net)
        net = slim.conv2d(net, 128, [3, 3],2,weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
                         padding = 'SAME',activation_fn=None)
        net = leakyReLU(net)
        net = slim.conv2d(net, 256, [3, 3],2,weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
                         padding = 'SAME',activation_fn=None)
        net = leakyReLU(net)
        #net = slim.conv2d(net, 256, [3, 3],2,weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
        #                 padding = 'SAME',activation_fn=None)
        #net = leakyReLU(net)
        net = slim.conv2d(net, 512, [3,3],2,weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
                         padding = 'SAME',activation_fn=None) 
        net = leakyReLU(net)
        net = slim.conv2d(net, 1, [2,2],2,weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
                         padding = 'SAME',activation_fn=None)   
        return [net,tf.nn.relu(x)]

In [None]:
def predictStartImage(blurImage):
    with tf.variable_scope('predictor'):
        genImage = slim.conv2d(blurImage,16, [5,5],2,activation_fn=None)
        genImage = slim.batch_norm(genImage)
        genImage = tf.nn.relu(genImage)
        genImage = slim.conv2d(genImage, 32, [5,5],2,activation_fn=None)
        genImage = slim.batch_norm(genImage)
        genImage = tf.nn.relu(genImage)
        genImage = slim.convolution2d_transpose(genImage, 64, [3,3],2,activation_fn=None)
        genImage = slim.batch_norm(genImage)
        genImage = tf.nn.relu(genImage)
        genImage = slim.convolution2d_transpose(genImage,32, [3,3],2,activation_fn=None)
        genImage = slim.batch_norm(genImage)
        genImage = tf.nn.relu(genImage)
        genImage = slim.convolution2d_transpose(genImage,3, [3,3],2)
        return tf.nn.tanh(genImage)

In [None]:
def getBlurImages(faces,z):
    samples = sess.run(face_images, feed_dict={Face:faces,Z:z})
    samples_blur = np.zeros((batch_size,imY,imX,3))
    for i,s in enumerate(samples[0]):
        samples_blur[i]= ndimage.filters.gaussian_filter(samples[0][i],1.0)
        samples_blur[i] = samples[0][i]
    return samples_blur

In [None]:
'''
#create bool mask
mask = np.ones((batch_size,imY,imX,3))
for y in range(imY):
    for x in range(imX):
        if y+x<=20:
            mask[:,y,x,:]=0
        if x+(imY-y)<=20:
            mask[:,(imY-y),x,:]=0
        if (imX-x)+y<=20:
            mask[:,y,(imX-x),:]=0
        if(imX-x)+(imY-y)<=20:
            mask[:,(imY-y),(imX-x),:]=0
plt.imshow(mask[0])
plt.show()
'''

In [None]:
#input
Face = tf.placeholder(tf.float32, shape=[batch_size, imY,imX,3])
X = tf.placeholder(tf.float32, shape=[batch_size,imY,imX,3])
Z = tf.placeholder(tf.float32, shape=[batch_size, input_dim])
#Blur = tf.placeholder(tf.float32, shape=[batch_size,imY,imX,16])

#Models
face_images = generator(Face, Z ) 
D_real = discriminator(X)
D_fake = discriminator(face_images[0])
predictImage = predictStartImage(D_fake[1])

    
#variables V1
theta_D = sess.graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator') 
theta_G = sess.graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'generator')
theta_Pred = sess.graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'predictor')


D_losses, G_losses = [],[]

#meanSquaredErrorLoss for Predictor:
P_loss = tf.reduce_sum(tf.squared_difference(predictImage, Face))

#Discriminator Wasserstein-Loss
D_loss = tf.reduce_mean(D_fake[0]) - tf.reduce_mean(D_real[0])

#Generator L1-Loss
G_l1_loss = tf.reduce_sum(tf.abs(face_images[0]-face_images[1]))

#Generator Loss
G_loss = -tf.reduce_mean(D_fake[0])+tf.multiply(0.1,P_loss)


#improved WGAN without weight clipping. Instead penalizing gradient 
alpha = tf.random_uniform(shape=[batch_size,1], minval=0.,maxval=1.)

differences = tf.reshape(face_images[0] - X, (batch_size, output_dim))
interpolates = tf.reshape(X,(batch_size, output_dim)) + (alpha*differences)
interpolates = tf.reshape(interpolates, (batch_size, 64,64,3))
gradients = tf.gradients(discriminator(interpolates), [interpolates])[0]
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
gradient_penalty = tf.reduce_mean((slopes-1.)**2)
D_loss += LAMBDA*gradient_penalty

G_losses.append(G_loss)
D_losses.append(D_loss)

G_loss = tf.add_n(G_losses)
D_loss = tf.add_n(D_losses) 

#Solver
D_solver = (tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0.5, beta2=0.9)
            .minimize(D_loss, var_list=theta_D, colocate_gradients_with_ops=True))
G_solver = (tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0.5, beta2=0.9)
            .minimize(G_loss, var_list=theta_G, colocate_gradients_with_ops=True))
G_solver_L1 = (tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0.5, beta2=0.9)
            .minimize(G_l1_loss, var_list=theta_G, colocate_gradients_with_ops=True))
P_solver = (tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0.5, beta2=0.9)
            .minimize(P_loss, var_list=theta_Pred, colocate_gradients_with_ops=True))


if not os.path.exists('out-imdb-faces/'):
    os.makedirs('out-imdb-faces/')


#initalize Variables 
#Add ops to save and restore all the variables.
saver = tf.train.Saver()
if os.path.isfile("checkpoints-celeba/model-celeba.index"):
    # Restore variables from disk.
    print("Restore variables")
    saver.restore(sess, "checkpoints-celeba/model-celeba")
else:
    print("Instantiate variables")
    sess.run(tf.global_variables_initializer())    

In [None]:
#pretrain gen with L1 Loss:
for t in tqdm(range(0)): 
    faces = getFacesBatch()
    z = sample_z(batch_size)
    _, G_loss_curr = sess.run(
        [G_solver_L1, G_l1_loss],
        feed_dict={Face:faces,Z:z}
    )


In [None]:
#pretrain disc:
for t in tqdm(range(0)):
    #train discriminator
    Xdata = getImageBatch()
    faces = getFacesBatch()
    z = sample_z(batch_size)
    _, D_loss_curr = sess.run(
        [D_solver, D_loss],
        feed_dict={X: Xdata, Face:faces,Z:z}
    ) 

In [None]:

#pretrain predictor
for t in tqdm(range(0)):
    faces = getFacesBatch()
    z = sample_z(batch_size)
    _, P_loss_curr = sess.run(
        [P_solver, P_loss],
        feed_dict={Face:faces,Z:z}
    ) 


In [None]:
i = 0
d_costs = []
g_costs = []
for it in range(100000):
    for q in range(5): #train discriminator
        Xdata = getImageBatch()
        faces = getFacesBatch()
        z = sample_z(batch_size)
        _, D_loss_curr = sess.run(
            [D_solver, D_loss],
            feed_dict={X: Xdata,Face:faces,Z:z}
        )
        d_costs.append(D_loss_curr)
    #trainGenerator 
    faces = getFacesBatch()
    z = sample_z(batch_size)  
    _, G_loss_curr = sess.run(
        [G_solver, G_loss],
        feed_dict={Face:faces,Z:z}
    )
    g_costs.append(G_loss_curr)
    #trainPredictor
    _, P_loss_curr = sess.run(
        [P_solver, P_loss],
        feed_dict={Face:faces,Z:z}
    ) 
    if it % 100 == 0:
        print('Iter: {}; D loss: {:.4}; G_loss: {:.4}; P_loss: {:.4}'
              .format(it, D_loss_curr, G_loss_curr,P_loss_curr))
        faces = getFacesBatch()
        z = sample_z(batch_size)
        samples = sess.run(face_images, feed_dict={Face:faces,Z:z})  
        imagesToPlot = preprocessPlotImages(samples)
        if it % 1000 == 0:
            fig = plot(imagesToPlot[:16],"GenImage",4,4)
            plt.savefig('out-imdb-faces/{}.png'
                        .format(str(i).zfill(3)), bbox_inches='tight')
            plt.show()
            #save variables
            save_path = saver.save(sess, "checkpoints-celeba/model-celeba")
            print("Model saved in file: %s" % save_path)
            i += 1
        else:
            baseImages = sess.run(predictImage, feed_dict={Face:faces,Z:z})
            
            fig = plot(imagesToPlot[:4],"GenImage",1,4)
            plt.show()           
            fig2 = plot(baseImages[:4],"PredictImage",1,4)
            plt.show()
            
plt.close(fig)

In [None]:
plt.plot(d_costs[500:_])
plt.ylabel('D_Loss')
plt.show()
plt.savefig('out-celebA/DLoss.png'
                        .format(str(i).zfill(3)), bbox_inches='tight')

plt.plot(g_costs[500:_])
plt.ylabel('G_Loss')
plt.show()
plt.savefig('out-celebA/GLoss.png'
                        .format(str(i).zfill(3)), bbox_inches='tight')