In [1]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
from scipy.misc import imsave
import os
import shutil
from PIL import Image
import time
import random
import sys

from layers import *
from model import *

In [2]:
to_train = True
to_test = False
to_restore = False
output_path = "./output"
check_dir = "./output/checkpoints/"

temp_check = 0

max_epoch = 1
max_images = 600

h1_size = 150
h2_size = 300
z_size = 100
sample_size = 10
save_training_images = True

In [3]:
class CycleGAN():

    def input_setup(self, mode='train'):

        ''' 
        This function basically setup variables for taking image input.

        filenames_A/filenames_B -> takes the list of all training images
        self.image_A/self.image_B -> Input image with each values ranging from [-1,1]
        '''
        
        if mode == 'train':
            filenames_A = tf.train.match_filenames_once("./maps/trainA/*.jpg")    
            self.queue_length_A = tf.size(filenames_A)
            filenames_B = tf.train.match_filenames_once("./maps/trainB/*.jpg")    
            self.queue_length_B = tf.size(filenames_B)
        elif mode == 'test':
            filenames_A = tf.train.match_filenames_once("./maps/testA/*.jpg")    
            self.queue_length_A = tf.size(filenames_A)
            filenames_B = tf.train.match_filenames_once("./maps/testB/*.jpg")    
            self.queue_length_B = tf.size(filenames_B)
            
        
        filename_queue_A = tf.train.string_input_producer(filenames_A)
        filename_queue_B = tf.train.string_input_producer(filenames_B)

        image_reader = tf.WholeFileReader()
        _, image_file_A = image_reader.read(filename_queue_A)
        _, image_file_B = image_reader.read(filename_queue_B)

        self.image_A = tf.subtract(tf.div(tf.image.resize_images(tf.image.decode_jpeg(image_file_A),[256,256]),127.5),1)
        self.image_B = tf.subtract(tf.div(tf.image.resize_images(tf.image.decode_jpeg(image_file_B),[256,256]),127.5),1)

    

    def input_read(self, sess):


        '''
        It reads the input into from the image folder.

        self.fake_images_A/self.fake_images_B -> List of generated images used for calculation of loss function of Discriminator
        self.A_input/self.B_input -> Stores all the training images in python list
        '''

        # Loading images into the tensors
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        num_files_A = sess.run(self.queue_length_A)
        num_files_B = sess.run(self.queue_length_B)

        self.fake_images_A = np.zeros((pool_size,1,img_height, img_width, img_layer))
        self.fake_images_B = np.zeros((pool_size,1,img_height, img_width, img_layer))


        self.A_input = np.zeros((max_images, batch_size, img_height, img_width, img_layer))
        self.B_input = np.zeros((max_images, batch_size, img_height, img_width, img_layer))

        for i in range(max_images): 
            image_tensor = sess.run(self.image_A)
            if(image_tensor.size == img_size*batch_size*img_layer):
                self.A_input[i] = image_tensor.reshape((batch_size,img_height, img_width, img_layer))

        for i in range(max_images):
            image_tensor = sess.run(self.image_B)
            if(image_tensor.size == img_size*batch_size*img_layer):
                self.B_input[i] = image_tensor.reshape((batch_size,img_height, img_width, img_layer))


        coord.request_stop()
        coord.join(threads)




    def model_setup(self):

        ''' This function sets up the model to train

        self.input_A/self.input_B -> Set of training images.
        self.fake_A/self.fake_B -> Generated images by corresponding generator of input_A and input_B
        self.lr -> Learning rate variable
        self.cyc_A/ self.cyc_B -> Images generated after feeding self.fake_A/self.fake_B to corresponding generator. This is use to calcualte cyclic loss
        '''

        self.input_A = tf.placeholder(tf.float32, [batch_size, img_width, img_height, img_layer], name="input_A")
        self.input_B = tf.placeholder(tf.float32, [batch_size, img_width, img_height, img_layer], name="input_B")
        
        self.fake_pool_A = tf.placeholder(tf.float32, [None, img_width, img_height, img_layer], name="fake_pool_A")
        self.fake_pool_B = tf.placeholder(tf.float32, [None, img_width, img_height, img_layer], name="fake_pool_B")

        self.global_step = tf.Variable(0, name="global_step", trainable=False)

        self.num_fake_inputs = 0

        self.lr = tf.placeholder(tf.float32, shape=[], name="lr")

        with tf.variable_scope("Model") as scope:
            self.fake_B = build_generator_resnet_9blocks(self.input_A, name="g_A")
            self.fake_A = build_generator_resnet_9blocks(self.input_B, name="g_B")
            self.rec_A = build_gen_discriminator(self.input_A, "d_A")
            self.rec_B = build_gen_discriminator(self.input_B, "d_B")

            scope.reuse_variables()

            self.fake_rec_A = build_gen_discriminator(self.fake_A, "d_A")
            self.fake_rec_B = build_gen_discriminator(self.fake_B, "d_B")
            self.cyc_A = build_generator_resnet_9blocks(self.fake_B, "g_B")
            self.cyc_B = build_generator_resnet_9blocks(self.fake_A, "g_A")

            scope.reuse_variables()

            self.fake_pool_rec_A = build_gen_discriminator(self.fake_pool_A, "d_A")
            self.fake_pool_rec_B = build_gen_discriminator(self.fake_pool_B, "d_B")

    def loss_calc(self):

        ''' In this function we are defining the variables for loss calcultions and traning model

        d_loss_A/d_loss_B -> loss for discriminator A/B
        g_loss_A/g_loss_B -> loss for generator A/B
        *_trainer -> Variaous trainer for above loss functions
        *_summ -> Summary variables for above loss functions'''

        cyc_loss = tf.reduce_mean(tf.abs(self.input_A-self.cyc_A)) + tf.reduce_mean(tf.abs(self.input_B-self.cyc_B))
        
        disc_loss_A = tf.reduce_mean(tf.squared_difference(self.fake_rec_A,1))
        disc_loss_B = tf.reduce_mean(tf.squared_difference(self.fake_rec_B,1))
        
        g_loss_A = cyc_loss*10 + disc_loss_B
        g_loss_B = cyc_loss*10 + disc_loss_A

        d_loss_A = (tf.reduce_mean(tf.square(self.fake_pool_rec_A)) + tf.reduce_mean(tf.squared_difference(self.rec_A,1)))/2.0
        d_loss_B = (tf.reduce_mean(tf.square(self.fake_pool_rec_B)) + tf.reduce_mean(tf.squared_difference(self.rec_B,1)))/2.0

        
        optimizer = tf.train.AdamOptimizer(self.lr, beta1=0.5)

        self.model_vars = tf.trainable_variables()

        d_A_vars = [var for var in self.model_vars if 'd_A' in var.name]
        g_A_vars = [var for var in self.model_vars if 'g_A' in var.name]
        d_B_vars = [var for var in self.model_vars if 'd_B' in var.name]
        g_B_vars = [var for var in self.model_vars if 'g_B' in var.name]
        
        self.d_A_trainer = optimizer.minimize(d_loss_A, var_list=d_A_vars)
        self.d_B_trainer = optimizer.minimize(d_loss_B, var_list=d_B_vars)
        self.g_A_trainer = optimizer.minimize(g_loss_A, var_list=g_A_vars)
        self.g_B_trainer = optimizer.minimize(g_loss_B, var_list=g_B_vars)

        for var in self.model_vars: print(var.name)

        #Summary variables for tensorboard

        self.g_A_loss_summ = tf.summary.scalar("g_A_loss", g_loss_A)
        self.g_B_loss_summ = tf.summary.scalar("g_B_loss", g_loss_B)
        self.d_A_loss_summ = tf.summary.scalar("d_A_loss", d_loss_A)
        self.d_B_loss_summ = tf.summary.scalar("d_B_loss", d_loss_B)

    def save_training_images(self, sess, epoch):

        if not os.path.exists("./output/imgs"):
            os.makedirs("./output/imgs")

        for i in range(0,10):
            fake_A_temp, fake_B_temp, cyc_A_temp, cyc_B_temp = sess.run([self.fake_A, self.fake_B, self.cyc_A, self.cyc_B],feed_dict={self.input_A:self.A_input[i], self.input_B:self.B_input[i]})
            imsave("./output/imgs/fakeB_"+ str(epoch) + "_" + str(i)+".jpg",((fake_A_temp[0]+1)*127.5).astype(np.uint8))
            imsave("./output/imgs/fakeA_"+ str(epoch) + "_" + str(i)+".jpg",((fake_B_temp[0]+1)*127.5).astype(np.uint8))
            imsave("./output/imgs/cycA_"+ str(epoch) + "_" + str(i)+".jpg",((cyc_A_temp[0]+1)*127.5).astype(np.uint8))
            imsave("./output/imgs/cycB_"+ str(epoch) + "_" + str(i)+".jpg",((cyc_B_temp[0]+1)*127.5).astype(np.uint8))
            imsave("./output/imgs/inputA_"+ str(epoch) + "_" + str(i)+".jpg",((self.A_input[i][0]+1)*127.5).astype(np.uint8))
            imsave("./output/imgs/inputB_"+ str(epoch) + "_" + str(i)+".jpg",((self.B_input[i][0]+1)*127.5).astype(np.uint8))

    def fake_image_pool(self, num_fakes, fake, fake_pool):
        ''' This function saves the generated image to corresponding pool of images.
        In starting. It keeps on feeling the pool till it is full and then randomly selects an
        already stored image and replace it with new one.'''

        if(num_fakes < pool_size):
            fake_pool[num_fakes] = fake
            return fake
        else :
            p = random.random()
            if p > 0.5:
                random_id = random.randint(0,pool_size-1)
                temp = fake_pool[random_id]
                fake_pool[random_id] = fake
                return temp
            else :
                return fake


    def train(self):


        ''' Training Function '''


        # Load Dataset from the dataset folder
        self.input_setup(mode='train')  

        #Build the network
        self.model_setup()

        #Loss function calculations
        self.loss_calc()
      
        # Initializing the global variables
        init = (tf.global_variables_initializer(), tf.local_variables_initializer())
        saver = tf.train.Saver()     

        with tf.Session() as sess:
            sess.run(init)

            #Read input to nd array
            self.input_read(sess)

            #Restore the model to run the model from last checkpoint
            if to_restore:
                chkpt_fname = tf.train.latest_checkpoint(check_dir)
                saver.restore(sess, chkpt_fname)

            writer = tf.summary.FileWriter("./output/2")

            if not os.path.exists(check_dir):
                os.makedirs(check_dir)

            # Training Loop
            for epoch in range(1):                
                print ("In the epoch ", epoch)
                saver.save(sess,os.path.join(check_dir,"cyclegan"),global_step=epoch)

                # Dealing with the learning rate as per the epoch number
                if(epoch < 100) :
                    curr_lr = 0.0002
                else:
                    curr_lr = 0.0002 - 0.0002*(epoch-100)/100

                if(save_training_images):
                    self.save_training_images(sess, epoch)

                # sys.exit()

                for ptr in range(0,max_images):
                    print("In the iteration ",ptr)
                    print("Starting",time.time()*1000.0)

                    # Optimizing the G_A network

                    _, fake_B_temp, summary_str = sess.run([self.g_A_trainer, self.fake_B, self.g_A_loss_summ],feed_dict={self.input_A:self.A_input[ptr], self.input_B:self.B_input[ptr], self.lr:curr_lr})
                    
                    writer.add_summary(summary_str, epoch*max_images + ptr)                    
                    fake_B_temp1 = self.fake_image_pool(self.num_fake_inputs, fake_B_temp, self.fake_images_B)
                    
                    # Optimizing the D_B network
                    _, summary_str = sess.run([self.d_B_trainer, self.d_B_loss_summ],feed_dict={self.input_A:self.A_input[ptr], self.input_B:self.B_input[ptr], self.lr:curr_lr, self.fake_pool_B:fake_B_temp1})
                    writer.add_summary(summary_str, epoch*max_images + ptr)
                    
                    
                    # Optimizing the G_B network
                    _, fake_A_temp, summary_str = sess.run([self.g_B_trainer, self.fake_A, self.g_B_loss_summ],feed_dict={self.input_A:self.A_input[ptr], self.input_B:self.B_input[ptr], self.lr:curr_lr})

                    writer.add_summary(summary_str, epoch*max_images + ptr)
                    
                    
                    fake_A_temp1 = self.fake_image_pool(self.num_fake_inputs, fake_A_temp, self.fake_images_A)

                    # Optimizing the D_A network
                    _, summary_str = sess.run([self.d_A_trainer, self.d_A_loss_summ],feed_dict={self.input_A:self.A_input[ptr], self.input_B:self.B_input[ptr], self.lr:curr_lr, self.fake_pool_A:fake_A_temp1})

                    writer.add_summary(summary_str, epoch*max_images + ptr)
                    
                    self.num_fake_inputs+=1
            
                        

                sess.run(tf.assign(self.global_step, epoch + 1))

            writer.add_graph(sess.graph)

    def test(self):


        ''' Testing Function'''

        print("Testing the results")
        tf.reset_default_graph()
        self.input_setup(mode='test')

        self.model_setup()
        saver = tf.train.Saver()
        init = (tf.global_variables_initializer(), tf.local_variables_initializer())

        with tf.Session() as sess:

            sess.run(init)

            self.input_read(sess)

            chkpt_fname = tf.train.latest_checkpoint(check_dir)
            saver.restore(sess, chkpt_fname)

            if not os.path.exists("./output/imgs/test/"):
                os.makedirs("./output/imgs/test/")            

            for i in range(0,100):
                fake_A_temp, fake_B_temp = sess.run([self.fake_A, self.fake_B],feed_dict={self.input_A:self.A_input[i], self.input_B:self.B_input[i]})
                imsave("./output/imgs/test/fakeB_"+str(i)+".jpg",((fake_A_temp[0]+1)*127.5).astype(np.uint8))
                imsave("./output/imgs/test/fakeA_"+str(i)+".jpg",((fake_B_temp[0]+1)*127.5).astype(np.uint8))
                imsave("./output/imgs/test/inputA_"+str(i)+".jpg",((self.A_input[i][0]+1)*127.5).astype(np.uint8))
                imsave("./output/imgs/test/inputB_"+str(i)+".jpg",((self.B_input[i][0]+1)*127.5).astype(np.uint8))


In [None]:
model = CycleGAN()
model.train()


Instructions for updating:
Queue-based input pipelines have been replaced by `tf.data`. Use `tf.data.Dataset.from_tensor_slices(string_tensor).shuffle(tf.shape(input_tensor, out_type=tf.int64)[0]).repeat(num_epochs)`. If `shuffle=False`, omit the `.shuffle(...)`.
Instructions for updating:
Queue-based input pipelines have been replaced by `tf.data`. Use `tf.data.Dataset.from_tensor_slices(input_tensor).shuffle(tf.shape(input_tensor, out_type=tf.int64)[0]).repeat(num_epochs)`. If `shuffle=False`, omit the `.shuffle(...)`.
Instructions for updating:
Queue-based input pipelines have been replaced by `tf.data`. Use `tf.data.Dataset.from_tensors(tensor).repeat(num_epochs)`.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
Instructions for updating:
Queue-based input pipelines have been replaced by `tf.data`. Use `tf.data.Dataset.map(tf.read_file)`.
Model/g_A/c1/Conv/weights:0

Instructions for updating:
To construct input pipelines, use the `tf.data` module.
In the epoch  0


`imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imwrite`` instead.
`imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imwrite`` instead.
`imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imwrite`` instead.
`imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imwrite`` instead.
`imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imwrite`` instead.
`imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imwrite`` instead.


In the iteration  0
Starting 1555076189706.056
In the iteration  1
Starting 1555076338214.9229
In the iteration  2
Starting 1555076368674.821
In the iteration  3
Starting 1555076395075.777
In the iteration  4
Starting 1555076420154.1528
In the iteration  5
Starting 1555076455106.498
In the iteration  6
Starting 1555076486882.574
In the iteration  7
Starting 1555076517777.314
In the iteration  8
Starting 1555076548785.961
In the iteration  9
Starting 1555076580363.172
In the iteration  10
Starting 1555076611301.406
In the iteration  11
Starting 1555076642772.0059
In the iteration  12
Starting 1555076675962.43
In the iteration  13
Starting 1555076708824.759
In the iteration  14
Starting 1555076740402.145
In the iteration  15
Starting 1555076773062.061
In the iteration  16
Starting 1555076804457.106
In the iteration  17
Starting 1555076835698.669
In the iteration  18
Starting 1555076867625.419
In the iteration  19
Starting 1555076898907.158
In the iteration  20
Starting 1555076929602.751


In the iteration  169
Starting 1555081728471.872
In the iteration  170
Starting 1555081761616.4648
In the iteration  171
Starting 1555081796267.427
In the iteration  172
Starting 1555081825361.135
In the iteration  173
Starting 1555081852721.3042
In the iteration  174
Starting 1555081879835.202
In the iteration  175
Starting 1555081907353.454
In the iteration  176
Starting 1555081934874.0698
In the iteration  177
Starting 1555081969744.595
In the iteration  178
Starting 1555082018305.133
In the iteration  179
Starting 1555082067192.98
In the iteration  180
Starting 1555082121772.3232
In the iteration  181
Starting 1555082191262.2341
In the iteration  182
Starting 1555082263295.403
In the iteration  183
Starting 1555082333456.523
In the iteration  184
Starting 1555082405451.231
In the iteration  185
Starting 1555082480882.932
In the iteration  186
Starting 1555082553111.04
In the iteration  187
Starting 1555082647880.358
In the iteration  188
Starting 1555082901049.259
In the iteration 

In the iteration  336
Starting 1555088446904.012
In the iteration  337
Starting 1555088474800.009
In the iteration  338
Starting 1555088503202.5442
In the iteration  339
Starting 1555088530546.201
In the iteration  340
Starting 1555088557811.6758
In the iteration  341
Starting 1555088585104.702
In the iteration  342
Starting 1555088612412.2568
In the iteration  343
Starting 1555088639633.173
In the iteration  344
Starting 1555088666648.545
In the iteration  345
Starting 1555088693951.05
In the iteration  346
Starting 1555088721518.021
In the iteration  347
Starting 1555088748774.03
In the iteration  348
Starting 1555088776440.675
In the iteration  349
Starting 1555088803850.719
In the iteration  350
Starting 1555088831586.158
In the iteration  351
Starting 1555088861925.589
In the iteration  352
Starting 1555088889303.148
In the iteration  353
Starting 1555088916573.87
In the iteration  354
Starting 1555088944105.865
In the iteration  355
Starting 1555088971565.58
In the iteration  356

In the iteration  503
Starting 1555093778702.865
In the iteration  504
Starting 1555093810528.599
In the iteration  505
Starting 1555093848626.2651
In the iteration  506
Starting 1555093880685.987
In the iteration  507
Starting 1555093914410.625
In the iteration  508
Starting 1555093948611.984
In the iteration  509
Starting 1555093983366.4639
In the iteration  510
Starting 1555094016463.48
In the iteration  511
Starting 1555094049744.171
In the iteration  512
Starting 1555094082712.1611
In the iteration  513
Starting 1555094117993.534
In the iteration  514
Starting 1555094156832.214
In the iteration  515
Starting 1555094193212.513
In the iteration  516
Starting 1555094228029.57
In the iteration  517
Starting 1555094268308.227
In the iteration  518
Starting 1555094302906.789
In the iteration  519
Starting 1555094359238.496
In the iteration  520
Starting 1555094408122.209
In the iteration  521
Starting 1555094451906.0747


In [None]:
model.test()