In [1]:
import numpy as np
import tensorflow as tf

from keras.datasets import mnist
from keras.utils import to_categorical

Using TensorFlow backend.


In [2]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

In [3]:
x_train = x_train.reshape(-1, 784) / 255
x_test = x_test.reshape(-1, 784) / 255

y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

In [4]:
tf.reset_default_graph()
input_shape = (None, 784) # None will be batch size, 

X = tf.placeholder(tf.float32, shape=input_shape, name='X')
y = tf.placeholder(tf.int32, shape=(None), name='y')

In [5]:
training = tf.placeholder_with_default(False, shape=(), name='training')

dropout_rate = 0.2
X_drop = tf.layers.dropout(X, rate=dropout_rate, training=training)

hidden1_nodes = 256
n_outputs = 10

In [6]:
def shuffle_batch(X, y, batch_size):
    rnd_idx = np.random.permutation(len(X))
    n_batches = len(X) // batch_size
    for batch_idx in np.array_split(rnd_idx, n_batches):
        X_batch, y_batch = X[batch_idx], y[batch_idx]
        yield X_batch, y_batch

In [7]:
with tf.name_scope("ann"):
    hidden1 = tf.layers.dense(X, hidden1_nodes, activation=tf.nn.relu, name='hidden')
    hidden1_dropout = tf.layers.dropout(hidden1, rate=dropout_rate, training=training)
    logits = tf.layers.dense(hidden1_dropout, n_outputs, name='outputs')
    
with tf.name_scope("loss"):
    cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=logits)
    loss = tf.reduce_mean(cross_entropy, name='loss')
    
with tf.name_scope("train"):
    optimizer = tf.train.AdamOptimizer()
    training_op = optimizer.minimize(loss)
    
with tf.name_scope("eval"):
    correct = tf.nn.in_top_k(logits, y, 1)
    accuracy = tf.reduce_mean(tf.cast(correct, tf.float32), name='accuracy')
    
init = tf.global_variables_initializer()
saver = tf.train.Saver()

In [None]:
n_epoch = 5
batch_size = 32

with tf.Session() as sess:
    init.run()
    for epoch in range(n_epoch):
        for X_batch, y_batch in shuffle_batch(x_train, y_train, batch_size):
            sess.run(training_op, feed_dict={X: X_batch, y: y_batch, training: True})
        accuracy_val = accuracy.eval(feed_dict={X: x_test, y: y_test})
        print(epoch, "Validation accuracy:", accuracy_val)