[View in Colaboratory](https://colab.research.google.com/github/johnphilip283/MNIST-Denoising-Autoencoder/blob/master/MNIST_Autoencoder.ipynb)

In [None]:
import tensorflow as tf
import numpy as np
from skimage import transform
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

In [None]:
def reconstruct(data):
    
  # Corrupt the data with random noise.
  data += tf.random_normal(tf.shape(data))
  
  conv1 = tf.layers.conv2d(data, 32, 4, 2, activation=tf.nn.relu, padding="SAME")
  conv2 = tf.layers.conv2d(conv1, 16, 4, 2, activation=tf.nn.relu, padding="SAME")
  conv3 = tf.layers.conv2d(conv2, 8, 4, 2, activation=tf.nn.relu, padding="SAME")
  
  # 32 x 32 x 1 -> 16 x 16 x 32
  # 16 x 16 x 32-> 8 x 8 x 16
  # 8 x 8 x 16 -> 4 x 4 x 8
    
  # Take the 32 x 32 x 1 images and map them to a 4 x 4 x 8 latent compressed space, 
  # and then map them back out to 32 x 32 x 1 feature space.
  
  conv4 = tf.layers.conv2d_transpose(conv3, 16, 4, 2, activation=tf.nn.relu, padding="SAME")
  conv5 = tf.layers.conv2d_transpose(conv4, 32, 4, 2, activation=tf.nn.relu, padding="SAME")
  final = tf.layers.conv2d_transpose(conv5, 1, 4, 2, activation=tf.nn.relu, padding="SAME")
  
  return final

def resize_images(images):
  
  # Just it case it isn't in this form yet, reshape the tensor.
  images = images.reshape((-1, 28, 28, 1))
  
  # Initialize a tensor full of zeroes to hold the correct resized tensor
  resized_images = np.zeros((images.shape[0], 32, 32, 1))
  
  # For each image in the batch we have,
  for i in range(images.shape[0]):
    
    # find the correct slot in the resultant batch, and store the resized image there.
    resized_images[i, ..., 0] = transform.resize(images[i, ..., 0], (32, 32))
    
  return resized_images

In [None]:
inputs = tf.placeholder(tf.float32, shape=(None, 32, 32, 1))
rec_images = reconstruct(inputs)

# Let the loss functions just be a simple reconstruction loss.
loss = tf.reduce_mean(tf.square(rec_images - inputs))
train_op = tf.train.GradientDescentOptimizer(0.3).minimize(loss)

mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
batch_size = 300
epochs = 10
num_batches = mnist.train.num_examples // batch_size

In [None]:
with tf.Session() as sess:
  
  # Initialize all TensorFlow variables in the current session's graph.
  sess.run(tf.global_variables_initializer())
  
  for epoch in range(epochs):
    for batch in range(num_batches):
      
      images, _ = mnist.train.next_batch(batch_size)
      images = images.reshape((-1, 28, 28, 1))
      images = resize_images(images)
      
      # Run the training step
      _, num_loss = sess.run([train_op, loss], feed_dict={inputs: images})
         
      print('Epoch: {} - cost= {:.5f}'.format((epoch + 1), num_loss))
      
      # Display the reconstructed images and the loss from time to time.
      if batch % 100 == 0:
        
        _, num_loss, re_images = sess.run([train_op, loss, rec_images], feed_dict={inputs: images})
        
        plt.figure(1)
        plt.title('Reconstructed Images')
        for i in range(50):
          plt.subplot(5, 10, i+1)
          plt.imshow(re_images[i, ..., 0], cmap='gray')

        plt.figure(2)
        plt.title('Input Images')
        for i in range(50):
          plt.subplot(5, 10, i+1)
          plt.imshow(images[i, ..., 0], cmap='gray')
        plt.show()
        
  images, _ = mnist.train.next_batch(1)
  images = images.reshape((-1, 28, 28, 1))
  images = resize_images(images)
  
  image = sess.run([rec_images], feed_dict={inputs: images})[0]
  
  plt.figure(1)
  plt.title("Original image")
  plt.imshow(images[0, ..., 0], cmap="gray")
  
  plt.figure(2)
  plt.title("Reconstructed image")
  plt.imshow(image[0, ..., 0], cmap="gray")
  
  plt.show()