In [1]:
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

In [2]:
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [3]:
def preproc(x):
    x = x*2 - 1.0
    return x

In [4]:
class Solver:
    def __init__(self, sess, model):
        self.model = model
        self.sess = sess
        
    def train(self, X, y):
        feed = {
            self.model.X: X,
            self.model.y: y,
            self.model.training: True
        }
        train_op = self.model.train_op
        loss = self.model.loss
        
        return self.sess.run([train_op, loss], feed_dict=feed)
    
    def evaluate(self, X, y, batch_size=None):
        if batch_size:
            N = X.shape[0]
            
            total_loss = 0
            total_acc = 0
            
            for i in range(0, N, batch_size):
                X_batch = X[i:i + batch_size]
                y_batch = y[i:i + batch_size]
                
                feed = {
                    self.model.X: X_batch,
                    self.model.y: y_batch,
                    self.model.training: False
                }
                
                loss = self.model.loss
                accuracy = self.model.accuracy
                
                step_loss, step_acc = self.sess.run([loss, accuracy], feed_dict=feed)
                
                total_loss += step_loss * X_batch.shape[0]
                total_acc += step_acc * X_batch.shape[0]
            
            total_loss /= N
            total_acc /= N
            
            return total_loss, total_acc
            
            
        else:
            feed = {
                self.model.X: X,
                self.model.y: y,
                self.model.training: False
            }
            
            loss = self.model.loss            
            accuracy = self.model.accuracy

            return self.sess.run([loss, accuracy], feed_dict=feed)

In [20]:
class Model:
    def __init__(self, name, lr=0.001):
        with tf.variable_scope(name):
            self.X = tf.placeholder(tf.float32, [None, 784], name='X')
            self.y = tf.placeholder(tf.float32, [None, 10], name='y')
            self.training = tf.placeholder(tf.bool, name='training')
            
            x = preproc(self.X)
            x_img = tf.reshape(x, [-1, 28, 28, 1])
            
#             h1_conv = tf.layers.conv2d(x_img, 64, [5,5], strides=2, padding='SAME', use_bias=False)
#             h1_bn = tf.layers.batch_normalization(h1_conv, training=self.training)
#             h1 = tf.nn.relu(h1_bn) # 14x14
#             h2_conv = tf.layers.conv2d(h1, 128, [5,5], strides=2, padding='SAME', use_bias=False)
#             h2_bn = tf.layers.batch_normalization(h2_conv, training=self.training)
#             h2 = tf.nn.relu(h2_bn) # 7x7
#             h3_conv = tf.layers.conv2d(h2, 256, [5,5], strides=2, padding='SAME', use_bias=False)
#             h3_bn = tf.layers.batch_normalization(h3_conv, training=self.training)
#             h3 = tf.nn.relu(h3_bn) # 4x4
            
            # hidden layers
            net = x_img
            n_filters = 64
            for i in range(3):
                net = tf.layers.conv2d(net, n_filters, [3,3], strides=1, kernel_initializer=tf.contrib.layers.xavier_initializer_conv2d(),
                                       padding='SAME', use_bias=False)
                net = tf.layers.batch_normalization(net, training=self.training)
                net = tf.nn.relu(net)
                net = tf.layers.dropout(net, rate=0.3, training=self.training)
                
                net = tf.layers.conv2d(net, n_filters, [5,5], strides=2, kernel_initializer=tf.contrib.layers.xavier_initializer_conv2d(),
                                       padding='SAME', use_bias=False)
#                 net = tf.layers.max_pooling2d(net, pool_size=[2,2], strides=2)
                net = tf.layers.batch_normalization(net, training=self.training)
                net = tf.nn.relu(net)
                net = tf.layers.dropout(net, rate=0.3, training=self.training)
                n_filters *= 2
            
            # x: [28, 28, 1]
            # h1: [14, 14, 64]
            # h2: [7, 7, 128]
            # h3: [4, 4, 256]
            # 4096 -> 1024 -> 10
            
            net = tf.contrib.layers.flatten(net)
#             net = tf.layers.dense(net, 1024, activation=tf.nn.relu)
#             net = tf.layers.dropout(net, rate=0.5, training=self.training)
            logits = tf.layers.dense(net, 10, weights_initializer=tf.contrib.layers.xavier_initializer())
            
            self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=self.y))
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=name)
            with tf.control_dependencies(update_ops):
                self.train_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(self.loss)    
            
            self.pred = tf.argmax(logits, axis=1)
            self.accuracy = tf.reduce_mean(tf.cast(tf.equal(self.pred, tf.argmax(self.y, axis=1)), tf.float32))

In [21]:
tf.reset_default_graph()

sess = tf.Session()

basic_cnn = Model('basic_cnn', lr=0.001)
solver = Solver(sess, basic_cnn)

tf.set_random_seed(777)

In [29]:
sess.run(tf.global_variables_initializer())

batch_size = 50
epoch_n = 60
N = mnist.train.num_examples

max_train_acc = 0
max_valid_acc = 0
max_test_acc = 0

for epoch in range(epoch_n):
    for _ in range(N // batch_size):
        batches = mnist.train.next_batch(batch_size)
        _, train_loss = solver.train(batches[0], batches[1])
#         sess.run(solver, {X: batches[0], y: batches[1]})
    
    train_loss, train_acc = solver.evaluate(mnist.train.images, mnist.train.labels, 1000)
    valid_loss, valid_acc = solver.evaluate(mnist.validation.images, mnist.validation.labels, 1000)
    test_loss, test_acc = solver.evaluate(mnist.test.images, mnist.test.labels, 1000)
    line = "[{:0>2d}/{}] train: {:.4f}, {:.3%} / valid: {:.4f}, {:.2%} / test: {:.4f}, {:.2%}". \
    format(epoch+1, epoch_n, train_loss, train_acc, valid_loss, valid_acc, test_loss, test_acc)
    print(line)
    
    if train_acc > max_train_acc:
        max_train_acc = train_acc
        train_line = line
    if valid_acc > max_valid_acc:
        max_valid_acc = valid_acc
        valid_line = line
    if test_acc > max_test_acc:
        max_test_acc = test_acc
        test_line = line
    

print("[train max] {}".format(train_line))
print("[valid max] {}".format(valid_line))
print("[ test max] {}".format(test_line))
# print("last maximum train acc: {:.2%}".format(max_train_acc))
# print("last maximum valid acc: {:.2%}".format(max_valid_acc))
# print("last maximum test acc: {:.2%}".format(max_test_acc))

[01/60] train: 0.0675, 98.113% / valid: 0.0682, 98.18% / test: 0.0521, 98.47%
[02/60] train: 0.0673, 98.100% / valid: 0.0734, 98.20% / test: 0.0521, 98.32%
[03/60] train: 0.0384, 98.885% / valid: 0.0515, 98.52% / test: 0.0417, 98.78%
[04/60] train: 0.0294, 99.127% / valid: 0.0396, 98.92% / test: 0.0345, 99.01%
[05/60] train: 0.0348, 98.911% / valid: 0.0469, 98.74% / test: 0.0451, 98.69%
[06/60] train: 0.0149, 99.527% / valid: 0.0262, 99.30% / test: 0.0232, 99.32%
[07/60] train: 0.0142, 99.571% / valid: 0.0266, 99.20% / test: 0.0234, 99.24%
[08/60] train: 0.0165, 99.495% / valid: 0.0349, 98.98% / test: 0.0247, 99.31%
[09/60] train: 0.0119, 99.627% / valid: 0.0303, 99.24% / test: 0.0242, 99.23%
[10/60] train: 0.0116, 99.653% / valid: 0.0309, 99.14% / test: 0.0209, 99.37%
[11/60] train: 0.0061, 99.820% / valid: 0.0223, 99.36% / test: 0.0185, 99.45%
[12/60] train: 0.0139, 99.516% / valid: 0.0382, 99.04% / test: 0.0283, 99.18%
[13/60] train: 0.0245, 99.247% / valid: 0.0511, 98.70% / test: 0

## Results

* basic-without preproc(zero-centered mean): 98.60%
* basic: 98.95%
* BN: 98.68%
    * BN-0.01: 94%
    * BN-0.05: 98.29%


* 2-strided models
    * BN: 99.26%
        * bias: 99.15%
    * No-BN, No-bias: 99.10%
    * No-BN: 99.26%
        * added 1 more 1024 dense layer: 99.00%
        * normalized input: 99.15%
* [(3,3),1] + [(5,5),2] model + BN
    * 2-FC + dropout: 99.39%, 99.46%
    * 1-FC: 99.52%, 99.49%
        * conv dropout 0.5: 99.53%
        * conv dropout 0.2: 99.45?
        * conv dropout 0.7: 99.11%, 99.35% (batch size 50)
        * conv dropout 0.3 + batch size 50 + epoch 30: 99.55%
* [(3,3),1] + [(3,3),2] model + BN
    * 1-FC
        * conv dropout 0.1: 99.41%
* max pooling 99.35%