Skip to content

Commit

Permalink
Implemented snt.BatchNormV2, which differs from snt.BatchNorm in the …
Browse files Browse the repository at this point in the history
…following ways:

* Automatically computes updates to moving statistics by default (i.e. update_ops_collection=None).
* Uses moving statistics by default when testing (i.e. test_local_stats=False).
* Takes a data_format string (NC/NWC/NCW/NHWC/NCHW/NDHWC/NCDHW) rather than axes; reduces along all non-C axes.
* Uses fused batch normalization by default. If the data_format isn't NHWC or NCHW, reshapes the batch internally.
* Uses flat variables for the moving statistics, scale, and offset so that they can be shared between different data_formats.

PiperOrigin-RevId: 179819339
  • Loading branch information
Deepmind authored and diegolascasas committed Jan 8, 2018
1 parent b920570 commit 9a547ed
Show file tree
Hide file tree
Showing 5 changed files with 1,293 additions and 5 deletions.
1 change: 1 addition & 0 deletions sonnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from sonnet.python.modules.basic_rnn import ModelRNN
from sonnet.python.modules.basic_rnn import VanillaRNN
from sonnet.python.modules.batch_norm import BatchNorm
from sonnet.python.modules.batch_norm_v2 import BatchNormV2
from sonnet.python.modules.clip_gradient import clip_gradient
from sonnet.python.modules.conv import CausalConv1D
from sonnet.python.modules.conv import Conv1D
Expand Down
2 changes: 2 additions & 0 deletions sonnet/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ py_library(
"modules/attention.py",
"modules/basic_rnn.py",
"modules/batch_norm.py",
"modules/batch_norm_v2.py",
"modules/block_matrix.py",
"modules/clip_gradient.py",
"modules/conv.py",
Expand Down Expand Up @@ -134,6 +135,7 @@ module_tests = [
("basic_test", "", "small"),
("basic_rnn_test", "", "medium"),
("batch_norm_test", "", "small"),
("batch_norm_v2_test", "", "small"),
("layer_norm_test", "", "small"),
("block_matrix_test", "", "small"),
("clip_gradient_test", "", "small"),
Expand Down
13 changes: 8 additions & 5 deletions sonnet/python/modules/batch_norm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,27 +138,30 @@ def _get_inputs(self, dtype=tf.float32):

return v, input_v, inputs

def testShiftImproveStatistics(self):
"""Test that using moving_mean as shift improves statistics."""
def testUpdateImproveStatistics(self):
"""Test that updating the moving_mean improves statistics."""

_, _, inputs = self._get_inputs()

# Use small decay_rate to update faster.
bn = snt.BatchNorm(offset=False, scale=False, decay_rate=0.1)
out1 = bn(inputs, is_training=True)
out1 = bn(inputs, is_training=False, test_local_stats=False)

# Build the update ops.
bn(inputs, is_training=True)

with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
out_v = sess.run(out1)

# Before updating the moving_mean the results are off.
self.assertAllClose(np.zeros([7, 6]), out_v, rtol=1e-6, atol=1e-5)
self.assertBetween(np.max(np.abs(np.zeros([7, 6]) - out_v)), 2, 5)

sess.run(tuple(tf.get_collection(tf.GraphKeys.UPDATE_OPS)))

# After updating the moving_mean the results are better.
out_v = sess.run(out1)
self.assertAllClose(np.zeros([7, 6]), out_v, rtol=1e-6, atol=1e-6)
self.assertBetween(np.max(np.abs(np.zeros([7, 6]) - out_v)), 1, 2)

@parameterized.named_parameters(
("Float16", tf.float16),
Expand Down

0 comments on commit 9a547ed

Please sign in to comment.