Define the function

In [1]:
def batches(batch_size, features, labels):
    """
    Create a batch
    """
    assert len(features)==len(labels)
    batch_output = []
    
    for idx in range(0, len(features), batch_size):
        batch_output.append([features[idx:idx+batch_size], labels[idx:idx+batch_size]])
        
    return batch_output

In [2]:
def print_epoch_stats(epoch_i, sess, last_features, last_labels):
    """
    Print cost and validation accuracy of an epoch
    """
    current_cost = sess.run(cost, feed_dict={features:last_features, labels:last_labels})
    validate_accuracy = sess.run(accuracy, feed_dict={features:last_features, labels:last_labels})
    print('Epoch: {:<4} - Cost: {:<8.3} - Validate accuracy: {:<5.3}'.format(epoch_i, current_cost, validate_accuracy))    

Import module

In [3]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np

In [4]:
mnist = input_data.read_data_sets('./datasets/ud730/mnist/', one_hot=True)

Extracting ./datasets/ud730/mnist/train-images-idx3-ubyte.gz
Extracting ./datasets/ud730/mnist/train-labels-idx1-ubyte.gz
Extracting ./datasets/ud730/mnist/t10k-images-idx3-ubyte.gz
Extracting ./datasets/ud730/mnist/t10k-labels-idx1-ubyte.gz


In [16]:
n_inputs = 784
n_classes = 10
batch_size = 128
learn_rate = 0.001
epochs = 90

In [6]:
# Split data to training set, validation set and test set
train_features = mnist.train.images
validation_features = mnist.validation.images
test_features = mnist.test.images

train_labels = mnist.train.labels.astype(np.float32)
validation_labels = mnist.validation.labels.astype(np.float32)
test_labels = mnist.test.labels.astype(np.float32)

In [7]:
# Feature and label
features = tf.placeholder(tf.float32, [None, n_inputs])
labels = tf.placeholder(tf.float32, [None, n_classes])

# Weights and bias
weights = tf.Variable(tf.random_normal([n_inputs, n_classes]))
bias = tf.Variable(tf.random_normal([n_classes]))

# Logist = xW+b
logits = tf.add(tf.matmul(features, weights), bias)

# Learning rate
learning_rate = tf.placeholder(tf.float32)

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cost)

In [8]:
# Calculate the accuracy
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

In [9]:
# Init the variable tensor
init = tf.global_variables_initializer()

# Seperate data into a patch with patch size = 128 samples
train_batches = batches(batch_size, train_features, train_labels)

In [17]:
with tf.Session() as sess:
    sess.run(init)
    
    for epoch_i in range(epochs):
        # Loop over the batches
        for batch_features, label_features in train_batches:
            sess.run(optimizer, feed_dict={features:batch_features, labels:label_features, learning_rate:learn_rate})
        
        print_epoch_stats(epoch_i, sess, batch_features, label_features)
        
    # Caculate the accuracy for test dataset
    test_accuracy = sess.run(accuracy, feed_dict={features:test_features, labels:test_labels})
    
print('Test Accuracy: {}'.format(test_accuracy))

Epoch: 0    - Cost: 11.9     - Validate accuracy: 0.136
Epoch: 1    - Cost: 10.6     - Validate accuracy: 0.136
Epoch: 2    - Cost: 9.64     - Validate accuracy: 0.136
Epoch: 3    - Cost: 8.97     - Validate accuracy: 0.148
Epoch: 4    - Cost: 8.44     - Validate accuracy: 0.148
Epoch: 5    - Cost: 8.01     - Validate accuracy: 0.159
Epoch: 6    - Cost: 7.63     - Validate accuracy: 0.159
Epoch: 7    - Cost: 7.3      - Validate accuracy: 0.182
Epoch: 8    - Cost: 7.0      - Validate accuracy: 0.193
Epoch: 9    - Cost: 6.73     - Validate accuracy: 0.205
Epoch: 10   - Cost: 6.47     - Validate accuracy: 0.205
Epoch: 11   - Cost: 6.23     - Validate accuracy: 0.216
Epoch: 12   - Cost: 6.0      - Validate accuracy: 0.216
Epoch: 13   - Cost: 5.79     - Validate accuracy: 0.227
Epoch: 14   - Cost: 5.58     - Validate accuracy: 0.239
Epoch: 15   - Cost: 5.39     - Validate accuracy: 0.25 
Epoch: 16   - Cost: 5.2      - Validate accuracy: 0.25 
Epoch: 17   - Cost: 5.02     - Validate accuracy