In [1]:
%matplotlib inline
import tensorflow as tf
import numpy as np
import scipy.misc
import matplotlib.pyplot as plt
import tensorflow.contrib.slim as slim

In [2]:
def data_loader(file_path , batch_size, epochs):
    filenames = tf.train.match_filenames_once(file_path)
    filename_queue = tf.train.string_input_producer(filenames, num_epochs=epochs)
    reader = tf.WholeFileReader()
    _, img_bytes = reader.read(filename_queue)
    image = tf.image.decode_jpeg(img_bytes, channels=3)
    
    return tf.train.batch([image], batch_size, dynamic_pad=True)

In [3]:
def lrelu(x, leaky=0.2):
    return tf.maximum(x, leaky * x)

In [4]:
def Unet_generator(cond_img):
    with tf.variable_scope('generator'):
        with slim.arg_scope([slim.conv2d],
                            activation_fn=lrelu,
                            kernel_size=[4,4],
                            stride=2):
            enc1 = slim.conv2d(cond_img , 64, scope='enc1')
            enc2 = slim.batch_norm(slim.conv2d(enc1 , 128, scope='enc2'))
            enc3 = slim.batch_norm(slim.conv2d(enc2 , 256, scope='enc3'))
            enc4 = slim.batch_norm(slim.conv2d(enc3 , 512, scope='enc4'))
            enc5 = slim.batch_norm(slim.conv2d(enc4 , 512, scope='enc5'))
            enc6 = slim.batch_norm(slim.conv2d(enc5 , 512, scope='enc6'))
            enc7 = slim.batch_norm(slim.conv2d(enc6 , 512, scope='enc7'))
            enc8 = slim.batch_norm(slim.conv2d(enc7 , 512, kernel_size=[2,2], scope='enc8'))
        
        with slim.arg_scope([slim.conv2d_transpose],
                            activation_fn = tf.nn.relu,
                            kernel_size=[4,4],
                            stride=2):
            dec1 = slim.batch_norm(slim.conv2d_transpose(enc8 ,512, kernel_size=[2,2], scope='dec1'))
            dec1 = tf.concat(3, [dec1, enc7] ,name='cat1')
            dec2 = slim.batch_norm(slim.conv2d_transpose(dec1 ,512, scope='dec2'))
            dec2 = tf.concat(3, [dec2, enc6],name='cat2')
            dec3 = slim.batch_norm(slim.conv2d_transpose(dec2 ,512, scope='dec3'))
            dec3 = tf.concat(3, [dec3, enc5],name='cat3')
            dec4 = slim.batch_norm(slim.conv2d_transpose(dec3 ,512, scope='dec4'))
            dec4 = tf.concat(3, [dec4, enc4],name='cat4')
            dec5 = slim.batch_norm(slim.conv2d_transpose(dec4 ,256, scope='dec5'))
            dec5 = tf.concat(3, [dec5, enc3],name='cat5')
            dec6 = slim.batch_norm(slim.conv2d_transpose(dec5 ,128, scope='dec6'))
            dec6 = tf.concat(3, [dec6, enc2],name='cat6')
            dec7 = slim.batch_norm(slim.conv2d_transpose(dec6 ,64, scope='dec7'))
            dec7 = tf.concat(3, [dec7, enc1],name='cat7')
            dec8 = slim.conv2d_transpose(dec7 ,3, activation_fn=tf.nn.tanh)
            
            return dec8

In [5]:
def discriminator(img,reuse=False):
    with tf.variable_scope('discriminator',reuse=reuse):
        with slim.arg_scope([slim.conv2d],
                            activation_fn=lrelu,
                            kernel_size=[4,4],
                            stride=2):
            x = slim.batch_norm(slim.conv2d(img , 64, scope='disc1'))
            x = slim.batch_norm(slim.conv2d(x , 128, scope='disc2'))
            x = slim.batch_norm(slim.conv2d(x , 256, scope='disc3'))
            x = slim.batch_norm(slim.conv2d(x , 512, scope='disc4'))
            
            logits = slim.fully_connected(slim.flatten(x), 1, activation_fn=None ,scope='fc')
            return tf.nn.sigmoid(logits) , logits

In [6]:
train_file_path = './facades/train/*.jpg'
batch_size = 10
epochs = 200
img_size = 256

In [7]:
batch_images = data_loader(train_file_path , batch_size , epochs)
real_img , cond_img = batch_images[:,:,:img_size,:] , batch_images[:,:,img_size:,:]

real_img.set_shape((batch_size , img_size, img_size, 3))
cond_img.set_shape((batch_size , img_size, img_size, 3))

real_img = tf.cast(real_img,tf.float32) / 127.5 -1
cond_img = tf.cast(cond_img,tf.float32) / 127.5 -1

In [8]:
fake_img = Unet_generator(cond_img)

In [9]:
real_img_cond = tf.concat(3, [real_img , cond_img])
fake_img_cond = tf.concat(3, [fake_img , cond_img])

In [10]:
p_real,real_logits = discriminator(real_img_cond)
p_fake,fake_logits = discriminator(fake_img_cond,reuse=True)

In [11]:
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(real_logits, tf.ones_like(p_real)))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(fake_logits, tf.zeros_like(p_fake)))

d_loss = d_loss_real + d_loss_fake

In [12]:
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(fake_logits, tf.ones_like(p_fake)))
l1_loss = tf.reduce_mean(tf.abs(fake_img - real_img))
g_loss = g_loss + 100 * l1_loss

In [13]:
g_params = slim.get_variables(scope='generator')
d_params = slim.get_variables(scope='discriminator')

optimizer = tf.train.AdamOptimizer(0.0001)
d_trainer = optimizer.minimize(d_loss, var_list=d_params)
g_trainer = optimizer.minimize(g_loss, var_list=g_params)

In [14]:
init_op = tf.group(tf.initialize_all_variables(), tf.initialize_local_variables())
sess = tf.Session()
sess.run(init_op)

coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)

Instructions for updating:
Use `tf.global_variables_initializer` instead.
Instructions for updating:
Use `tf.local_variables_initializer` instead.


In [15]:
train_step = 1
try:
    while not coord.should_stop():
        # Run training steps or whatever
        _ , d_loss_, _, g_loss_  = sess.run([d_trainer , d_loss , g_trainer , g_loss])
        #_ , g_loss_ = sess.run([g_trainer , g_loss])
        
        epcoh = (train_step * batch_size) / 400
              
        if (train_step * batch_size) % (400*10) == 0:
            print "epoch:%s , d_loss:%s ,g_loss:%s" % (epcoh , d_loss_ ,g_loss_)
        
        if (train_step * batch_size) % (400*50) == 0: 
            real_img_ ,cond_img_ , fake_img_ = sess.run([real_img , cond_img , fake_img])
            real_img_ = (real_img_ + 1) *127.5
            cond_img_ = (cond_img_ + 1) *127.5
            fake_img_ = (fake_img_ + 1) * 127.5
            
            for i in range(batch_size):
                file_name = 'epcoh_%s_%s.jpg' % (epcoh , i)
                canvas = np.zeros((img_size , img_size*3,3))
                canvas[: , :img_size] = real_img_[i]
                canvas[: , img_size:img_size*2] = cond_img_[i]
                canvas[: , img_size*2:img_size*3] = fake_img_[i]                
                scipy.misc.imsave(file_name, np.clip(canvas, 0, 255).astype('uint8'))
                
        train_step += 1
        
except tf.errors.OutOfRangeError:
    print 'Done training -- epoch limit reached'
    
finally:
    # When done, ask the threads to stop.
    coord.request_stop()

# Wait for threads to finish.
coord.join(threads)

epoch:10 , d_loss:1.57773 ,g_loss:35.3845
epoch:20 , d_loss:2.70714 ,g_loss:25.1553
epoch:30 , d_loss:1.70816 ,g_loss:19.685
epoch:40 , d_loss:1.84334 ,g_loss:15.8764
epoch:50 , d_loss:2.63484 ,g_loss:14.797
epoch:60 , d_loss:2.16098 ,g_loss:14.0988
epoch:70 , d_loss:2.02101 ,g_loss:12.9068
epoch:80 , d_loss:1.79013 ,g_loss:11.856
epoch:90 , d_loss:2.31743 ,g_loss:12.6805
epoch:100 , d_loss:2.46365 ,g_loss:10.4963
epoch:110 , d_loss:1.87597 ,g_loss:10.2014
epoch:120 , d_loss:2.15549 ,g_loss:9.29062
epoch:130 , d_loss:1.66803 ,g_loss:9.18493
epoch:140 , d_loss:2.12977 ,g_loss:9.40766
epoch:150 , d_loss:3.00332 ,g_loss:8.20005
epoch:160 , d_loss:2.21027 ,g_loss:8.85239
epoch:170 , d_loss:1.98511 ,g_loss:8.40414
epoch:180 , d_loss:2.71879 ,g_loss:7.79251
epoch:190 , d_loss:2.40175 ,g_loss:7.24438
Done training -- epoch limit reached


In [23]:
variables_to_save = slim.get_variables(scope="generator")

In [24]:
saver = tf.train.Saver(variables_to_save)

In [27]:
saver.save(sess, "./model/generator.ckpt")

'./model/generator.ckpt'