Skip to content

Commit

Permalink
Comments explaining tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
dwf committed Jan 19, 2016
1 parent cac5319 commit a4d0728
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions tests/bricks/test_bn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@


def random_unif(rng, dim, low=1, high=10):
"""Generate some floatX uniform random numbers."""
return (rng.uniform(low, high, size=dim)
.astype(theano.config.floatX))

Expand Down Expand Up @@ -61,6 +62,7 @@ def check(input_dim, expected_shape, broadcastable=None, save_memory=True):


def apply_setup(input_dim, broadcastable, save_memory):
"""Common setup code."""
bn = BatchNormalization(input_dim, broadcastable, save_memory,
epsilon=1e-4)
bn.initialize()
Expand All @@ -73,6 +75,7 @@ def apply_setup(input_dim, broadcastable, save_memory):


def test_batch_normalization_inference_apply():
"""Test that BatchNormalization.apply works in inference mode."""
def check(input_dim, variable_dim, broadcastable=None, save_memory=True):
bn, x, y = apply_setup(input_dim, broadcastable, save_memory)
rng = numpy.random.RandomState((2015, 12, 16))
Expand Down Expand Up @@ -179,6 +182,7 @@ def normalize(x):


def test_batch_normalization_image_size_setter():
"""Test that setting image_size on a BatchNormalization works."""
bn = BatchNormalization()
bn.image_size = (5, 4)
assert bn.input_dim == (None, 5, 4)
Expand All @@ -187,6 +191,7 @@ def test_batch_normalization_image_size_setter():


def test_spatial_batch_normalization():
"""Smoke test for SpatialBatchNormalization."""
def check(*input_dim):
sbn = SpatialBatchNormalization(input_dim)
sbn.initialize()
Expand All @@ -204,6 +209,7 @@ def check(*input_dim):


def test_raise_exception_spatial():
"""Test that SpatialBatchNormalization raises an expected exception."""
# Work around a stupid bug in nose2 that unpacks the tuple into
# separate arguments.
yield assert_raises, (ValueError, SpatialBatchNormalization, (5,))
Expand All @@ -222,6 +228,7 @@ def do_not_fail(*input_dim):


def test_batch_normalization_inside_convolutional_sequence():
"""Test that BN bricks work in ConvolutionalSequences."""
conv_seq = ConvolutionalSequence(
[Convolutional(filter_size=(3, 3), num_filters=4),
BatchNormalization(broadcastable=(False, True, True)),
Expand Down Expand Up @@ -256,6 +263,7 @@ def test_batch_normalization_inside_convolutional_sequence():


def test_batch_normalized_mlp_construction():
"""Test that BatchNormalizedMLP performs construction correctly."""
mlp = BatchNormalizedMLP([Tanh(), Tanh()], [5, 7, 9])
assert all(isinstance(a, Sequence) for a in mlp.activations)
assert all(isinstance(a.children[0], BatchNormalization)
Expand All @@ -265,6 +273,7 @@ def test_batch_normalized_mlp_construction():


def test_batch_normalized_mlp_allocation():
"""Test that BatchNormalizedMLP performs allocation correctly."""
mlp = BatchNormalizedMLP([Tanh(), Tanh()], [5, 7, 9])
mlp.allocate()
assert mlp.activations[0].children[0].input_dim == 7
Expand All @@ -282,6 +291,7 @@ def test_batch_normalized_mlp_transformed():


def test_batch_normalized_mlp_save_memory_propagated():
"""Test that setting save_memory on a BatchNormalizedMLP works."""
mlp = BatchNormalizedMLP([Tanh(), Tanh()], [5, 7, 9],
save_memory=False)
assert not any(act.children[0].save_memory for act in mlp.activations)
Expand Down

0 comments on commit a4d0728

Please sign in to comment.