In [None]:
from __future__ import print_function, division
import tensorflow as tf
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
plt.rcParams['xtick.major.pad']='16'

def count_params():
    """Count the number of parameters in the current TensorFlow graph """
    param_count = np.sum([np.prod(x.get_shape().as_list()) for x in tf.global_variables()])
    return param_count


def get_session():
    config = tf.ConfigProto()
    config.log_device_placement    = True
#     config.gpu_options.allow_growth = True
    session = tf.Session(config=config)
    return session



In [None]:
code_length = 10
output_image = 32

data = np.random.choice([-1, 1], size=(20000, 10))
input_data = data[0:18000]
val_data = data[-2000:]

In [None]:
#style calculation
import vgg19

# gram matrix per layer
def gram_matrix(x):
    assert isinstance(x, tf.Tensor)
    b, h, w, ch = x.get_shape().as_list()
    features = tf.reshape(x, [-1, h*w, ch])
    gram = tf.matmul(features, features, transpose_b=True)/tf.constant(ch*w*h, tf.float32)
    return gram


In [None]:
tf.reset_default_graph()
sess = get_session()

# synthesizer network
_input = tf.placeholder(tf.float32, [None, code_length])
isTraining = tf.placeholder(tf.bool)

# dense_layer = tf.layers.dense(input_, output_image * output_image * 3, activation=tf.nn.relu, name='dense1')
sy_dl1 = tf.layers.dense(_input, 8 * 8 * 64, activation=tf.nn.relu, name='dense1') 
sy_dl1_r = tf.reshape(sy_dl1, (-1, 8, 8, 64)) # [-1, 8, 8, 64]

# [N, 8, 8, 128] x [3, 3] = [N, 8, 8, 64]
sy_cn1 = tf.layers.conv2d(sy_dl1_r, 64, 3, 1, padding="same", activation=tf.nn.relu)

# [N, 8, 8, 64] x [3, 3] = [N, 8, 8, 32]
sy_cn2 = tf.layers.conv2d(sy_cn1, 32, 3, 1, padding="same", activation=tf.nn.relu)

# [N, 8, 8, 32] x [3, 3] = [N, 16, 16, 32]
sy_cnt1 = tf.layers.conv2d_transpose(sy_cn2, 32, 3, 2, padding="same")

# [N, 16, 16, 32] x [3, 3] = [N, 16, 16, 64]
sy_cn3 = tf.layers.conv2d(sy_cnt1, 64, 3, 1, padding="same", activation=tf.nn.relu)

# [N, 16, 16, 64] x [3, 3] = [N, 16, 16, 32]
sy_cn4 = tf.layers.conv2d(sy_cn3, 32, 3, 1, padding="same", activation=tf.nn.relu)

# [N, 16, 16, 32] x [3, 3] = [N, 32, 32, 3]
sy_output_image = tf.layers.conv2d_transpose(sy_cn4, 3, 3, 2, padding="same")

tf.summary.image('marker', sy_output_image)

#recognizer network
conv1 = tf.layers.conv2d(sy_output_image, 96, [5, 5], padding='same', name='conv1') # [32, 32, 96]
kernel_conv1 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'conv1/kernel')[0]
bias_conv1 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'conv1/bias')[0]
tf.summary.histogram('conv1/kernel', kernel_conv1)
tf.summary.histogram('conv1/bias', bias_conv1)

pool1 = tf.layers.max_pooling2d(conv1, [2, 2], 2, padding='valid') # [16, 16, 96]
batch1 = tf.layers.batch_normalization(pool1, axis=3, training=isTraining)
relu1 = tf.nn.relu(batch1) # [16, 16, 96]

conv2 = tf.layers.conv2d(relu1, 96, [5, 5], padding='same', name='conv2') # [16, 16, 96]
kernel_conv2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'conv2/kernel')[0]
bias_conv2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'conv2/bias')[0]
tf.summary.histogram('conv2/kernel', kernel_conv2)
tf.summary.histogram('conv2/bias', bias_conv2)

pool2 = tf.layers.max_pooling2d(conv2, [2, 2], 2, padding='valid') # [8, 8, 96]
batch2 = tf.layers.batch_normalization(pool2, axis=3, training=isTraining)
relu2 = tf.nn.relu(batch2) # [8, 8, 96]

conv3 = tf.layers.conv2d(relu2, 96, [5, 5], padding='same', name='conv3') # [8, 8, 96]
kernel_conv3 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'conv3/kernel')[0]
bias_conv3 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'conv3/bias')[0]
tf.summary.histogram('conv3/kernel', kernel_conv3)
tf.summary.histogram('conv3/bias', bias_conv3)

pool3 = tf.layers.max_pooling2d(conv3, [2, 2], 2, padding='valid') # [4, 4, 96]
batch3 = tf.layers.batch_normalization(pool3, axis=3, training=isTraining)
relu3 = tf.nn.relu(batch3) # [4, 4, 96]

flat_relu3 = tf.reshape(relu3, (-1, 4 * 4 * 96))
fc1 = tf.layers.dense(flat_relu3, 192)
fc2 = tf.layers.dense(fc1, code_length)

vgg_s = vgg19.Vgg19()
np_target_style_image = np.asarray(plt.imread('style.png'))
target_style_image = tf.constant(np_target_style_image[np.newaxis, :, :, 0:3])
vgg_s.build(tf.image.resize_images(target_style_image, (224, 224)))
feature_ = [vgg_s.conv1_2, vgg_s.conv2_2, vgg_s.conv3_3, vgg_s.conv4_3, vgg_s.conv5_3]
gram_ = [gram_matrix(l) for l in feature_]

vgg = vgg19.Vgg19()
vgg.build(tf.image.resize_images(sy_output_image[0:1,:,:,:], (224, 224)))
feature = [vgg.conv1_2, vgg.conv2_2, vgg.conv3_3, vgg.conv4_3, vgg.conv5_3]
gram = [gram_matrix(l) for l in feature]

style_loss = tf.zeros(1, tf.float32)
for g, g_ in zip(gram, gram_):
    style_loss += 1e1 * tf.reduce_mean(tf.subtract(g, g_) ** 2)

# loss function is element wise sigmoid
# mean_loss = tf.losses.mean_squared_error(input_, fc2)
# sigmoid_loss = tf.losses.sigmoid_cross_entropy(tf.cast(input_, tf.int32), fc2)
mean_loss = tf.losses.mean_squared_error(_input, fc2)
# sigmoid_loss = -tf.multiply(tf.reduce_mean(tf.sigmoid(tf.multiply(input_, fc2))), tf.constant(1.))

loss = mean_loss + style_loss
tf.summary.scalar('mean_loss', mean_loss)
tf.summary.scalar('style_loss', style_loss)
optimizer = tf.train.AdamOptimizer(learning_rate=0.05)

merged = tf.summary.merge_all()
writer = tf.summary.FileWriter('./train', sess.graph)
# batch normalization in tensorflow requires this extra dependency
extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(extra_update_ops):
    train = optimizer.minimize(loss)


sess.run(tf.global_variables_initializer())

In [None]:
#missing Trainable batch norm thingy
for i in tqdm(range(200 * 90)):
    batch_ = np.random.randint(9000, size=200)
    x = sess.run([train, merged], feed_dict={_input : input_data[batch_, :], isTraining : True})
    if i % 50 == 0:
        writer.add_summary(x[1], i)

    if i % 500 == 0:
        x = sess.run([loss], feed_dict={_input : val_data, isTraining : False})
        print('loss: ', x[0])
        
        #print a sample output
        check_data = np.random.choice([-1, 1], size=(1, 10))
        #print(check_data[0])
        image_output = sess.run([sy_output_image], feed_dict={_input : check_data, isTraining : False})
        image_output1 = np.reshape(image_output[0], (output_image, output_image, 3))
        plt.figure(figsize=(4, 4))
        plt.imshow(image_output1)
        plt.show(image_output)



In [None]:
check_data = np.random.choice([-1, 1], size=(1, 10))
print(check_data[0])
image_output = sess.run([dense_layer1], feed_dict={input_ : check_data, isTraining : False})
print(image_output[0].shape)
# print(np.sign(image_output[0][0]))
# image_output1 = np.reshape(image_output[0], (output_image, output_image, 3))
# plt.figure(figsize=(4, 4))
# plt.imshow(image_output1)

# plt.show(image_output)
# print(np.sign(image_output[0][0]))
# image_output1 = np.reshape(image_output[0], (output_image, output_image, 3))
# plt.imshow(image_output1)

# print(sess.run([fc2], feed_dict={input_ : check_data, isTraining: False}))