In [1]:
import tensorflow as tf

In [2]:
"""
The batch normalization layer does not normalize based on the current batch if its training parameter is not set to true.
"""

tf.reset_default_graph()

x = tf.placeholder(tf.float32, [None, 1], 'x')
y = tf.layers.batch_normalization(x)

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
y_out = sess.run(y, feed_dict={x: [[-10], [0], [10]]})
sess.close()

print(y_out)


[[-9.995004]
 [ 0.      ]
 [ 9.995004]]


In [3]:
"""
With this setup, the batch normalization layer looks at the current batch and normalized it depending on its value.
"""

tf.reset_default_graph()

x = tf.placeholder(tf.float32, [None, 1], 'x')
y = tf.layers.batch_normalization(x, training=True)

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
y_out = sess.run(y, feed_dict={x: [[-10], [0], [10]]})
sess.close()

print(y_out)
###################################################3
tf.reset_default_graph()

x = tf.placeholder(tf.float32, [None, 1], 'x')
y = tf.layers.batch_normalization(x, training=True)

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
y_out = sess.run(y, feed_dict={x: [[-10]]})
sess.close()

print(y_out)


[[-1.2247357]
 [ 0.       ]
 [ 1.2247357]]
[[0.]]


In [4]:
tf.get_collection(tf.GraphKeys.UPDATE_OPS)

[<tf.Operation 'batch_normalization/AssignMovingAvg' type=AssignSub>,
 <tf.Operation 'batch_normalization/AssignMovingAvg_1' type=AssignSub>]

In [5]:
"""
In order to update the two moving average variables (mean and variance),
which the tf.layers.batch_normalization function call creates automatically,
two operations must be evaluated while feeding a batch through the layer.
"""

tf.reset_default_graph()

x = tf.placeholder(tf.float32, [None, 1], 'x')
y = tf.layers.batch_normalization(x, training=True)

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
y_1 = sess.run([y, update_ops], feed_dict={x: [[-10], [0], [10]]})[0]
y_2 = sess.run(y, feed_dict={x: [[-10]]})
sess.close()

print(y_1)
print(y_2)

[[-1.2247357]
 [ 0.       ]
 [ 1.2247357]]
[[0.]]


In [6]:
"""
The values for y1 and y2 remain the same.
That is because the moving averages are only being used, if the training parameter is set to False.
We can control it with a placeholder (here a placeholder with a default value) and set it to True when feeding the larger batch
(and False for the smaller batch; strictly not necessary because it is the placeholder’s default value anyways)

GT:
When `training` is set to True tf.layers.batch_normalization normalize.
When `training` is set to False tf.layers.batch_normalization does not normalize and use moving averages.
So, for larger batch set `training=True` and for small batch set `training=False`.
"""

tf.reset_default_graph()

is_training = tf.placeholder_with_default(False, (), 'is_training')
x = tf.placeholder(tf.float32, [None, 1], 'x')
y = tf.layers.batch_normalization(x, training=is_training)

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
y_1 = sess.run([y, update_ops], feed_dict={x: [[-10], [0], [10]], is_training: True})[0]
#y_1 = sess.run(y, feed_dict={x: [[-10], [0], [10]], is_training: True}) # without update_ops
y_2 = sess.run(y, feed_dict={x: [[-10]], is_training: False})
sess.close()

print(y_1)
print(y_2)

[[-1.2247357]
 [ 0.       ]
 [ 1.2247357]]
[[-7.766966]]


In [7]:
"""
Kind of weird. It’s neither 0, which it was without moving averages, nor -1.22,
which it should be if it was normalized with the same factors as the X1 batch.

The reason for the wrong normalization of the small batch is that the moving averages update slowly.
If we were to feed the larger batch multiple times, the second batch would be properly normalized:
"""

tf.reset_default_graph()

is_training = tf.placeholder_with_default(False, (), 'is_training')
x = tf.placeholder(tf.float32, [None, 1], 'x')
y = tf.layers.batch_normalization(x, training=is_training)

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
for _ in range(1000):
    y_1 = sess.run([y, update_ops], feed_dict={x: [[-10], [0], [10]], is_training: True})[0]
y_2 = sess.run(y, feed_dict={x: [[-10]], is_training: False})
sess.close()

print(y_1)
print(y_2)

[[-1.2247357]
 [ 0.       ]
 [ 1.2247357]]
[[-1.224762]]


In [8]:
"""
Right now we have to call sess.run and pass the update_ops manually.
It is more convenient to add them as a control dependency,
such that TensorFlow always executes them if the Tensor y is being evaluated.
"""

tf.reset_default_graph()

is_training = tf.placeholder_with_default(False, (), 'is_training')
x = tf.placeholder(tf.float32, [None, 1], 'x')
y = tf.layers.batch_normalization(x, training=is_training)

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    y = tf.identity(y)
    
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

x_1 = [[-10], [0], [10]]
x_2 = [[-10]]
for _ in range(1000):
    y_1 = sess.run(y, feed_dict={x: x_1, is_training: True})
y_2 = sess.run(y, feed_dict={x: x_2})

print(y_1)
print(y_2)

[[-1.2247357]
 [ 0.       ]
 [ 1.2247357]]
[[-1.224762]]


In [9]:
"""
Typically, is_training should be set to True during training and False when performing inference.
The values stored by the batch normalization layer can be examined
"""

tf.global_variables()

[<tf.Variable 'batch_normalization/gamma:0' shape=(1,) dtype=float32_ref>,
 <tf.Variable 'batch_normalization/beta:0' shape=(1,) dtype=float32_ref>,
 <tf.Variable 'batch_normalization/moving_mean:0' shape=(1,) dtype=float32_ref>,
 <tf.Variable 'batch_normalization/moving_variance:0' shape=(1,) dtype=float32_ref>]

In [10]:
with tf.variable_scope("", reuse=tf.AUTO_REUSE):
    out = sess.run([tf.get_variable('batch_normalization/moving_mean'),
                    tf.get_variable('batch_normalization/moving_variance')])
    moving_average, moving_variance = out

In [11]:
moving_average

array([0.], dtype=float32)

In [12]:
moving_variance

array([66.66382], dtype=float32)