In [1]:
#!/usr/bin/env python3

In [2]:
%matplotlib inline
import numpy as np 
import time 
from datetime import timedelta
import tensorflow as tf
from readImages import *
from build_tensorflow_graph import *
import matplotlib.pyplot as plt

In [3]:
#
## ---------- VAE Architecture ----------
#

# Convolution Layer 1
numFilters1 = 32
cnnArchitecture1 = CNN_Architecture(numFilters = numFilters1, 
                                    filterSize = (3, 3), 
                                    strides = 1, 
                                    toPadding = True, 
                                    useReLU = True, 
                                    numInputChannels = 3, 
                                    maxPoolingSize=None)
dcnnArchitecture1 = cnnArchitecture1
dcnnArchitecture1.useReLU = False
# Convolution Layer 2
numFilters2 = 64
cnnArchitecture2 = CNN_Architecture(numFilters = numFilters2, 
                                    filterSize = (3, 3), 
                                    strides = 1, 
                                    toPadding = True, 
                                    useReLU = True, 
                                    numInputChannels = numFilters1, 
                                    maxPoolingSize=None)
dcnnArchitecture2 = cnnArchitecture2
# Convolution Layer 3
numFilters3 = 128
cnnArchitecture3 = CNN_Architecture(numFilters = numFilters3,
                                    filterSize = (2, 2), 
                                    strides = 1, 
                                    toPadding = True, 
                                    useReLU = True, 
                                    numInputChannels = numFilters2, 
                                    maxPoolingSize=None)
dcnnArchitecture3 = cnnArchitecture3
# Fully Connected 1
fc1_size = 1024 
# Fully Connected 2
fc2_size = 512
# Lattern Code 
z_dim = 256

In [4]:
#  
#    input layer: 64 * 64 * 3 = 12288
#    conv1 layer: 64 * 64 * 32 = 131072
#    conv2 layer: 64 * 64 * 64 = 262144
#    conv3 layer: 64 * 64 * 128 = 524288
#    fc1_size = 2048
#    fc2_size = 512
#    z_dim = 32

In [5]:
class Image:
    """
    Structure for input images
    """
    size = 32
    numChannels = 3

In [6]:
class VAE:
    def __init__(self, batch_size):
        self.batch_size = batch_size 
        self.sess = tf.Session()
        # ---------- build model ----------
        bs = self.batch_size
        self.inputImages = tf.placeholder(tf.float32, 
                                          shape=[bs, Image.size, Image.size, Image.numChannels])
        self.lattenCode = tf.placeholder(tf.float32, shape= [None, z_dim])
        self.mu, self.sigma = self.encoder(self.inputImages, is_training=True, reuse=False)
        z = self.mu + \
               self.sigma * tf.random_normal(tf.shape(self.mu), 0, 1, dtype=tf.float32)
        
        # Reconstruct
        reconstruct = self.decoder(z, is_training=True, reuse=False)
        self.reconstruct = tf.clip_by_value(reconstruct, 1e-8, 1 - 1e-8)
        
        # Define Cost
        regularizer = tf.reduce_sum(tf.exp(self.sigma) - (1 + self.sigma) + tf.square(self.mu) , axis = 1)
        lms = tf.reduce_mean(tf.square(self.inputImages - self.reconstruct), axis = [1, 2, 3])
        self.cost = tf.reduce_sum(regularizer + lms)

        self.optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(self.cost)
        
        # Testing
        self.fake_images = self.decoder(self.lattenCode, is_training=False, reuse=True)
        
        
    def encoder(self, inputLayer, is_training = True, reuse = False):
        with tf.variable_scope("encoder", reuse = reuse):
            conv1, convWeights1 = new_convLayer(inputLayer, cnnArchitecture1, name="en_conv1")
            conv2, convWeights2 = new_convLayer(conv1, cnnArchitecture2, name="en_conv2")
            conv2 = bn(conv2, is_training=is_training, scope="en_bn2")
            conv3, convWeights3 = new_convLayer(conv2, cnnArchitecture3, name="en_conv3")
            conv3 = bn(conv3, is_training=is_training, scope="en_bn3")
            flat_layer, numAttrs = flattenLayer(conv3)
            
            self.flat_numAttrs = (numAttrs)
            
            fc1 = new_fcLayer(flat_layer, 
                              inputChannels = numAttrs, 
                              outputChannels = fc1_size,
                              useReLU=True, 
                              name = "en_fc4")
            fc1 = bn(fc1, is_training=is_training, scope="en_bn4")
            fc2 = new_fcLayer(fc1, 
                              inputChannels = fc1_size, 
                              outputChannels = fc2_size,
                              useReLU=True, 
                              name = "en_fc5")
            fc2 = bn(fc2, is_training=is_training, scope="en_bn5")
            gaussian_para = new_fcLayer(fc2, 
                                        inputChannels = fc2_size, 
                                        outputChannels = z_dim * 2,
                                        useReLU=False, 
                                        name = "en_fc6")
            mean = gaussian_para[:, :z_dim]
            stdev = tf.exp(gaussian_para[:, z_dim:])
            return mean, stdev
        
    def decoder(self, lattenCode, is_training = True, reuse = False):
        with tf.variable_scope("decoder", reuse=reuse):
            dfc2 = new_fcLayer(lattenCode, 
                               inputChannels = z_dim, 
                               outputChannels = fc2_size,
                               useReLU=True, 
                               name = "de_fc1")
            dfc2 = bn(dfc2, is_training=is_training, scope="de_bn1")
            dfc1 = new_fcLayer(dfc2, 
                               inputChannels = fc2_size, 
                               outputChannels = fc1_size, 
                               useReLU = True, 
                               name = "de_fc2")
            dfc1 = bn(dfc1, is_training=is_training, scope="de_bn2")
            dflat = new_fcLayer(dfc1, 
                                inputChannels = fc1_size,  
                                outputChannels = self.flat_numAttrs, 
                                useReLU = True, 
                                name = "de_conv3")
            dflat = bn(dflat, is_training=is_training, scope="de_bn3")
            dconv3 = tf.reshape(dflat, shape=[-1, Image.size, Image.size, 128])
            
            dconv2 =  new_dconvLayer(dconv3, dcnnArchitecture3, 
                                     [self.batch_size, Image.size, Image.size, 64], name = "de_conv4")
            dconv2 = bn(dconv2, is_training=is_training, scope="de_bn4")
            dconv1 =  new_dconvLayer(dconv2, dcnnArchitecture2, 
                                     [self.batch_size, Image.size, Image.size, 32], name = "de_conv5")
            dconv1 = bn(dconv1, is_training=is_training, scope="de_bn5")
            reconstruct = new_dconvLayer(dconv1, dcnnArchitecture1, 
                                         [self.batch_size, Image.size, Image.size, 3], name = "de_conv6")
            reconstruct = tf.nn.sigmoid(reconstruct)
            return reconstruct
    def train(self, fileDir, epochs = 100):
        costHistory = []
        start = time.time()
        print('  * Start Training ...')
        self.sess.run(tf.global_variables_initializer())
        for epoch in range(epochs):
            print('  * processing epoch #{} '.format(epoch))
            count = 0
            for batch in readImagesIn(directory=fileDir, size=(Image.size, Image.size), noiseStdev=0.03, batch_size = self.batch_size):
                count += 1
                batch_images = batch.getAttrs()
                feed_dict_train = {self.inputImages : batch_images}
                _, cost = self.sess.run([self.optimizer, self.cost],feed_dict = feed_dict_train)
                
                if count % 200 == 1:
                    print('\tepoch #{} , iterations #{}, cost = {}'.format(epoch, count, cost))
                    costHistory.append(cost)
            print('\tDONE! cost: {}'.format(cost))
                
                
        # -------- Plot Learning Curve --------
        plt.figure()
        plt.title('Learning Curve')
        plt.plot(costHistory)
            
        
    def generateFakeImages(self):
        z = np.random.normal(0, 1, size=(z_dim, ))
        feed_dict_test = {self.lattenCode : np.array([z]) }
        fakeImg = self.sess.run(self.fake_images, feed_dict=feed_dict_test)
        return fakeImg

In [None]:
if __name__ == '__main__':
    directory = '../faces/'
#     imgs = readImagesIn(directory=directory, size=(64, 64), noiseStdev=0.03)
    
    print('  * Building Model ...', end="")
    vae = VAE(batch_size=40)
    print('  Finished!!')
    
    

  * Building Model ...  Finished!!


In [None]:
vae.train(fileDir=directory, epochs = 10)

  * Start Training ...
  * processing epoch #0 
	epoch #0 , iterations #1, cost = 9253.59765625
	epoch #0 , iterations #201, cost = 105.4503173828125
	epoch #0 , iterations #401, cost = 28.509605407714844
	epoch #0 , iterations #601, cost = 16.092824935913086
	epoch #0 , iterations #801, cost = 10.006964683532715
	epoch #0 , iterations #1001, cost = 9.446853637695312
	epoch #0 , iterations #1201, cost = 8.401418685913086
	DONE! cost: 6.267484664916992
  * processing epoch #1 
	epoch #1 , iterations #1, cost = 6.061323165893555
	epoch #1 , iterations #201, cost = 5.246826648712158
	epoch #1 , iterations #401, cost = 5.640420913696289
	epoch #1 , iterations #601, cost = 5.532613754272461
	epoch #1 , iterations #801, cost = 5.467437267303467


In [None]:
# vae.lattenCode = tf.placeholder(tf.float32, shape= [None, z_dim])
# vae.fake_images = vae.decoder(vae.lattenCode, is_training=False, reuse=True)

In [None]:
# z = np.random.normal(0, 1, size=(40, z_dim))
z = np.random.uniform(-1, 1, size=(40, z_dim))
feed_dict_test = {vae.lattenCode : z }
imgs = vae.sess.run(vae.fake_images, feed_dict=feed_dict_test)

numImgs = len(imgs)
numGrids = math.ceil(math.sqrt(numImgs))
fig, axes = plt.subplots(numGrids, numGrids)
    
for i, ax in enumerate(axes.flat):
    if i < numImgs:
        pic = imgs[i]
        # Clipping Data in 0 - 1
        for index, x in np.ndenumerate(pic):
            if x > 1.0:
                pic[index] = 1.0
            elif x < 0.0:
                pic[index] = 0.0
        ax.imshow(pic)
    ax.set_xticks([])
    ax.set_yticks([])

In [None]:
# imgs = [vae.generateFakeImages() for _ in range(9)]
# plotImages(imgs)