In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import math
from scipy import misc as im

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data')

In [None]:
tf.reset_default_graph()

# Network Parameters
num_hidden_1 = 256 # number of parameters in the first fully connected layer in encoder and decoder
num_hidden_2 = 128 # number of parameters in the second fully connected layer in encoder
num_hidden_3 = 3136 # number of parameters in the second fully connected layer in decoder
num_input = 784 # MNIST data input (img shape: 28*28)

batch_size = 64

X_in = tf.placeholder(dtype=tf.float32, shape=[None, 28, 28, 1], name='X')
keep_prob = tf.placeholder(dtype=tf.float32, shape=(), name='keep_prob')

# Model
# Training Parameters
number_of_training_iterations = 10000
learning_rate = 0.0005
train_on_gpu = False # True for training on GPU, False for training on CPU

# tensorflow graph
regularizer = tf.contrib.layers.l2_regularizer(scale = 0.1)

# Function that builds computational graph of encoder
def encoder(x, keep_prob):
    conv1 = tf.layers.conv2d(inputs=x, filters=32, kernel_size=[5, 5], padding="same", activation=tf.nn.relu, kernel_regularizer = regularizer, name = 'enc_conv1')    
    conv1 = tf.nn.dropout(conv1, keep_prob)
    pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)    
    
    conv2 = tf.layers.conv2d(inputs=pool1, filters=64, kernel_size=[5, 5], padding="same", activation=tf.nn.relu, kernel_regularizer = regularizer, name = 'enc_conv2')
    conv2 = tf.nn.dropout(conv2, keep_prob)
    pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)    
    flattened = tf.layers.flatten(pool2)    
    
    dense1 = tf.layers.dense(inputs=flattened, units=num_hidden_1, activation=tf.nn.relu, kernel_regularizer = regularizer, name = 'enc_dense1')
    
    z_mean = tf.layers.dense(inputs=dense1, units=num_hidden_2, name = 'enc_dense2_mn')
    
    z_log_sigma_sq = 0.5 * tf.layers.dense(inputs=dense1, units=num_hidden_2, name = 'enc_dense2_sd')
    
    eps = tf.random_normal([tf.shape(x)[0], num_hidden_2], dtype=tf.float32)
     
    z = tf.multiply(eps, tf.exp(z_log_sigma_sq)) + z_mean
    
    return z, z_mean, z_log_sigma_sq

# Function that builds computational graph of decoder
def decoder(z, keep_prob):
    dense1 = tf.layers.dense(inputs=z, units=num_hidden_1, activation=tf.nn.relu, kernel_regularizer = regularizer, name = 'dec_dense1')
        
    dense2 = tf.layers.dense(inputs=dense1, units=num_hidden_3, activation=tf.nn.relu, kernel_regularizer = regularizer, name = 'dec_dense2')
        
    dense2 = tf.reshape(dense2, shape=[-1, 7, 7, 64])
        
    deconv1 = tf.layers.conv2d_transpose(dense2, 64, [5, 5], padding = 'SAME', kernel_regularizer = regularizer, activation=tf.nn.relu, name = 'dec_conv1')
    deconv1 = tf.nn.dropout(deconv1, keep_prob)
    
    upsample1 = tf.image.resize_bilinear(deconv1, size=[14, 14], align_corners=None, name=None)
        
    deconv2 = tf.layers.conv2d_transpose(upsample1, 32, [5, 5], padding = 'SAME', kernel_regularizer = regularizer, activation=tf.nn.relu, name = 'dec_conv2')
    deconv2 = tf.nn.dropout(deconv2, keep_prob)
    
    upsample2 = tf.image.resize_bilinear(deconv2, size=[28, 28], align_corners=None, name=None)
        
    flattened = tf.layers.flatten(upsample2)
        
    dense3 = tf.layers.dense(inputs=flattened, units=28*28, activation=tf.nn.sigmoid, kernel_regularizer = regularizer, name = 'dec_dense3')
    
    dense3 = tf.reshape(dense3, shape = [-1, 28, 28, 1])
    
    return dense3

# Autoencoder model
if train_on_gpu:
    with tf.device('/gpu:0'):
        z, mn, sd = encoder(X_in, keep_prob)
        dec = decoder(z, keep_prob)
else:
    z, mn, sd = encoder(X_in, keep_prob)
    dec = decoder(z, keep_prob)

# The output of decoder becomes prediction
y_pred = dec

# In autoencoders, our labels are input images since the aim is to 
# reconstruct the original image from its low dimensional representation
y_true = X_in

# Loss function is mean square error and KL divergence which we minimize using Adam Optimizer
latent_loss = -0.5 * tf.reduce_sum(1.0 + 2.0 * sd - tf.square(mn) - tf.exp(2.0 * sd), 1)
img_loss = tf.reduce_sum(tf.squared_difference(tf.reshape(dec, [-1, 28 * 28]), tf.reshape(y_true, [-1, 28 * 28])), 1)

loss = tf.reduce_mean(img_loss + latent_loss)
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss)

saver = tf.train.Saver()

In [None]:
# If you just want to use a pre-trained model, please skip this cell and run the next one.
# Training

sess = tf.Session() # start a Tensorflow session
sess.run(tf.global_variables_initializer()) # initialize variables

# Training iterations
for i in range(number_of_training_iterations):
    batch_x = [np.reshape(b, [28, 28, 1]) for b in mnist.train.next_batch(batch_size=batch_size)[0]] # get a new batch
    
    _, l, d = sess.run([optimizer, loss, dec], feed_dict={X_in: batch_x, keep_prob: 1}) # perform optimization
    #writer = tf.summary.FileWriter('./drive/My Drive/Colab Notebooks', sess.graph)
    # The following two lines are just to see how well decoder performs
    z_new = np.random.normal(0, 1, [batch_size, num_hidden_2])
    g = sess.run(dec, feed_dict={z: z_new, keep_prob: 1})
  
    if (i % 1000 == 1):
        plt.imshow(np.reshape(d[0], [28, 28]), cmap = 'gray') # reconstruction of batch_x[0]
        plt.show()
        plt.imshow(np.reshape(g[0], [28, 28]), cmap = 'gray') # a random sample from decoder
        plt.show()
        print("iter %d - loss = %f" % (i, l))

save_path = saver.save(sess, './tf_model.ckpt')
print("Model saved in path: %s" % save_path)

sess.close()

In [None]:
# Test
# Start a new tensorflow session
from scipy import misc as im

sess = tf.Session()
    
# Restore the saved model
saver.restore(sess, "./tf_model.ckpt")
        
number_of_samples = 100 # number of images to generate

# Generate a noise vector from unit Gaussian distribution
z_new = np.random.normal(0, 1, [number_of_samples, num_hidden_2])
g = sess.run(dec, feed_dict={z: z_new, keep_prob: 1})

for i in range(number_of_samples):
    #temp = im.imresize(g[i].reshape([28, 28]), (56, 56), interp = 'bilinear')
    #im.imsave('./drive/My Drive/Github/segmentation_autoencoder/Results_VAE/mnist_vae_%d.png'%i, temp)
    plt.imshow(g[i].reshape([28, 28]), cmap="gray")
    plt.show()

sess.close()