# Batch Normalization in TensorFlow

In this notebook we explore the proper way to use batch normalization in TensorFlow. TensorFlow has several functions related to batch normalization, in the `contrib`, `nn` and `layers` modules. The latter is supposed to be the more "official", and is the one that we will be using. We build a MLP with a (uselessly) large number of layers, all of the same size. The common wisdom is that batchnorm can, at least partially, replace dropout.

In [44]:
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from functools import partial

mnist = input_data.read_data_sets('/tmp/data/')

n_inputs = 28 * 28
n_hidden1 = 100
n_hidden2 = 100
n_hidden3 = 100
n_hidden4 = 100
n_hidden5 = 100

learning_rate = 0.01
n_epochs = 150
batch_size = 100

tf.reset_default_graph()

Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz


Besides the usual placeholders for `X` and `y`, we need an extra one for the logical variable indicating whether we are in the training or in the test phase.

In [45]:
X = tf.placeholder(tf.float32, shape=(None, n_inputs), name='X')
y = tf.placeholder(tf.int64, shape=(None), name='y')
training = tf.placeholder_with_default(False, shape=(), name='training')

The batchnorm layer should be between the inputs and the activation. We are going to share the parameters of the batch normalization layer across all the hidden layers, and we are going to use `functools.partial` for this. Similarly we are going to use the same initialization for all the hidden layers, and it makes sense to create a `my_dense_layer` layer.

In [46]:
my_dense_layer = partial(
    tf.layers.dense, 
    kernel_initializer=tf.contrib.layers.variance_scaling_initializer())

my_batch_norm = partial(tf.layers.batch_normalization, 
                        training=training,
                        momentum=0.9)

with tf.name_scope('model'):
    
    hidden1 = my_dense_layer(X, n_hidden1, name='hidden1')
    bn1 = my_batch_norm(hidden1)
    bn1_act = tf.nn.elu(bn1)
    
    hidden2 = my_dense_layer(X, n_hidden2, name='hidden2')
    bn2 = my_batch_norm(hidden2)
    bn2_act = tf.nn.elu(bn2)
    
    hidden3 = my_dense_layer(X, n_hidden3, name='hidden3')
    bn3 = my_batch_norm(hidden3)
    bn3_act = tf.nn.elu(bn3)
    
    hidden4 = my_dense_layer(X, n_hidden4, name='hidden4')
    bn4 = my_batch_norm(hidden4)
    bn4_act = tf.nn.elu(bn4)
    
    hidden5 = my_dense_layer(X, n_hidden5, name='hidden5')
    bn5 = my_batch_norm(hidden5)
    bn5_act = tf.nn.elu(bn5)
    
    logits_before_bn = tf.layers.dense(bn5_act, 10, name='output')
    logits = my_batch_norm(logits_before_bn)

We can now write the crossentropy and the loss, generate the training op, and compute the accuracy.

In [47]:
with tf.name_scope('loss'):
    xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
    labels=y, logits=logits, name='xentropy')
    loss = tf.reduce_mean(xentropy, name='loss')

The [TensorFlow documentation](https://www.tensorflow.org/versions/master/api_docs/python/tf/layers/batch_normalization) recommends that *when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in tf.GraphKeys.UPDATE_OPS, so they need to be added as a dependency to the train_op*. In practical terms this means that we should use a `with tf.control_dependencies` term in the code containing the training op.

In [50]:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

with tf.name_scope('train'):
    # optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
    optimizer = tf.train.AdamOptimizer()
    with tf.control_dependencies(update_ops):
        training_op = optimizer.minimize(loss)
    
with tf.name_scope('eval'):
    correct = tf.nn.in_top_k(logits, y, 1)
    accuracy = tf.reduce_mean(tf.cast(correct, dtype=tf.float32))

init = tf.global_variables_initializer()

We can now launch a session and train the model.

In [51]:
with tf.Session() as sess:
    sess.run(init)
    for epoch in range(n_epochs):
        for iteration in range(mnist.train.num_examples // batch_size):
            X_batch, y_batch = mnist.train.next_batch(batch_size)
            sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
        accuracy_val = sess.run(accuracy, feed_dict={X: mnist.test.images,
                                                     y: mnist.test.labels})
        if epoch % 10 == 0:
            print(epoch, "test accuracy:", accuracy_val)

0 test accuracy: 0.9289
10 test accuracy: 0.9752
20 test accuracy: 0.9761
30 test accuracy: 0.9765
40 test accuracy: 0.9778
50 test accuracy: 0.9733
60 test accuracy: 0.9785
70 test accuracy: 0.9779
80 test accuracy: 0.9786
90 test accuracy: 0.977
100 test accuracy: 0.9788
110 test accuracy: 0.9744
120 test accuracy: 0.9783
130 test accuracy: 0.9787
140 test accuracy: 0.9768
