#### Convolutional forward network

In [1]:
import tensorflow as tf

# For an RGB image, we have 3 channels. For a word embedding matrix or a grey scale image, num of channels is usually one.
num_channels = 1

# We have 3 filters
filter_sizes = [[3, 3], [4, 4], [5, 5]]

# Number of feature maps
num_fea_maps = 5

# Softmax hidden units
hidden_units = 50

def cnn_forward(inputs):
    pooled_outputs = []
    for i, filter_shape in enumerate(filter_sizes):
        with tf.variable_scope("Convolution-%s" % i) as scope:
            # Convolution layer
            # Filter shape = [height, width, num_channels, num_fea_maps]
            filter_shape.append(num_channels)
            filter_shape.append(num_fea_maps)
        
            W = tf.get_variable(name='W', shape=filter_shape)
            b = tf.get_variable(name='b', shape=[num_fea_maps])
            conv = tf.nn.conv2d(inputs,
                                W,
                                strides=[1, 1, 1, 1],
                                padding='SAME',
                                name='Conv')
            # Activation
            activated = tf.nn.relu(tf.nn.bias_add(conv, b))

            # Pooling
            pooled = tf.nn.max_pool(activated, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME")
            
            pooled_outputs.append(pooled)
    
    # We have 3 type of filter, pooled_outputs has shape 3 * (batch_size, 14, 14, num_fea_maps)
    # Here we concat the 4-th dim of pooled_outputs (batch_size, 14, 14, 3 * num_fea_maps)
    total_pooled = tf.concat(pooled_outputs, 3)    
    pool_shape = total_pooled.get_shape().as_list()
    nodes = pool_shape[1] * pool_shape[2] * pool_shape[3]
    
    # Convert (batch_size, num_features)
    reshaped = tf.reshape(total_pooled, [-1, nodes]) 
    
    # Softmax layers
    before_outputs = tf.contrib.layers.fully_connected(reshaped, hidden_units, activation_fn=tf.nn.relu)
    outputs = tf.contrib.layers.fully_connected(before_outputs, 10, activation_fn=None)
    return outputs

#### Read MNIST dataset

In [2]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('../VAE/Datasets/MNIST_data', one_hot=True)

Extracting ../VAE/Datasets/MNIST_data/train-images-idx3-ubyte.gz
Extracting ../VAE/Datasets/MNIST_data/train-labels-idx1-ubyte.gz
Extracting ../VAE/Datasets/MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../VAE/Datasets/MNIST_data/t10k-labels-idx1-ubyte.gz


#### Training

In [3]:
batch_size = 64

X = tf.placeholder(name='inputs', shape=[None, 28, 28, 1], dtype=tf.float32)
y = tf.placeholder(name='labels', shape=[None, 10], dtype=tf.float32)

logits = cnn_forward(X)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits))

predictions = tf.argmax(logits, 1)
correct_predictions = tf.equal(predictions, tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

train_op = tf.train.AdamOptimizer(0.001).minimize(loss)

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    for i in range(10000):
        xs, ys = mnist.train.next_batch(batch_size)
        # Convert shape(64, 784) to (64, 28, 28, 1)
        train_feed_X = xs.reshape(batch_size, 28, 28, 1)
        loss_val, _ = sess.run([loss, train_op], feed_dict={X: train_feed_X, y: ys})
        
        if i % 1000 == 0:
            # Convert shape(64, 784) to (64, 28, 28, 1). Validation has 5000 examples
            valid_feed_X = mnist.validation.images.reshape(5000, 28, 28, 1)
            valid_loss, valid_acc = sess.run([loss, accuracy], feed_dict={X: valid_feed_X, y: mnist.validation.labels})
            print("Valid loss %.5f, valid accuracy %.2f%%" % (valid_loss, valid_acc*100))

Valid loss 2.59421, valid accuracy 11.96%
Valid loss 0.13061, valid accuracy 96.14%
Valid loss 0.07981, valid accuracy 97.50%
Valid loss 0.06634, valid accuracy 98.00%
Valid loss 0.06003, valid accuracy 98.10%
Valid loss 0.05467, valid accuracy 98.40%
Valid loss 0.05546, valid accuracy 98.18%
Valid loss 0.05451, valid accuracy 98.34%
Valid loss 0.05189, valid accuracy 98.58%
Valid loss 0.05434, valid accuracy 98.36%
