The structure of this network follows the classic structure of CNNs, which is a mix of convolutional layers and max pooling, followed by fully-connected layers.

In [1]:
import tensorflow as tf

# Get the data

In [2]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("../", one_hot=True, reshape=False)

Extracting ../train-images-idx3-ubyte.gz
Extracting ../train-labels-idx1-ubyte.gz
Extracting ../t10k-images-idx3-ubyte.gz
Extracting ../t10k-labels-idx1-ubyte.gz


# Hyperparameters

In [17]:
# Parameters
learning_rate = 0.00001
epochs = 10
batch_size = 128

# Number of samples to calculate validation and accuracy
# Decrease this if you're running out of memory to calculate accuracy
test_valid_size = 256

# Network Parameters
n_classes = 10  # MNIST total classes (0-9 digits)
dropout = 0.75  # Dropout, probability to keep units

# Weights and Biases

In [19]:
# Store layers weight & bias
# form of weights inputs are width,height, depth, # filters
weights = {
    'wc1': tf.Variable(tf.truncated_normal([5, 5, 1, 32])), # weights for 32 5x5x1 filters
    'wc2': tf.Variable(tf.truncated_normal([5, 5, 32, 64])), # weights for 64 5x5x32 filters
    'wd1': tf.Variable(tf.truncated_normal([7*7*64, 1024])), # weights for relu to fully-connected layer
    'out': tf.Variable(tf.truncated_normal([1024, n_classes]))} # 

biases = {
    'bc1': tf.Variable(tf.random_normal([32])), # bias for the 32 filters
    'bc2': tf.Variable(tf.random_normal([64])), # bias for the 64 filters
    'bd1': tf.Variable(tf.random_normal([1024])), # bias for the full layer
    'out': tf.Variable(tf.random_normal([n_classes]))}

In [7]:
# Convulution layer
def conv2d(x, W, b, stride=1):
    x = tf.nn.conv2d(x, W, strides=[1, stride, stride, 1], padding='SAME')
    x = tf.nn.bias_add(x, b)
    return tf.nn.relu(x)

In [6]:
# k by k max pooling layer
def maxpool2d(x, k=2):
    return tf.nn.max_pool(
        x,
        ksize=[1, k, k, 1],
        strides=[1, k, k, 1],
        padding='SAME')

# Model 

The model has three layers, alternating between convolutions and max pooling, followed by a fully connected and output layer. The transformation of each layer to new dimensions are shown in the comments. For example, the first layer shapes the images from ```28x28x1``` to ```28x28x32``` in the convolution step. Then next step applies max pooling, turning each sample into ```14x14x32```. All the layers are applied from conv1 to output, producing 10 class predictions.

In [20]:
def conv_net(x, weights, biases, dropout):
    # Layer 1: 28*28*1 to 14*14*32, using 32 5x5x1 filters
    conv1 = conv2d(x, weights['wc1'], biases['bc1'])
    # 2x2 max pooling
    conv1 = maxpool2d(conv1, k=2)

    # Layer 2: 14*14*32 to 7*7*64, using 64 5x5x32 filters
    conv2 = conv2d(conv1, weights['wc2'], biases['bc2'])
    # 2x2 max pooling
    conv2 = maxpool2d(conv2, k=2)

    # Fully connected layer: 7*7*64 to 1024
    fc1 = tf.reshape(conv2, [-1, weights['wd1'].get_shape().as_list()[0]])
    fc1 = tf.add(tf.matmul(fc1, weights['wd1']), biases['bd1'])
    fc1 = tf.nn.relu(fc1)
    fc1 = tf.nn.dropout(fc1, dropout)

    # Output Layer - class prediction - 1024 to 10
    out = tf.add(tf.matmul(fc1, weights['out']), biases['out'])
    return out

# Model Training & Testing

In [21]:
# tf Graph input
x = tf.placeholder(tf.float32, [None, 28, 28, 1])
y = tf.placeholder(tf.float32, [None, n_classes])
keep_prob = tf.placeholder(tf.float32)

# Model
logits = conv_net(x, weights, biases, keep_prob)

# Define loss and optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cost)

# Accuracy
correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Initializing the variables
init = tf. global_variables_initializer()

# Launch the graph
with tf.Session() as sess:
    sess.run(init)

    for epoch in range(epochs):
        for batch in range(mnist.train.num_examples//batch_size):
            batch_x, batch_y = mnist.train.next_batch(batch_size)
            sess.run(optimizer, feed_dict={
                x: batch_x,
                y: batch_y,
                keep_prob: dropout})

            # Calculate batch loss and accuracy
            loss = sess.run(cost, feed_dict={
                x: batch_x,
                y: batch_y,
                keep_prob: 1.})
            valid_acc = sess.run(accuracy, feed_dict={
                x: mnist.validation.images[:test_valid_size],
                y: mnist.validation.labels[:test_valid_size],
                keep_prob: 1.})

            if (batch+1) % 64 == 0:
                print('Epoch {:>2}, Batch {:>3} -'
                      'Loss: {:>10.4f} Validation Accuracy: {:.6f}'.format(
                    epoch + 1,
                    batch + 1,
                    loss,
                    valid_acc))

    # Calculate Test Accuracy
    test_acc = sess.run(accuracy, feed_dict={
        x: mnist.test.images[:test_valid_size],
        y: mnist.test.labels[:test_valid_size],
        keep_prob: 1.})
    print('Testing Accuracy: {}'.format(test_acc))

Epoch  1, Batch  64 -Loss:  6574.7393 Validation Accuracy: 0.375000
Epoch  1, Batch 128 -Loss:  2949.7703 Validation Accuracy: 0.519531
Epoch  1, Batch 192 -Loss:  2365.6963 Validation Accuracy: 0.585938
Epoch  1, Batch 256 -Loss:  1407.9883 Validation Accuracy: 0.613281
Epoch  1, Batch 320 -Loss:  1149.4143 Validation Accuracy: 0.664062
Epoch  1, Batch 384 -Loss:  1170.0652 Validation Accuracy: 0.683594
Epoch  2, Batch  64 -Loss:  1286.1045 Validation Accuracy: 0.734375
Epoch  2, Batch 128 -Loss:   790.8742 Validation Accuracy: 0.765625
Epoch  2, Batch 192 -Loss:   907.0021 Validation Accuracy: 0.769531
Epoch  2, Batch 256 -Loss:   881.4524 Validation Accuracy: 0.777344
Epoch  2, Batch 320 -Loss:   817.6746 Validation Accuracy: 0.792969
Epoch  2, Batch 384 -Loss:  1186.3710 Validation Accuracy: 0.792969
Epoch  3, Batch  64 -Loss:   531.8604 Validation Accuracy: 0.796875
Epoch  3, Batch 128 -Loss:   488.2843 Validation Accuracy: 0.804688
Epoch  3, Batch 192 -Loss:   681.3949 Validation