In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import loadmat
import cv2

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

In [None]:
def load_svhn_data():
    path = "./dataset/train_32x32.mat"
    data = loadmat(path)
    training_x = data['X']
    training_y = data['y']
    
    transposed_X_train = np.transpose(training_x, (3, 0, 1, 2))
    #print("transposed_X_train's shape : ",np.shape(transposed_X_train))
    return transposed_X_train, training_y

def make_batch(data, batch_size):
    data_length = len(data)
    #print("data_length :", data_length)
    index = np.arange(0, data_length)
    np.random.shuffle(index)
    index = index[:batch_size]
    #print("index's shape : ", np.shape(index))
    #print("index : ", index)
    shuffled_data = [data[i] for i in index]
    
    return shuffled_data

def load_mnist_batch_data(batch_size):
    mnist_x_batch, mnist_y_batch = mnist.train.next_batch(batch_size)
    reshaped_x_batch = []
    
    for i in range(batch_size):
        reshaped_x_batch.append(np.reshape(mnist_x_batch[i], (28, 28)))
    resized_x_batch = []
    
    for i in range(batch_size):
        resized_x_batch.append(cv2.resize(reshaped_x_batch[i], (32, 32)))

    output = np.reshape(resized_x_batch, (batch_size, 32, 32, 1))
    return output

def data_normalize(data):
    numerator = data - np.min(data, 0)
    #denominator = 255
    denominator = np.max(data, 0) - np.min(data, 0)
    
    return (numerator * (denominator))


def color_img_normalize(data):
    batch_size = len(data)
    row_size = len(data[0])
    col_size = len(data[0][0])
    channel_size = 3
    tmp = np.zeros((batch_size, row_size, col_size, channel_size), dtype=np.float32)
    for i in range(batch_size):
        tmp[i] = data[i]/255.0
    return tmp

In [None]:
class model():
    def __init__(self, batch_size, d_lr, g_lr, epochs):
        self.batch_size = batch_size
        self.d_lr = d_lr
        self.g_lr = g_lr
        self.epochs = epochs
        self.img_width = 32
        self.img_height = 32
        self.img_channel = 3
        self.mnist_channel = 1
        self.src_domain_img = tf.placeholder(dtype = tf.float32, shape = [None, 32, 32, 3])
        self.trg_domain_img = tf.placeholder(dtype = tf.float32, shape = [None, 32, 32, 1])
        self.model()
        
    def feature_extractor(self, _input):
        if(np.shape(_input)[3] == 1):
            _input = tf.image.grayscale_to_rgb(_input)
        with tf.variable_scope("feature_extractor", reuse = tf.AUTO_REUSE):
            net = tf.layers.conv2d(inputs = _input, filters = 64, kernel_size = [3, 3], padding = "SAME", strides = (2, 2), kernel_initializer=tf.contrib.layers.xavier_initializer())
            net = tf.layers.batch_normalization(net, training=True, momentum=0.95)
            net = tf.nn.relu(net)

            net = tf.layers.conv2d(inputs = net, filters = 128, kernel_size = [3, 3], padding = "SAME", strides = (2, 2), kernel_initializer=tf.contrib.layers.xavier_initializer())
            net = tf.layers.batch_normalization(net, training=True, momentum=0.95)
            net = tf.nn.relu(net)

            net = tf.layers.conv2d(inputs = net, filters = 256, kernel_size = [3, 3], padding = "SAME", strides = (2, 2), kernel_initializer=tf.contrib.layers.xavier_initializer())
            net = tf.layers.batch_normalization(net, training=True, momentum=0.95)
            net = tf.nn.relu(net)
            
            net = tf.layers.conv2d(inputs = net, filters = 512, kernel_size = [3, 3], padding = "SAME", strides = (1, 1), kernel_initializer=tf.contrib.layers.xavier_initializer())
            net = tf.layers.batch_normalization(net, training=True, momentum=0.95)
            net = tf.nn.relu(net)
            
            net = tf.layers.conv2d(inputs = net, filters = 256, kernel_size = [3, 3], padding = "SAME", strides = (1, 1), kernel_initializer=tf.contrib.layers.xavier_initializer())
            net = tf.layers.batch_normalization(net, training=True, momentum=0.95)
            net = tf.nn.relu(net)

            net = tf.layers.conv2d(inputs = net, filters = 128, kernel_size = [4, 4], padding = "VALID", strides = (2, 2), kernel_initializer=tf.contrib.layers.xavier_initializer())
            net = tf.layers.batch_normalization(net, training=True, momentum=0.95)
            net = tf.nn.tanh(net)

            return net
        
    def discriminator(self, _input):
        with tf.variable_scope("discriminator", reuse = tf.AUTO_REUSE):
            net = tf.layers.conv2d(inputs = _input, filters = 128, kernel_size = [3, 3], padding = "SAME", strides = (2, 2), kernel_initializer=tf.contrib.layers.xavier_initializer())
            net = tf.layers.batch_normalization(net, training=True, momentum=0.95)
            net = tf.nn.relu(net)

            net = tf.layers.conv2d(inputs = net, filters = 256, kernel_size = [3, 3], padding = "SAME", strides = (2, 2), kernel_initializer=tf.contrib.layers.xavier_initializer())
            net = tf.layers.batch_normalization(net, training=True, momentum=0.95)
            net = tf.nn.relu(net)

            net = tf.layers.conv2d(inputs = net, filters = 512, kernel_size = [3, 3], padding = "SAME", strides = (2, 2), kernel_initializer=tf.contrib.layers.xavier_initializer())
            net = tf.layers.batch_normalization(net, training=True, momentum=0.95)
            net = tf.nn.relu(net)

            net = tf.layers.conv2d(inputs = net, filters = 1, kernel_size = [4, 4], padding = "VALID", strides = (2, 2), kernel_initializer=tf.contrib.layers.xavier_initializer())
            net = tf.layers.flatten(net)
            return tf.nn.sigmoid(net), net
    
    def generator(self, _input):
        with tf.variable_scope("generator", reuse = tf.AUTO_REUSE):
            net = tf.layers.conv2d_transpose(inputs = _input, filters = 512, kernel_size = [4, 4], padding = "VALID", strides = (2, 2), kernel_initializer=tf.contrib.layers.xavier_initializer())
            net = tf.layers.batch_normalization(net, training=True, momentum=0.95)
            net = tf.nn.relu(net)
            
            net = tf.layers.conv2d_transpose(inputs = net, filters = 512, kernel_size = [3, 3], padding = "SAME", strides = (1, 1), kernel_initializer=tf.contrib.layers.xavier_initializer())
            net = tf.layers.batch_normalization(net, training=True, momentum=0.95)
            net = tf.nn.relu(net)
            
            net = tf.layers.conv2d_transpose(inputs = net, filters = 512, kernel_size = [3, 3], padding = "SAME", strides = (1, 1), kernel_initializer=tf.contrib.layers.xavier_initializer())
            net = tf.layers.batch_normalization(net, training=True, momentum=0.95)
            net = tf.nn.relu(net)
            

            net = tf.layers.conv2d_transpose(inputs = net, filters = 256, kernel_size = [3, 3], padding = "SAME", strides = (2, 2), kernel_initializer=tf.contrib.layers.xavier_initializer())
            net = tf.layers.batch_normalization(net, training=True, momentum=0.95)
            net = tf.nn.relu(net)
            
            net = tf.layers.conv2d_transpose(inputs = net, filters = 128, kernel_size = [3, 3], padding = "SAME", strides = (2, 2), kernel_initializer=tf.contrib.layers.xavier_initializer())
            net = tf.layers.batch_normalization(net, training=True, momentum=0.95)
            net = tf.nn.relu(net)
            
            net = tf.layers.conv2d_transpose(inputs = net, filters = self.mnist_channel, kernel_size = [3, 3], padding = "SAME", strides = (2, 2), kernel_initializer=tf.contrib.layers.xavier_initializer())
            net = tf.nn.tanh(net)
            
            return net
    
    def model(self):
        #input -> self.src_domain_img, self.trg_domain_img
        #source_domain 입장
        self.src_fx = self.feature_extractor(self.src_domain_img)
        self.src_gfx = self.generator(self.src_fx)
        self.src_fgfx = self.feature_extractor(self.src_gfx)
        self.D1, self.D1_logits = self.discriminator(self.src_gfx)
        
        #target_domain 입장
        self.trg_fx = self.feature_extractor(self.trg_domain_img)
        self.trg_gfx = self.generator(self.trg_fx)
        self.D2, self.D2_logits = self.discriminator(self.trg_gfx)
        
        self.D3, self.D3_logits = self.discriminator(self.trg_domain_img)
        
        self.L_D1 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D1_logits, labels=tf.zeros_like(self.D1_logits)))
        self.L_D2 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D2_logits, labels=tf.zeros_like(self.D2_logits)))
        self.L_D3 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D3_logits, labels=tf.ones_like(self.D3_logits)))
        
        self.Loss_D = self.L_D1 + self.L_D2 + self.L_D3
        
        self.L_G1 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D1_logits, labels = tf.ones_like(self.D1_logits)))
        self.L_G2 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D2_logits, labels = tf.ones_like(self.D2_logits)))
        
        self.L_Gang = self.L_G1 + self.L_G2
        
        self.L_const = tf.reduce_mean(tf.square(self.src_fx - self.src_fgfx))*15
        self.L_tid = tf.reduce_mean(tf.square(self.trg_domain_img - self.trg_gfx))
        
        self.L_tv = tf.reduce_mean(tf.squared_difference(self.trg_gfx[:,1:,:,:], self.trg_gfx[:,:-1,:,:])) + tf.reduce_mean(tf.squared_difference(self.trg_gfx[:,:,1:,:], self.trg_gfx[:,:,:-1,:]))+ \
                    tf.reduce_mean(tf.squared_difference(self.src_gfx[:,1:,:,:], self.src_gfx[:,:-1,:,:])) + tf.reduce_mean(tf.squared_difference(self.src_gfx[:,:,1:,:], self.src_gfx[:,:,:-1,:]))
        
        alpha = 15
        beta = 15
        gamma = 1
        self.Loss_G = self.L_Gang + alpha*self.L_const + beta*self.L_tid + gamma*self.L_tv
        
    def train(self):
        #학습할 변수 load
        trainable_variables = tf.trainable_variables()
        self.d_var = [var for var in trainable_variables if 'discriminator' in var.name]
        self.g_var = [var for var in trainable_variables if 'generator' in var.name]
        self.g_var.extend([var for var in trainable_variables if 'feature_extractor' in var.name])
        
        #model load
        
        #starter_learning_rate_g = self.g_lr
        #starter_learning_rate_d = self.d_lr
        #global_step = tf.Variable(0, trainable=False)
        #self.g_lr_dacay = tf.train.exponential_decay(starter_learning_rate_g, global_step, 900, 0.6, staircase=True)#900 step마다 learning rate 감소
        #self.d_lr_dacay = tf.train.exponential_decay(starter_learning_rate_d, global_step, 900, 0.6, staircase=True)#900 step마다 learning rate 감소
        
        
        self.optimize_d = tf.train.AdamOptimizer(learning_rate = self.d_lr).minimize(self.Loss_D, var_list = self.d_var)#, global_step = global_step)
        self.optimize_g = tf.train.AdamOptimizer(learning_rate = self.g_lr).minimize(self.Loss_G, var_list = self.g_var)#, global_step = global_step)
        
        #data load
        source_x, _ = load_svhn_data()
        #source_x = color_img_normalize(source_x)
        with tf.Session() as sess:
            
            wirter = tf.summary.FileWriter('./mygraph',sess.graph)
            
            sess.run(tf.global_variables_initializer())
            batch_size = self.batch_size
            num_of_data = 72256#svhn data size
            total_batch = int(num_of_data / batch_size)
            
            #weight load
            SAVE_PATH = "./Weight/Weight.ckpt"
            saver = tf.train.Saver()
            
            try:
                print("Existed Weight.ckpt load")
                saver.restore(sess, SAVE_PATH)
            except:
                print("No Weight exist")
                print("Start Training with newly made Weight.ckpt")
            
            print("train start")
            loss_gen_list = []
            loss_dis_list = []
            loss_fea_list = []
            
            for epoch in range(self.epochs):
                batch_loss_g = []
                batch_loss_d = []
                print("---------------------------------------------------------")
                #print("learning rate : ", self.lr.eval())
                for iteration in range(total_batch):#(total_batch):
                    #data load
                    #print("1")
                    batch_source_x = make_batch(source_x, self.batch_size)
                    #print("2")
                    batch_source_x = color_img_normalize(batch_source_x)
                    #print("3")
                    batch_target_x = load_mnist_batch_data(self.batch_size)
                    #print("4")
                    #batch_target_x = data_normalize(batch_target_x)
                    
                    #train start
                    
                    _, discriminator_loss = sess.run([self.optimize_d, self.Loss_D], feed_dict = {self.src_domain_img : batch_source_x, self.trg_domain_img : batch_target_x})
                    _, generator_loss = sess.run([self.optimize_g, self.Loss_G], feed_dict = {self.src_domain_img : batch_source_x, self.trg_domain_img : batch_target_x})
                    #_, feature_extractor_loss = sess.run([self.optimize_f, self.loss_f], feed_dict = {self.source_domain_img : source_train_data, self.target_domain_img : target_train_data})
                    #print(discriminator_loss, generator_loss)
                    loss_dis_list.append(discriminator_loss)
                    loss_gen_list.append(generator_loss)
                    #loss_fea_list.append(feature_extractor_loss)

                output = sess.run(self.src_gfx, feed_dict = {self.src_domain_img : batch_source_x})
                
                dis_loss_sum = sum(loss_dis_list[-batch_size:])
                gen_loss_sum = sum(loss_gen_list[-batch_size:])                
                
                print(epoch+1, "epoch average loss_d : ", dis_loss_sum/batch_size)                
                print(epoch+1, "epoch average loss_g : ", gen_loss_sum/batch_size)
                
                saver.save(sess, SAVE_PATH)
                for i in range(self.batch_size):
                    for row in range(self.img_width):
                        for col in range(self.img_height):
                            for channel in range(self.mnist_channel):
                                output[i][row][col][channel] = int(output[i][row][col][channel])
                output = np.reshape(output, (self.batch_size, self.img_width, self.img_height))
                plt.imshow(output[0])
                plt.show()
                if(epoch%1== 0):

                    fig, ax = plt.subplots(1, self.batch_size, figsize=(self.batch_size, 2))
                    for i in range(self.batch_size):
                        ax[i].set_axis_off()
                        ax[i].imshow(batch_source_x[i])

                    fig2, ax2 = plt.subplots(1, self.batch_size, figsize=(self.batch_size, 2))
                    for i in range(self.batch_size):
                        ax2[i].set_axis_off()
                        ax2[i].imshow(output[i])
                    plt.show()

                plt.plot(loss_dis_list)
                plt.title("discriminator's loss")
                plt.show()

                plt.plot(loss_gen_list)
                plt.title("generator's loss")
                plt.show()
                
                batch_source_x, batch_target_x, d_src_gfx, d_trg_gfx, d_trg_img, output0, output1 = self.test()
                index = 0
                plt.imshow(batch_source_x[index])
                plt.show()
                
                batch_target_x = np.reshape(batch_target_x, (self.batch_size, 32, 32))
                plt.imshow(batch_target_x[index])
                plt.show()
                
                print("source domain generate")
                tmp = np.reshape(output0, (self.batch_size, 32, 32))
                for i in range(4):
                    plt.imshow(tmp[i])
                    plt.show()
                
                print("target domain generate")
                tmp = np.reshape(output1, (self.batch_size, 32, 32))
                for i in range(4):
                    plt.imshow(tmp[i])
                    plt.show()
        
    def test(self):
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            #weight load
            SAVE_PATH = "./Weight/Weight.ckpt"
            saver = tf.train.Saver()
            try:
                print("Existed Weight.ckpt load")
                saver.restore(sess, SAVE_PATH)
            except:
                print("No Weight exist")
                print("Start Training with newly made Weight.ckpt")

            #data load
            source_x, _ = load_svhn_data()
            batch_source_x = make_batch(source_x, self.batch_size)

            batch_target_x = load_mnist_batch_data(self.batch_size)
            batch_target_x = data_normalize(batch_target_x)

            #discriminator 성능 test
            d_src_gfx = sess.run(self.D1, feed_dict = {self.src_domain_img : batch_source_x})
            d_trg_gfx = sess.run(self.D2, feed_dict = {self.trg_domain_img : batch_target_x})
            d_trg_img = sess.run(self.D3, feed_dict = {self.trg_domain_img : batch_target_x})

            
            #generator 성능 test
            output0 = sess.run(self.src_gfx, feed_dict = {self.src_domain_img : batch_source_x})
            
            output1 = sess.run(self.trg_gfx, feed_dict = {self.trg_domain_img : batch_target_x})
            
            
            #for i in range(self.batch_size):
            #    for row in range(self.img_width):
            #        for col in range(self.img_height):
            #           for channel in range(self.mnist_channel):
            #               output[i][row][col][channel] = (output[i][row][col][channel])
                            
            return batch_source_x, batch_target_x, d_src_gfx, d_trg_gfx, d_trg_img, output0, output1
