In [1]:
import tensorflow as tf
import numpy as np
import glob
import os
import re
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.gridspec as gridspec
%matplotlib inline

In [2]:
def plotPoke(x):
    f, a = plt.subplots(2, 8, figsize=(8, 2))
    for i in range(8):
        a[0][i].imshow(x[i], cmap=plt.get_cmap('gray'))
        a[0,i].axis('off')
        a[1][i].imshow(x[i+8], cmap=plt.get_cmap('gray'))
        a[1,i].axis('off')
    f.show()
    plt.draw()

In [3]:
# Create an empty array to store pokemon pics
orig_img = np.empty((0, 40, 40, 3), dtype='float32')
# Load all images and append into orig_img
path = os.path.abspath("./AE_RGB.ipynb")
path = re.sub('[a-zA-Z\s._]+$', '', path)
for pic in glob.glob(path+'Pokemon/*.png'):
    img = mpimg.imread(pic)
    # remove alpha channel  %some alpha=0 but RGB is not equal to [1., 1., 1.]
    img[img[:,:,3]==0] = np.ones((1,4))
    img = img[:,:,0:3]
    orig_img = np.append(orig_img, [img], axis=0)

# Use plt to show original images 
print 'Input data shape: {}'.format(orig_img.shape)
# plotPoke(orig_img)

Input data shape: (792, 40, 40, 3)


In [12]:
# Parameters
learning_rate = 0.001
training_epochs = 1000
batch_size = 24
display_step = 100
examples_to_show = 8

# Network Parameters
n_input = [40, 40, 3] # Pokemon data input (img shape: 40*40*3)
n_channel1 = 16
n_channel2 = 32
n_channel3 = 64
gen_dim = 100

# tf Graph input
X = tf.placeholder(tf.float32, [None]+n_input)
Z = tf.placeholder(tf.float32, [None, gen_dim])

In [13]:
# Store layers weights & biases
weights = {
    'dis_h1': tf.Variable(tf.truncated_normal([5, 5, 3, n_channel2], stddev=0.01)),
    'dis_h2': tf.Variable(tf.truncated_normal([5, 5, n_channel2, n_channel3], stddev=0.01)),
    'dis_h3': tf.Variable(tf.truncated_normal([10*10*n_channel3, 1], stddev=0.01)),

    'gen_h1': tf.Variable(tf.truncated_normal([gen_dim, 5*5*n_channel3], stddev=0.01)),
    'gen_h2': tf.Variable(tf.truncated_normal([5, 5, n_channel2, n_channel3], stddev=0.01)),
    'gen_h3': tf.Variable(tf.truncated_normal([5, 5, n_channel1, n_channel2], stddev=0.01)),
    'gen_h4': tf.Variable(tf.truncated_normal([5, 5, 3, n_channel1], stddev=0.01))
}

biases = {
    'dis_h1': tf.Variable(tf.truncated_normal([n_channel2], stddev=0.01)),
    'dis_h2': tf.Variable(tf.truncated_normal([n_channel3], stddev=0.01)),
    'dis_h3': tf.Variable(tf.truncated_normal([1], stddev=0.01)),
    
    'gen_h1': tf.Variable(tf.truncated_normal([5*5*n_channel3], stddev=0.01)),
    'gen_h2': tf.Variable(tf.truncated_normal([n_channel2], stddev=0.01)),
    'gen_h3': tf.Variable(tf.truncated_normal([n_channel1], stddev=0.01)),
    'gen_h4': tf.Variable(tf.truncated_normal([3], stddev=0.01))
}

In [14]:
def conv2d(x, W, b, strides=2):
    # Conv2D wrapper, with bias and relu activation
    x = tf.nn.conv2d(x, W, strides=[1, strides, strides, 1], padding='SAME')
    x = tf.nn.bias_add(x, b)
    return tf.nn.relu(x)

def deconv2d(x, W, b, out_shape, strides=2):
    x = tf.nn.conv2d_transpose(x, W, out_shape, strides=[1, strides, strides, 1], 
                               padding='SAME')
    x = tf.nn.bias_add(x, b)
    return tf.nn.sigmoid(x)

def project(x, W, b):
    return tf.nn.relu(tf.add(tf.matmul(x, W), b))

In [15]:
def generator(z):
    hidden_g1 = project(z, weights['gen_h1'], biases['gen_h1'])
    hidden_g1 = tf.reshape(hidden_g1, [-1, 5, 5, n_channel3])
    
    output_dim2 = tf.stack([tf.shape(z)[0], 10, 10, n_channel2])
    hidden_g2 = deconv2d(hidden_g1, weights['gen_h2'], biases['gen_h2'], output_dim2)
    
    output_dim3 = tf.stack([tf.shape(z)[0], 20, 20, n_channel1])
    hidden_g3 = deconv2d(hidden_g2, weights['gen_h3'], biases['gen_h3'], output_dim3)
    
    output_dim4 = tf.stack([tf.shape(z)[0], 40, 40, 3])
    hidden_g4 = deconv2d(hidden_g3, weights['gen_h4'], biases['gen_h4'], output_dim4)
    
    return hidden_g4

def discriminator(x):
    hidden_d1 = conv2d(x, weights['dis_h1'], biases['dis_h1'])
    hidden_d2 = conv2d(hidden_d1, weights['dis_h2'], biases['dis_h2'])
    hidden_d2 = tf.reshape(hidden_d2, [-1, 10*10*n_channel3])
    
    hidden_d3 = project(hidden_d2, weights["dis_h3"], biases['dis_h3'])
    return hidden_d3

In [16]:
# Construct discriminator and generator
gen_sample = generator(Z)
dis_real= discriminator(X)
dis_fake = discriminator(gen_sample)

# Define loss
dis_loss = -tf.reduce_mean(tf.log(dis_real) + tf.log(1.-dis_fake))
gen_loss = -tf.reduce_mean(tf.log(dis_fake))

# Optimizer for discriminator
var_dis = [weights[i] for i in weights if re.match('dis', i)]+[biases[i] for i in biases if re.match('dis', i)]
dis_optimizer = tf.train.AdamOptimizer(learning_rate).minimize(dis_loss, var_list= var_dis)
# Optimizer for generator parameters
var_gen = [weights[i] for i in weights if re.match('gen', i)]+[biases[i] for i in biases if re.match('gen', i)]
gen_optimizer = tf.train.AdamOptimizer(learning_rate).minimize(gen_loss, var_list= var_gen)

In [20]:
# Define function: sample data for generator
def sample_Z(m, n):
    return np.random.uniform(-1., 1., size=[m, n])

In [18]:
# Initializing the variables
init = tf.global_variables_initializer()

# Create session and graph, initial variables
sess = tf.InteractiveSession()
sess.run(init)

In [None]:
# Load previous trained model and rewrite to variables, if exists
# Before run this cell, you have to run the cell above first, to define variables and init it.
weightSaver = tf.train.Saver(var_list=weights)
biaseSaver = tf.train.Saver(var_list=biases)

weightSaver.restore(sess, "./saved_model/Conv_AE_weights.ckpt")
biaseSaver.restore(sess, "./saved_model/Conv_AE_biases.ckpt")

print "Model restored."

In [23]:
total_batch = int(orig_img.shape[0]/batch_size)
# Training cycle
for epoch in range(training_epochs):
    # Loop over all batches
    start = 0; end = batch_size
    for i in range(total_batch):
        index = np.arange(start, end)
        np.random.shuffle(index)
        batch_xs = orig_img[index]
        start = end+1; end = start+batch_size
        # Run optimization op (backprop) and loss op (to get loss value)
        _, d_loss_train = sess.run([dis_optimizer, dis_loss], feed_dict = {X: batch_xs, Z: sample_Z(batch_size, gen_dim)})
        _, g_loss_train = sess.run([gen_optimizer, gen_loss], feed_dict = {Z: sample_Z(batch_size, gen_dim)})
    # Display logs per epoch step
    if ((epoch == 0) or (epoch+1) % display_step == 0) or ((epoch+1) == training_epochs):
        print 'Epoch: {0:04d}   Discriminator loss: {1:f}   Generator loss: {1:f}'.format(epoch+1, d_loss_train, g_loss_train)

print("Optimization Finished!")

Epoch: 00001   Discriminator loss: 1.386294   Generator loss: 1.386294
Epoch: 00100   Discriminator loss: 1.386294   Generator loss: 1.386294


KeyboardInterrupt: 