In [1]:
import scipy as sp
import numpy as np
import os 
import pickle

import tensorflow as tf
import matplotlib.pyplot as plt
from datetime import datetime

from architectures.cycleGAN import *

In [2]:
import tensorflow as tf
tf.__version__

'1.3.0'

In [3]:
# some constants

LEARNING_RATE_D = 0.0001
LEARNING_RATE_G = 0.0001
BETA1 = 0.5
BATCH_SIZE = 128
EPOCHS = 1
SAVE_SAMPLE_PERIOD = 100
task='TRAIN'
#task='TEST'
SEED=1

PATH = 'cycleGAN_test'

In [7]:
trunc_normal= tf.truncated_normal_initializer(stddev=0.02)
normal = tf.random_normal_initializer(stddev=0.02)

uniform = tf.random_uniform_initializer()
glorot_uniform = tf.glorot_uniform_initializer()
glorot_normal = tf.glorot_normal_initializer()

global d_sizes, g_sizes

d_sizes_A = {
         'conv_layers': [(2, 6, 2, False, 1, lrelu, trunc_normal),
                         (64, 6, 2, True, 1, lrelu, trunc_normal)],
         'dense_layers': [(1024, True, 1, lrelu ,normal)],
         'readout_layer_w_init':normal,
}

d_sizes_B = {
         'conv_layers': [(2, 6, 2, False, 1, lrelu, trunc_normal),
                         (64, 6, 2, True, 1, lrelu, trunc_normal)],
         'dense_layers': [(1024, True, 1, lrelu ,normal)],
         'readout_layer_w_init':normal,
}

g_sizes_A={
         'conv_layer_0':[(4, 8, 2, False, 1, tf.nn.relu, normal)],
         'conv_layer_1':[(8, 8, 2, False, 1, tf.nn.relu, normal)],
         'convblock_layer_0':[(16, 4, 2, True, 1, tf.nn.relu, normal),
                                (16, 4, 1, False, 1, tf.nn.relu, normal),
                                (16, 4, 2, True, 1, tf.nn.relu, normal)],
         'convblock_shortcut_layer_0':[(16, 1, 4, False, 1, normal)],
         'deconv_layer_0':[(8, 8, 2, False, 1, tf.nn.relu, normal)],
         'deconv_layer_1':[(3, 8, 2, False, 1, tf.nn.relu, normal)],
         
        
         
}

g_sizes_B={
         'conv_layer_0':[(4, 8, 2, False, 1, tf.nn.relu, normal)],
         'conv_layer_1':[(8, 8, 2, False, 1, tf.nn.relu, normal)],
         'convblock_layer_0':[(16, 4, 2, True, 1, tf.nn.relu, normal),
                                (16, 4, 1, False, 1, tf.nn.relu, normal),
                                (16, 4, 2, True, 1, tf.nn.relu, normal)],
         'convblock_shortcut_layer_0':[(16, 1, 4, False, 1, normal)],
         'deconv_layer_0':[(8, 8, 2, False, 1, tf.nn.relu, normal)],
         'deconv_layer_1':[(3, 8, 2, False, 1, tf.nn.relu, normal)],
        
         
}



In [8]:
def Horse2Zebra():
    
    train_A = np.array(
    [plt.imread("horse2zebra/trainA/"+filename) for filename in os.listdir("horse2zebra/trainA")]
    )
    
    train_B = np.array(
    [plt.imread("./horse2zebra/trainB/"+filename) for filename in os.listdir("./horse2zebra/trainB")]
    )
    
    m = np.minimum(train_A.shape[0],train_B.shape[0])
    
    _, n_H, n_W, n_C = train_A.shape
    
    X_train_A = train_A[0:m]
    X_train_B = train_B[0:m]
    
    tf.reset_default_graph()

    gan = cycleGAN(n_H, n_W, n_C, d_sizes_A, d_sizes_B, g_sizes_A, g_sizes_B,
                   lr_g=LEARNING_RATE_G, lr_d=LEARNING_RATE_D,beta1=BETA1,
                  batch_size=BATCH_SIZE, epochs=EPOCHS,
                  save_sample=SAVE_SAMPLE_PERIOD, path=PATH, seed=SEED)
    
    vars_to_train= tf.trainable_variables()
    
    if task == 'TRAIN':
        init_op = tf.global_variables_initializer()
        
    if task == 'TEST':
        vars_all = tf.global_variables()
        vars_to_init = list(set(vars_all)-set(vars_to_train))
        init_op = tf.variables_initializer(vars_to_init)
        
    # Add ops to save and restore all the variables.
    saver = tf.train.Saver()    
        
    with tf.Session() as sess:
        
        sess.run(init_op)

        if task=='TRAIN':
            print('\n Training...')
            
            if os.path.exists(PATH+'/'+PATH+'.ckpt.index'):
                saver.restore(sess,PATH+'/'+PATH+'.ckpt')
                print('Model restored.')
            
            gan.set_session(sess)
            gan.fit(X_train_A, X_train_B)
            
            save_path = saver.save(sess, PATH+'/'+PATH+'.ckpt')
            print("Model saved in path: %s" % save_path)
        
        if task=='TEST':
            print('\n Evaluate model on test set...')
            saver.restore(sess,PATH+'/'+PATH+'.ckpt')
            print('Model restored.')
            
            gan.set_session(sess) 
            
        done = True
        while not done:
            
            
            j = np.random.choice(len(X_train_A))
            true_img = X_train_A[j]
            sample_img = gan.get_sample(true_img)
            
            plt.subplot(1,2,1)
            plt.imshow(true_img.reshape(n_H,n_W),cmap='gray')
            plt.subplot(1,2,2)
            plt.imshow(sample_img.reshape(n_H,n_W),cmap='gray')
            
            fig=plt.gcf()
            fig.set_size_inches(5,8)
            plt.savefig(PATH+'/sample_{0}_at_iter_{1}.png'.format(j, total_iters),dpi=300)
            
            ans = input("Generate another?")
            if ans and ans[0] in ('n' or 'N'):
                done = True

In [9]:
if __name__=='__main__':

    if task == 'TRAIN':
        if not os.path.exists(PATH):
            os.mkdir(PATH)
    
        elif os.path.exists(PATH):
            if os.path.exists(PATH+'/checkpoint'):
                ans = input('A previous checkpoint already exists, choose the action to perform \n \n 1) Overwrite the current model saved at '+PATH+'/checkpoint \n 2) Start training a new model \n 3) Restore and continue training the previous model \n ')
                
                if ans == '1':
                    print('Overwriting existing model in '+PATH)
                    for file in os.listdir(PATH):
                        file_path = os.path.join(PATH, file)
                        try:
                            if os.path.isfile(file_path):
                                os.unlink(file_path)
                            #elif os.path.isdir(file_path): shutil.rmtree(file_path)
                        except Exception as e:
                            print(e)
                            
                elif ans == '2':
                    PATH = input('Specify the name of the model, a new directory will be created.\n')
                    os.mkdir(PATH)    
        
        Horse2Zebra()
   
    elif task == 'TEST': 
        if not os.path.exists(PATH+'/checkpoint'):
            print('No checkpoint to test')
        else:
            Horse2Zebra()

Convolutional Network architecture detected
Convolutional Network architecture detected
Discriminator_A
Convolution
Input for convolution shape  (?, 256, 256, 3)
(262144, 1024)
Logits shape (?, 1)
Generator_A_to_B
Input for deconvolution shape (?, 256, 256, 3)
Deconvoluted output shape (?, 256, 256, 3)
Discriminator_B
Convolution
Input for convolution shape  (?, 256, 256, 3)
(262144, 1024)
Logits shape (?, 1)
Generator_B_to_A
Input for deconvolution shape (?, 256, 256, 3)
Deconvoluted output shape (?, 256, 256, 3)
Discriminator_B
Convolution
Input for convolution shape  (?, 256, 256, 3)
(262144, 1024)
Logits shape (?, 1)
Generator_B_to_A
Input for deconvolution shape (?, 256, 256, 3)
Deconvoluted output shape (?, 256, 256, 3)
Discriminator_A
Convolution
Input for convolution shape  (?, 256, 256, 3)
(262144, 1024)
Logits shape (?, 1)
Generator_A_to_B
Input for deconvolution shape (?, 256, 256, 3)
Deconvoluted output shape (?, 256, 256, 3)
Generator_A_to_B
Input for deconvolution shape (

KeyboardInterrupt: 