In [2]:
import tensorflow as tf
import numpy as np

slim = tf.contrib.slim

  return f(*args, **kwds)


In [3]:
tf.reset_default_graph()

In [4]:
(train_data, train_labels), (test_data, test_labels) = \
                        tf.keras.datasets.mnist.load_data()
    
train_data = train_data / 255.
train_labels = np.asarray(train_labels, dtype=np.int32)

test_data = test_data / 255.
test_labels = np.asarray(test_labels, dtype=np.int32)

In [5]:
batch_size = 32

train_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_labels))
train_dataset = train_dataset.shuffle(10000)
train_dataset = train_dataset.batch(batch_size=batch_size)

test_dataset = tf.data.Dataset.from_tensor_slices((test_data, test_labels))
test_dataset = test_dataset.shuffle(buffer_size = 10000)
test_dataset = test_dataset.batch(batch_size = len(test_data))

In [6]:
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle,
                                              train_dataset.output_types,
                                              train_dataset.output_shapes)

x, y = iterator.get_next()
x = tf.cast(x, tf.float32)
y = tf.cast(y, tf.int32)

In [7]:
def residual_block(x, output_channel, is_training=True, down_sampling=False, is_end=False):
    
    if down_sampling:
        stride = 2
    else:
        stride = 1
    
    #  project shortcut? --> 1x1 convnet
    
    with slim.arg_scope([slim.conv2d], weights_regularizer=slim.l2_regularizer(scale=.0003),
                       normalizer_fn=slim.batch_norm,
                       normalizer_params= {'decay': .9, 'is_training': is_training}):
    
        h1 = slim.conv2d(x, output_channel, [3, 3], stride=stride)
        
        h2 = slim.conv2d(h1, output_channel, [3, 3], activation_fn=None)

    if down_sampling:
        x = slim.conv2d(x, output_channel, [1, 1], stride=stride, activation_fn=None)
    
    if is_end:
        return h2 + x
    return tf.nn.relu(h2 + x)

In [8]:
def build_resnet(x, layer_n):
    x = tf.reshape(x, [-1, 28, 28, 1])
    
    with tf.variable_scope('conv0'):
        net = slim.conv2d(x, 16, [3, 3], activation_fn=None)
        net = tf.layers.batch_normalization(net, epsilon=1e-5, training=is_training)
        net = tf.nn.relu(net)
        
    with tf.variable_scope('res0'):
        for i in range(layer_n):
            net = residual_block(net, 16, is_training)
    
    with tf.variable_scope('res1'):
        for i in range(layer_n):
            net = residual_block(net, 32, is_training, down_sampling=(i==0))
            
    with tf.variable_scope('res2'):
        for i in range(layer_n):
            net = residual_block(net, 64, is_training, down_sampling=(i==0), is_end=(i==layer_n-1))
            
    with tf.variable_scope('fc'):
        net = slim.flatten(net)
        net = slim.fully_connected(net, 10, activation_fn=None)
        
    return net

In [9]:
is_training = tf.placeholder(tf.bool)
logits = build_resnet(x, 3)

In [12]:
loss = tf.losses.sparse_softmax_cross_entropy(logits=logits, labels=y)

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    train_op = tf.train.AdamOptimizer(0.001).minimize(loss)

In [13]:
with tf.name_scope('summaries'):
  tf.summary.scalar('loss/cross_entropy', loss)
  for var in tf.trainable_variables():
    tf.summary.histogram(var.op.name, var)
  # merge all summaries
  summary_op = tf.summary.merge_all()

In [14]:
graph_location = 'graphs/resnet_mnist'
print('Saving graph to: %s' % graph_location)
train_writer = tf.summary.FileWriter(graph_location)
train_writer.add_graph(tf.get_default_graph()) 

Saving graph to: graphs/resnet_mnist


In [15]:
saver = tf.train.Saver()

In [None]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

train_iterator = train_dataset.make_initializable_iterator()
train_handle = sess.run(train_iterator.string_handle())

epochs = 3
step = 0

for epoch in range(epochs):
    sess.run(train_iterator.initializer)
    
    while True:
        try:
            _, loss_ = sess.run([train_op, loss], 
                                feed_dict={handle: train_handle, is_training: True})
            
            if step % 100 == 0:
                print("Step: %d, Loss: %g" % (step, loss_))
                
                summary_str = sess.run(summary_op, 
                                      feed_dict={handle: train_handle, is_training: True})
                train_writer.add_summary(summary_str, global_step=step)
            
            if step % 300 == 0:
                print('Saving the model in ', graph_location)
                saver.save(sess, graph_location + 'model.ckpt', global_step=step)
            
            step += 1
            
        except tf.errors.OutOfRangeError:
            print('End of dataset')
            break
        
    print('Epoch: ', epoch)

train_writer.close()
print('Training done!')

In [20]:
model_ckpt = tf.train.latest_checkpoint('graphs/')
saver.restore(sess, model_ckpt)

INFO:tensorflow:Restoring parameters from graphs/resnet_mnistmodel.ckpt-1500


In [21]:
test_iterator = test_dataset.make_initializable_iterator()
test_handle = sess.run(test_iterator.string_handle())
sess.run(test_iterator.initializer)

In [22]:
accuracy, acc_op = tf.metrics.accuracy(labels=y, predictions=tf.argmax(logits, 1), name='accuracy')
sess.run(tf.local_variables_initializer())

sess.run(acc_op, feed_dict={handle: test_handle, is_training: False})
print("test accuracy:", sess.run(accuracy, feed_dict={handle: test_handle, is_training: False}))

test accuracy: 0.9804
