# Training Deep Neural Networks Using Batch Normalization in TensorFlow

Batch normalization really pays off when we train _really_ deep neural networks. 

Our goal in this notebook is to demonstrate this property. For that matter, we'll build two deep neural networks (one with and one without batch normalization) to recognize hand-written numbers from the MNIST database.

**DISCLAIMER**: These architectures are NOT the best for the MNIST dataset. They're too complex, and while a simpler network would produce better results, we made it this convoluted on purpose to demonstrate the usefulness of batch normalization.

## Preliminaries

Let's load the data.

In [1]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data/', one_hot=True, reshape=False)

Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please write your own downloading logic.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.


## High Level Network

This version of the network uses the helper functions in `tf.layers` which are very high level (at least for TensorFlow standards). Later we'll redo all the work you'll encounter in the following cells using a lower level API.

Let's start by building a network **without** batch normalization.

In [2]:
def fully_connected(previous_layer, number_of_units):
    return tf.layers.dense(previous_layer, number_of_units, activation=tf.nn.relu)

def conv_layer(previous_layer, layer_depth):
    strides = 2 if layer_depth % 3 == 0 else 1
    return tf.layers.conv2d(previous_layer, layer_depth * 4, 3, strides, 'same', activation=tf.nn.relu)

def train(number_of_batches, batch_size, learning_rate):
    inputs = tf.placeholder(tf.float32, (None, 28, 28, 1))
    labels = tf.placeholder(tf.float32, (None, 10))
    
    network = inputs
    for i in range(1, 20):
        network = conv_layer(network, layer_depth=i)
        
    # Flatten
    original_shape = network.get_shape().as_list()
    network = tf.reshape(network, shape=(-1, original_shape[1] * original_shape[2] * original_shape[3]))
    
    network = fully_connected(network, 100)
    
    logits = tf.layers.dense(network, 10)
    
    model_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels))
    train_optimizer = tf.train.AdamOptimizer(learning_rate).minimize(model_loss)
    
    correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    
    # Time to train
    with tf.Session() as s:
        s.run(tf.global_variables_initializer())
        
        for i in range(number_of_batches):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            
            s.run(train_optimizer, feed_dict={inputs: batch_xs, labels: batch_ys})
            
            # Check validation or training loss and accuracy
            if i % 100 == 0:
                loss, acc = s.run([model_loss, accuracy], feed_dict={inputs: mnist.validation.images, 
                                                                     labels: mnist.validation.labels})
                print(f'Batch: {i}, Validation loss: {loss}, Validation accuracy: {acc}')
            elif i % 25 == 0:
                loss, acc = s.run([model_loss, accuracy], feed_dict={inputs: batch_xs, labels: batch_ys})
                print(f'Batch: {i}, Training loss: {loss}, Training accuracy: {acc}')
              
        # Final accuracy for both validation and test sets
        acc = s.run(accuracy, feed_dict={inputs: mnist.validation.images, 
                                         labels: mnist.validation.labels})
        print(f'Final validaction accuracy: {acc}')
        acc = s.run(accuracy, feed_dict={inputs: mnist.test.images, 
                                         labels: mnist.test.labels})
        print(f'Final test accuracy: {acc}')
        
        # Score the first 100 test images.
        correct = 0.0
        for i in range(100):
            correct += s.run(accuracy, feed_dict={inputs: [mnist.test.images[i]], 
                                                  labels: [mnist.test.labels[i]]})
            
        print(f'Accuracy on 100 samples: {correct/100}')

Let's train the network:

In [3]:
NUM_BATCHES = 800
BATCH_SIZE = 32
LEARNING_RATE = 0.002

tf.reset_default_graph()
with tf.Graph().as_default():
    train(NUM_BATCHES, BATCH_SIZE, LEARNING_RATE)

Batch: 0, Validation loss: 0.690942108631134, Validation accuracy: 0.10700000077486038
Batch: 25, Training loss: 0.3634071350097656, Training accuracy: 0.0625
Batch: 50, Training loss: 0.3253903090953827, Training accuracy: 0.1875
Batch: 75, Training loss: 0.33056578040122986, Training accuracy: 0.03125
Batch: 100, Validation loss: 0.32672563195228577, Validation accuracy: 0.10999999940395355
Batch: 125, Training loss: 0.3267457187175751, Training accuracy: 0.09375
Batch: 150, Training loss: 0.32452020049095154, Training accuracy: 0.15625
Batch: 175, Training loss: 0.32730501890182495, Training accuracy: 0.0625
Batch: 200, Validation loss: 0.3261047601699829, Validation accuracy: 0.09759999811649323
Batch: 225, Training loss: 0.3210393786430359, Training accuracy: 0.21875
Batch: 250, Training loss: 0.32537251710891724, Training accuracy: 0.125
Batch: 275, Training loss: 0.3301261365413666, Training accuracy: 0.0625
Batch: 300, Validation loss: 0.3255506157875061, Validation accuracy: 0

Given the depth of the network, it will take a **really long time** to learn anything. In fact, after 800 batches, it only reaches 10% accuracy. That's not good. Let's see how it goes **with** batch normalization.

In [4]:
def fully_connected(previous_layer, number_of_units, is_training):
    layer = tf.layers.dense(previous_layer, number_of_units, use_bias=False, activation=None)
    layer = tf.layers.batch_normalization(layer, training=is_training)
    return tf.nn.relu(layer)

def conv_layer(previous_layer, layer_depth, is_training):
    strides = 2 if layer_depth % 3 == 0 else 1
    conv_layer = tf.layers.conv2d(previous_layer, layer_depth * 4, 3, strides, 'same', use_bias=False, activation=None)
    conv_layer = tf.layers.batch_normalization(conv_layer, training=is_training)
    return tf.nn.relu(conv_layer)

def train(number_of_batches, batch_size, learning_rate):
    inputs = tf.placeholder(tf.float32, (None, 28, 28, 1))
    labels = tf.placeholder(tf.float32, (None, 10))
    is_training = tf.placeholder(tf.bool)
    
    network = inputs
    for i in range(1, 20):
        network = conv_layer(network, layer_depth=i, is_training=is_training)
        
    # Flatten
    original_shape = network.get_shape().as_list()
    network = tf.reshape(network, shape=(-1, original_shape[1] * original_shape[2] * original_shape[3]))
    
    network = fully_connected(network, 100, is_training)
    
    logits = tf.layers.dense(network, 10)
    
    model_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels))
    
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_optimizer = tf.train.AdamOptimizer(learning_rate).minimize(model_loss)
    
    correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    
    # Time to train
    with tf.Session() as s:
        s.run(tf.global_variables_initializer())
        
        for i in range(number_of_batches):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            
            s.run(train_optimizer, feed_dict={inputs: batch_xs, 
                                              labels: batch_ys, 
                                              is_training: True})
            
            # Check validation or training loss and accuracy
            if i % 100 == 0:
                loss, acc = s.run([model_loss, accuracy], feed_dict={inputs: mnist.validation.images, 
                                                                     labels: mnist.validation.labels,
                                                                     is_training: False})
                print(f'Batch: {i}, Validation loss: {loss}, Validation accuracy: {acc}')
            elif i % 25 == 0:
                loss, acc = s.run([model_loss, accuracy], feed_dict={inputs: batch_xs, 
                                                                     labels: batch_ys,
                                                                     is_training: False})
                print(f'Batch: {i}, Training loss: {loss}, Training accuracy: {acc}')
              
        # Final accuracy for both validation and test sets
        acc = s.run(accuracy, feed_dict={inputs: mnist.validation.images, 
                                         labels: mnist.validation.labels, 
                                         is_training: False})
        print(f'Final validaction accuracy: {acc}')
        acc = s.run(accuracy, feed_dict={inputs: mnist.test.images, 
                                         labels: mnist.test.labels, 
                                         is_training: False})
        print(f'Final test accuracy: {acc}')
        
        # Score the first 100 test images.
        correct = 0.0
        for i in range(100):
            correct += s.run(accuracy, feed_dict={inputs: [mnist.test.images[i]], 
                                                  labels: [mnist.test.labels[i]],
                                                  is_training: False})
            
        print(f'Accuracy on 100 samples: {correct/100}')

Let's train the network again, but this time using batch normalization.

In [5]:
NUM_BATCHES = 800
BATCH_SIZE = 32
LEARNING_RATE = 0.002

tf.reset_default_graph()
with tf.Graph().as_default():
    train(NUM_BATCHES, BATCH_SIZE, LEARNING_RATE)

Batch: 0, Validation loss: 0.6910560727119446, Validation accuracy: 0.0989999994635582
Batch: 25, Training loss: 0.5869064927101135, Training accuracy: 0.0625
Batch: 50, Training loss: 0.4755508303642273, Training accuracy: 0.09375
Batch: 75, Training loss: 0.4092184603214264, Training accuracy: 0.03125
Batch: 100, Validation loss: 0.369232177734375, Validation accuracy: 0.0989999994635582
Batch: 125, Training loss: 0.3442865014076233, Training accuracy: 0.0625
Batch: 150, Training loss: 0.3285561203956604, Training accuracy: 0.15625
Batch: 175, Training loss: 0.31170812249183655, Training accuracy: 0.25
Batch: 200, Validation loss: 0.3115042448043823, Validation accuracy: 0.22280000150203705
Batch: 225, Training loss: 0.2478412389755249, Training accuracy: 0.5
Batch: 250, Training loss: 0.1931648701429367, Training accuracy: 0.6875
Batch: 275, Training loss: 0.10294921696186066, Training accuracy: 0.78125
Batch: 300, Validation loss: 0.3323472738265991, Validation accuracy: 0.46779999

Wow! Amazing! In the same number of batches, this new network achieved a decent 96.1% accuracy on the test set!

## Low Level Network

This version of the network uses the helper functions in `tf.nn` which are at a lower level than those in `tf.layers` package. This is useful because, some times, we want to have more control on how to implement a feature.

**NOTE**: In order to understand the implementations details present in the following cell, it's a good idea to first [read the original paper](https://arxiv.org/abs/1502.03167).

In [6]:
def fully_connected(previous_layer, number_of_units, is_training):
    layer = tf.layers.dense(previous_layer, number_of_units, use_bias=False, activation=None)
    
    gamma = tf.Variable(tf.ones([number_of_units]))
    beta = tf.Variable(tf.zeros([number_of_units]))
    
    population_mean = tf.Variable(tf.zeros([number_of_units]), trainable=False)
    population_variance = tf.Variable(tf.ones([number_of_units]), trainable=False)
    
    epsilon = 0.001
    
    def batch_normalization_training():
        batch_mean, batch_variance =  tf.nn.moments(layer, [0])
        
        decay = 0.99
        
        train_mean = tf.assign(population_mean, population_mean * decay + batch_mean * (1 - decay))
        train_variance = tf.assign(population_variance, population_variance * decay + batch_variance * (1 - decay))
        
        with tf.control_dependencies([train_mean, train_variance]):
            return tf.nn.batch_normalization(layer, batch_mean, batch_variance, beta, gamma, epsilon)
        
    def batch_normalization_inference():
        return tf.nn.batch_normalization(layer, population_mean, population_variance, beta, gamma, epsilon)
    
    batch_normalized_output = tf.cond(is_training, batch_normalization_training, batch_normalization_inference)
    return tf.nn.relu(batch_normalized_output)

def conv_layer(previous_layer, layer_depth, is_training):
    strides = 2 if layer_depth % 3 == 0 else 1
    
    input_channels = previous_layer.get_shape().as_list()[3]
    output_channels = layer_depth * 4
    
    weights = tf.Variable(tf.truncated_normal((3, 3, input_channels, output_channels), stddev=0.05))
    
    layer = tf.nn.conv2d(previous_layer, weights, strides=(1, strides, strides, 1), padding='SAME')
    
    gamma = tf.Variable(tf.ones([output_channels]))
    beta = tf.Variable(tf.zeros([output_channels]))
    
    population_mean = tf.Variable(tf.zeros([output_channels]), trainable=False)
    population_variance = tf.Variable(tf.ones([output_channels]), trainable=False)
    
    epsilon = 0.001
    
    def batch_normalization_training():
        batch_mean, batch_variance =  tf.nn.moments(layer, [0, 1, 2], keep_dims=False)
        
        decay = 0.99
        
        train_mean = tf.assign(population_mean, population_mean * decay + batch_mean * (1 - decay))
        train_variance = tf.assign(population_variance, population_variance * decay + batch_variance * (1 - decay))
        
        with tf.control_dependencies([train_mean, train_variance]):
            return tf.nn.batch_normalization(layer, batch_mean, batch_variance, beta, gamma, epsilon)
        
    def batch_normalization_inference():
        return tf.nn.batch_normalization(layer, population_mean, population_variance, beta, gamma, epsilon)
    
    batch_normalized_output = tf.cond(is_training, batch_normalization_training, batch_normalization_inference)
    return tf.nn.relu(batch_normalized_output)

def train(number_of_batches, batch_size, learning_rate):
    inputs = tf.placeholder(tf.float32, (None, 28, 28, 1))
    labels = tf.placeholder(tf.float32, (None, 10))
    is_training = tf.placeholder(tf.bool)
    
    network = inputs
    for i in range(1, 20):
        network = conv_layer(network, layer_depth=i, is_training=is_training)
        
    # Flatten
    original_shape = network.get_shape().as_list()
    network = tf.reshape(network, shape=(-1, original_shape[1] * original_shape[2] * original_shape[3]))
    
    network = fully_connected(network, 100, is_training)
    
    logits = tf.layers.dense(network, 10)
    
    model_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels))
    train_optimizer = tf.train.AdamOptimizer(learning_rate).minimize(model_loss)
    
    correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    
    # Time to train
    with tf.Session() as s:
        s.run(tf.global_variables_initializer())
        
        for i in range(number_of_batches):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            
            s.run(train_optimizer, feed_dict={inputs: batch_xs, 
                                              labels: batch_ys, 
                                              is_training: True})
            
            # Check validation or training loss and accuracy
            if i % 100 == 0:
                loss, acc = s.run([model_loss, accuracy], feed_dict={inputs: mnist.validation.images, 
                                                                     labels: mnist.validation.labels,
                                                                     is_training: False})
                print(f'Batch: {i}, Validation loss: {loss}, Validation accuracy: {acc}')
            elif i % 25 == 0:
                loss, acc = s.run([model_loss, accuracy], feed_dict={inputs: batch_xs, 
                                                                     labels: batch_ys,
                                                                     is_training: False})
                print(f'Batch: {i}, Training loss: {loss}, Training accuracy: {acc}')
              
        # Final accuracy for both validation and test sets
        acc = s.run(accuracy, feed_dict={inputs: mnist.validation.images, 
                                         labels: mnist.validation.labels, 
                                         is_training: False})
        print(f'Final validaction accuracy: {acc}')
        acc = s.run(accuracy, feed_dict={inputs: mnist.test.images, 
                                         labels: mnist.test.labels, 
                                         is_training: False})
        print(f'Final test accuracy: {acc}')
        
        # Score the first 100 test images.
        correct = 0.0
        for i in range(100):
            correct += s.run(accuracy, feed_dict={inputs: [mnist.test.images[i]], 
                                                  labels: [mnist.test.labels[i]],
                                                  is_training: False})
            
        print(f'Accuracy on 100 samples: {correct/100}')

Finally, let's the network using the lower level helper functions we just implemented.

In [7]:
NUM_BATCHES = 800
BATCH_SIZE = 32
LEARNING_RATE = 0.002

tf.reset_default_graph()
with tf.Graph().as_default():
    train(NUM_BATCHES, BATCH_SIZE, LEARNING_RATE)

Batch: 0, Validation loss: 0.6910102963447571, Validation accuracy: 0.11259999871253967
Batch: 25, Training loss: 0.5824524164199829, Training accuracy: 0.21875
Batch: 50, Training loss: 0.46968093514442444, Training accuracy: 0.09375
Batch: 75, Training loss: 0.4036182761192322, Training accuracy: 0.0625
Batch: 100, Validation loss: 0.36633265018463135, Validation accuracy: 0.10019999742507935
Batch: 125, Training loss: 0.34632745385169983, Training accuracy: 0.125
Batch: 150, Training loss: 0.3398407995700836, Training accuracy: 0.09375
Batch: 175, Training loss: 0.3682170510292053, Training accuracy: 0.0625
Batch: 200, Validation loss: 0.39046332240104675, Validation accuracy: 0.0868000015616417
Batch: 225, Training loss: 0.4842556416988373, Training accuracy: 0.03125
Batch: 250, Training loss: 0.5430601239204407, Training accuracy: 0.09375
Batch: 275, Training loss: 0.5405094027519226, Training accuracy: 0.125
Batch: 300, Validation loss: 0.6523125171661377, Validation accuracy: 0.

Great! It works. Batch normalization is a powerful tool, as we've seen in this notebook. Most of the time we can safely use the higher level functions in the `tf.layers` package, but it's also worth knowing how to work with the low level functions, as it gives us more control of our solutions.