In [1]:
# import dependencies
import time
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# import MNIST
mnist = input_data.read_data_sets("data/MNIST", one_hot=False)

Extracting data/MNIST/train-images-idx3-ubyte.gz
Extracting data/MNIST/train-labels-idx1-ubyte.gz
Extracting data/MNIST/t10k-images-idx3-ubyte.gz
Extracting data/MNIST/t10k-labels-idx1-ubyte.gz


In [2]:
tf.reset_default_graph()

# set parameters
learning_rate = 0.1

# create the model (a simple convolutional network)
x = tf.placeholder(tf.float32, (None, 784))
y = tf.placeholder(tf.float32, (None, 10))

x_reshaped = tf.reshape(x, (-1, 28, 28, 1))

conv1 = tf.layers.conv2d(x_reshaped, filters=32, kernel_size=(3, 3), strides=(1, 1), activation=tf.nn.relu,
                         padding="SAME")
pool1 = tf.layers.max_pooling2d(conv1, pool_size=(2, 2), strides=(2, 2), padding="SAME")

conv2 = tf.layers.conv2d(pool1, filters=64, kernel_size=(3, 3), strides=(1, 1), activation=tf.nn.relu, 
                         padding="SAME")
pool2 = tf.layers.max_pooling2d(conv2, pool_size=(2, 2), strides=(2, 2), padding="SAME")

conv3 = tf.layers.conv2d(pool2, filters=128, kernel_size=(3, 3), strides=(1, 1), activation=tf.nn.relu,
                         padding="SAME")

flatten = tf.reshape(conv3, (tf.shape(conv3)[0], 7 * 7 * 128))

logits = tf.layers.dense(flatten, 10)

cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits))
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(cross_entropy)

correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

In [8]:
tf.reset_default_graph()

# set parameters
learning_rate = 0.1

# create the model (a simple convolutional network)
x = tf.placeholder(tf.float32, (None, 784))
y = tf.placeholder(tf.int32, (None,))

x_reshaped = tf.reshape(x, (-1, 28, 28, 1))

conv1 = tf.layers.conv2d(x_reshaped, filters=32, kernel_size=(3, 3), strides=(1, 1), activation=tf.nn.relu,
                         padding="SAME")
pool1 = tf.layers.max_pooling2d(conv1, pool_size=(2, 2), strides=(2, 2), padding="SAME")

conv2 = tf.layers.conv2d(pool1, filters=64, kernel_size=(3, 3), strides=(1, 1), activation=tf.nn.relu, 
                         padding="SAME")
pool2 = tf.layers.max_pooling2d(conv2, pool_size=(2, 2), strides=(2, 2), padding="SAME")

conv3 = tf.layers.conv2d(pool2, filters=128, kernel_size=(3, 3), strides=(1, 1), activation=tf.nn.relu,
                         padding="SAME")

flatten = tf.reshape(conv3, (tf.shape(conv3)[0], 7 * 7 * 128))

with tf.variable_scope("means"):
    means = tf.layers.dense(flatten, 10)

with tf.variable_scope("stds"):
    stds = tf.layers.dense(flatten, 10)

dist = tf.distributions.Normal(loc=means, scale=tf.square(stds))
sampling_steps = 50
samples = dist.sample([sampling_steps])

y_index = tf.stack([tf.range(0, tf.shape(y)[0], delta=1), y], axis=1)

targets = tf.map_fn(lambda x: tf.gather_nd(x, y_index), samples)
targets = tf.tile(tf.expand_dims(targets, -1), [1, 1, tf.shape(samples)[-1]])

diff = samples - targets

loss = tf.reduce_logsumexp(diff, axis=-1)
loss = tf.reduce_mean(loss)

train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)

logits = tf.reduce_mean(samples, axis=0)

correct_prediction = tf.equal(tf.cast(tf.argmax(logits, 1), tf.int32), y)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

In [10]:
# set parameters
num_iterations = 1000

# train

start = time.time()

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())

  # training
  print("Training started.")
    
  for step in range(num_iterations):
    batch_xs, batch_ys = mnist.train.next_batch(64)

    _, l = sess.run([train_step, loss], feed_dict={
      x: batch_xs,
      y: batch_ys
    })

    if step % 100 == 0 and step > 0:
      
      duration = (time.time() - start) / 100
      print("{:.2f} seconds per batch".format(duration))
    
      # test model
      test_accuracy = sess.run(accuracy, feed_dict={
        x: mnist.test.images,
        y: mnist.test.labels
      })
      print("Accuracy: %.2f%%" % (test_accuracy * 100))

      start = time.time()
        
  # evaluation
  accuracy = sess.run(accuracy, feed_dict={
    x: mnist.test.images,
    y: mnist.test.labels
  })

print("\nTrained for %d iterations" % num_iterations)
print("Accuracy: %.2f%%" % (accuracy * 100))

Training started.
0.13 seconds per batch
Accuracy: 86.86%
0.12 seconds per batch
Accuracy: 94.64%
0.13 seconds per batch
Accuracy: 95.22%
0.13 seconds per batch
Accuracy: 95.45%
0.13 seconds per batch
Accuracy: 96.82%
0.13 seconds per batch
Accuracy: 97.75%
0.12 seconds per batch
Accuracy: 97.91%
0.13 seconds per batch
Accuracy: 97.49%
0.13 seconds per batch
Accuracy: 97.09%

Trained for 1000 iterations
Accuracy: 97.91%
