In [None]:
%matplotlib inline
from skimage import io, filters
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os,sys
from lgan.diffeomorphism import tf_diffeomorphism
from tqdm import tqdm
from scipy import misc
from keras.preprocessing import image
from scipy import ndimage, misc

In [None]:
#hyper-parameter
data_size = 100
batch_size = 32 #batch size
input_dim = 128 #dim http://localhost:8889/notebooks/MyGans/LabelGAN/LabelGAN-celebA.ipynb#x and z input_size
image_size = 4096
attribute_size = 40
imX= 64
imY= 64
LAMBDA = 10
output_dim = imX*imY*3

#Session
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
sess = tf.Session(config=config)


In [None]:
datagen = image.ImageDataGenerator(rescale=1./255)

batcher = datagen.flow_from_directory(
    directory='/home/ben/celeba/resized',
    target_size= (64, 64),
    color_mode= 'rgb',
    class_mode= None,
    batch_size= batch_size)

In [None]:
#annotations
attr_file  = open('/home/ben/celeba/Anno/list_attr_celeba.txt', "r")
attributes = np.zeros((data_size,attribute_size))
attr = attr_file.readline() #get different attributes
attr = attr.split(" ")[0:attribute_size] #40 annotations

for i in range(data_size):
    a = attr_file.readline()
    a = a.split(" ")
    a = np.array(list(filter(lambda x:x!="",a))) #filter empty strings
    a = np.array(list(map(float, a[1:attribute_size+1])))
    attributes[i]=a

    


In [None]:
#simple plot function
def plot(samples, labels,y,x):
    fig = plt.figure(figsize=(10, 10))
    gs = gridspec.GridSpec(y, x)
    gs.update(wspace=0.05, hspace=0.05)
    plt.subplots_adjust(left=None, bottom=None, right=1, top=1.5,
                wspace=None, hspace=None)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.title(labels[i])
        plt.imshow(sample)   
    return fig

In [None]:
#Resize Images - only do once
def resize():
    path = '/home/ben/celeba/data/0'
    dirs = os.listdir( path )    
    for item in dirs:
        if os.path.isfile(os.path.join(path,item)):
            image = ndimage.imread(os.path.join(path,item), mode="RGB")
            image_resized = misc.imresize(image, (64, 64))
            misc.imsave('/home/ben/celeba/resized/0/'+item, image_resized)

#resize()
print(len(os.listdir("/home/ben/celeba/resized/0")))

In [None]:
def getImageBatch():
    img_batch = batcher.next()
    if len(img_batch) != batch_size:
        img_batch = batcher.next()
    assert len(img_batch) == batch_size
    return img_batch

plot(getImageBatch(),np.zeros(batch_size),8,4)
plt.show()

In [None]:
#create mean Images
#todo: sort image - attributes
#creates just nonsense means
meanImages = np.zeros((attribute_size,imY,imX,3))
for i in range(attribute_size):
    image_batch = getImageBatch()  
    meanImages[i] = np.mean(image_batch, axis=0)
plot(meanImages, np.zeros(attribute_size),10,4)

In [None]:
#toDO -how to do means`?
def getMeanImage(batch_size):
    indexes = (np.random.rand(batch_size)*20).astype(int)
    images= meanImages[indexes]
    return images

In [None]:
def sample_z(batch_size):
    batch = np.zeros((batch_size,input_dim))
    #labels = np.zeros((batch_size,10)) 
    for i in range(batch_size):
        vector = np.random.uniform(-1., 1., size=[input_dim]) #create noise vector of 128
        #index = int(np.random.rand(1)*10)
        #labels[i][index]= 1  #get random label 
        batch[i] = vector
    return batch


In [None]:
def convLayer(layerinput, maps, kernel, stride,scope, act="relu"):
    layerinput = slim.conv2d(layerinput, maps, kernel,stride, weights_initializer=tf.truncated_normal_initializer(stddev=1e-1),
                  scope=scope,padding = 'SAME')
    if(act == "tanh"):        
        return tf.nn.tanh(layerinput)
    else:
        layerinput = slim.batch_norm(layerinput)
        return tf.nn.relu(layerinput) 
    
def leakyReLU(x, alpha=0.2):
    return tf.maximum(alpha*x, x)

In [None]:
#Generator Variables
scalar = tf.Variable(tf.reduce_mean(tf.random_normal([1], stddev=1)))
multMatrix = tf.Variable(tf.random_normal([batch_size,2,256], stddev=1))

#Generator - add Detail
def generator_detail(noise, dif_image):
    with tf.variable_scope('generator_detail'):
        noise = tf.reshape(noise, (batch_size,64,2))
        noise = tf.matmul(noise, multMatrix) #49x16 Matrix=784
        noise = tf.reshape(noise,(batch_size, 32,32,16))
        noise = slim.batch_norm(noise)
        #noise = tf.scalar_mul(scalar, noise)

        details = tf.reshape(dif_image, [batch_size, 64,64,3])
        details = convLayer(details, 16, [3,3],2, scope='convReshape_1')
        
        detail_image = tf.add(details, noise)
        detail_image = convLayer(detail_image, 1024, [9,9],8, scope='convDetail_0')
        detail_image = slim.convolution2d_transpose(detail_image, 512, [3,3],2)
        detail_image = slim.batch_norm(detail_image)
        detail_image = tf.nn.relu(detail_image)
        detail_image = slim.convolution2d_transpose(detail_image, 256, [3,3],2)
        detail_image = slim.batch_norm(detail_image)
        detail_image = tf.nn.relu(detail_image)        
        detail_image = slim.convolution2d_transpose(detail_image, 128, [3,3],2)
        detail_image = slim.batch_norm(detail_image)
        detail_image = tf.nn.relu(detail_image)
        detail_image = slim.convolution2d_transpose(detail_image, 3, [3,3],2) 
        detail_image = tf.nn.tanh(detail_image)

        return detail_image
'''
#Generator - predict parameters for Diffeomorphism
def generator_dif_paras(noise_vector):
    with tf.variable_scope('generator_diffeo'):
        params = tf.reshape(noise_vector,[batch_size,16,8,1])
        params = convLayer(params, 32, [3, 3],2, scope='convG_1')  
        params = convLayer(params, 1, [3, 3],2, scope='convG_2')
        params = tf.nn.tanh(params)
        params= tf.reshape(params, [batch_size,2,2,2]) 
        return params
'''    
#Main Generator
def generator(z, mean_image):
    with tf.variable_scope('generator'):
        mean_image = tf.reshape(mean_image,[batch_size, imY,imX,3])
        #class_z = slim.fully_connected(class_z, input_dim,weights_initializer=tf.truncated_normal_initializer(stddev=1e-1))
        #class_z = tf.nn.relu(class_z)
        #class_z = tf.nn.relu(z)
        #noise = tf.add(z,class_z)
        
        #dif_params = generator_dif_paras(z)
        #dif_image = tf_diffeomorphism(mean_image,dif_params) 
        detailed_image = generator_detail(z, mean_image)
        return [detailed_image,mean_image]



In [None]:
#v1 fully conv
def discriminator(x):
    with tf.variable_scope('discriminator'):
        x = tf.reshape(x,[batch_size,imY,imX,3])
        net = slim.conv2d(x, 16, [4, 4],2, weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
                         padding = 'SAME',activation_fn=None)
        net = leakyReLU(net)
        net = slim.conv2d(x, 32, [3, 3],2, weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
                         padding = 'SAME',activation_fn=None)
        net = leakyReLU(net)
        net = slim.conv2d(net, 128, [3, 3],2,weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
                         padding = 'SAME',activation_fn=None)
        net = leakyReLU(net)
        net = slim.conv2d(net, 128, [3, 3],1,weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
                         padding = 'SAME',activation_fn=None)
        net = leakyReLU(net)
        net = slim.conv2d(net, 256, [5, 5],1,weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
                         padding = 'SAME',activation_fn=None)
        net = leakyReLU(net)
        net = slim.conv2d(net, 512, [3, 3],2,weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
                         padding = 'SAME',activation_fn=None)
        net = leakyReLU(net)
        net = slim.conv2d(net, 512, [3,3],1,weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
                         padding = 'SAME',activation_fn=None) 
        net = leakyReLU(net)
        net = slim.conv2d(net, 1024, [3,3],2,weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
                         padding = 'SAME',activation_fn=None) 
        net = leakyReLU(net)
        net = slim.conv2d(net, 1, [2,2],2,weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
                         padding = 'SAME',activation_fn=None)       
        return net

In [None]:
'''
#V2: autoencoder
def encoder(x):
    with tf.variable_scope('auto_disc'):
        x = tf.reshape(x,[batch_size,imY,imX,3])
        net = slim.conv2d(x, 64, [3, 3],1, weights_initializer=tf.contrib.layers.xavier_initializer(), scope='convD_1')
        net = leakyReLU(net)
        net = slim.conv2d(x, 128, [3, 3],1, weights_initializer=tf.contrib.layers.xavier_initializer())
        net = leakyReLU(net)
        net = slim.conv2d(net, 256, [3, 3],1,weights_initializer=tf.truncated_normal_initializer(stddev=0.1))
        net = leakyReLU(net)
        net = slim.conv2d(net, 512, [3, 3],1,weights_initializer=tf.truncated_normal_initializer(stddev=0.1))
        net = leakyReLU(net)
        net = slim.fully_connected(net, 128, activation_fn=None,weights_initializer=tf.truncated_normal_initializer(stddev=0.1))  
        return net
    
def decoder(x):
    with tf.variable_scope('auto_disc'):
        
'''    

In [None]:
#input
Mean_image = tf.placeholder(tf.float32, shape=[batch_size, imY,imX,3])
X = tf.placeholder(tf.float32, shape=[batch_size,imY,imX,3]) 
Z = tf.placeholder(tf.float32, shape=[batch_size, input_dim]) #random Noise 100
#Class_z = tf.placeholder(tf.float32, shape=[batch_size, attribute_size]) #class(label) of image

#Models
detail_image = generator(Z, Mean_image) 
D_real = discriminator(X)
D_fake = discriminator(detail_image[0])

#variables V1
theta_D = sess.graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator') 
theta_G = sess.graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'generator')


D_losses, G_losses = [],[]

D_loss = tf.reduce_mean(D_fake) - tf.reduce_mean(D_real)
#added simple L1 Loss.
G_l1_loss = tf.multiply(0.001, tf.reduce_mean(tf.abs(detail_image[0]-detail_image[1])))
G_loss = -tf.reduce_mean(D_fake)

#improved WGAN without weight clipping. Instead penalizing gradient 
alpha = tf.random_uniform(shape=[batch_size,1], minval=0.,maxval=1.)

differences = tf.reshape(detail_image[0] - X, (batch_size, output_dim))
interpolates = tf.reshape(X,(batch_size, output_dim)) + (alpha*differences)
interpolates = tf.reshape(interpolates, (batch_size, 64,64,3))
gradients = tf.gradients(discriminator(interpolates), [interpolates])[0]
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
gradient_penalty = tf.reduce_mean((slopes-1.)**2)
D_loss += LAMBDA*gradient_penalty

G_losses.append(G_loss)
D_losses.append(D_loss)

G_loss = tf.add_n(G_losses)
D_loss = tf.add_n(D_losses) 

#Solver
D_solver = (tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0.5, beta2=0.9)
            .minimize(D_loss, var_list=theta_D, colocate_gradients_with_ops=True))
G_solver = (tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0.5, beta2=0.9)
            .minimize(G_loss, var_list=theta_G, colocate_gradients_with_ops=True))


if not os.path.exists('out-celebA/'):
    os.makedirs('out-celebA/')

#initalize Variables    
sess.run(tf.global_variables_initializer())    

In [None]:
'''
logs_path = 'logs/lgan-log'

if not os.path.exists('logs/'):
    os.makedirs('logs/')
#Instantiate Tensorboard

# Create a summary to monitor cost tensor
tf.summary.scalar("D-loss", D_loss)
# Create a summary to monitor accuracy tensor
tf.summary.scalar("G-loss", G_loss)
# Merge all summaries into a single op
merged_summary_op = tf.summary.merge_all()

summary_writer = tf.summary.FileWriter(logs_path, graph=tf.get_default_graph())
'''

In [None]:
#pretrain disc:
for t in tqdm(range(50)):
    #train discriminator
    Xdata = getImageBatch()
    z = sample_z(batch_size) #get Image Batch+Labels
    mean_image = getMeanImage(batch_size) #get mean_image
    _, D_loss_curr = sess.run(
        [D_solver, D_loss],
        feed_dict={X: Xdata, Z:z, Mean_image:mean_image}
    ) 

In [None]:
i = 0
d_costs = []
g_costs = []
for it in range(100000):
    for q in range(5): #train discriminator
        Xdata = getImageBatch()
        z = sample_z(batch_size) #get Image Batch+Labels
        _, D_loss_curr = sess.run(
            [D_solver, D_loss],
            feed_dict={X: Xdata, Z:z, Mean_image:mean_image}
        )
        d_costs.append(D_loss_curr)
    z = sample_z(batch_size) #get Image Batch+Labels
    mean_image = getMeanImage(batch_size) #get mean_image
    _, G_loss_curr = sess.run(
        [G_solver, G_loss],
        feed_dict={Z:z, Mean_image:mean_image}
    )
    g_costs.append(G_loss_curr)
    if it % 100 == 0:
        print('Iter: {}; D loss: {:.4}; G_loss: {:.4}'
              .format(it, D_loss_curr, G_loss_curr))
        if it % 1000 == 0:
            z = sample_z(batch_size) #get Image Batch+Labels
            mean_image = getMeanImage(batch_size) #get mean_image
            samples = sess.run(detail_image, feed_dict={Z:z,Mean_image:mean_image})         
            samples[0] = np.reshape(samples[0], (batch_size,64,64,3))
            fig = plot(samples[0][:16], np.zeros(16),4,4)
            plt.savefig('out-celebA/{}.png'
                        .format(str(i).zfill(3)), bbox_inches='tight')
            plt.show()
            i += 1
        else:
            z = sample_z(batch_size) #get Image Batch+Labels
            mean_image = getMeanImage(batch_size) #get mean_image
            samples = sess.run(detail_image, feed_dict={Z:z,Mean_image:mean_image})         
            samples[0] = np.reshape(samples[0], (batch_size,imY,imX,3))
            fig = plot(samples[0][:4],np.zeros(4),1,4)
            plt.show()
            
plt.close(fig)

In [None]:
plt.plot(d_costs[500:_])
plt.ylabel('D_Loss')
plt.show()
plt.savefig('out-celebA/DLoss.png'
                        .format(str(i).zfill(3)), bbox_inches='tight')

plt.plot(g_costs[500:_])
plt.ylabel('G_Loss')
plt.show()
plt.savefig('out-celebA/GLoss.png'
                        .format(str(i).zfill(3)), bbox_inches='tight')