Skip to content

Commit

Permalink
Remove use_dropout from snt.nets.MLP.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 256664562
Change-Id: I64d859a63d70ff59b6a5100d54c1fffa0882bbf0
  • Loading branch information
tomhennigan authored and sonnet-copybara committed Jul 5, 2019
1 parent da4f2a9 commit 63a3d26
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 27 deletions.
20 changes: 6 additions & 14 deletions sonnet/src/nets/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def __init__(self,
b_init=None,
with_bias=True,
activation=tf.nn.relu,
use_dropout=False,
dropout_rate=None,
activate_final=False,
name=None):
Expand All @@ -46,9 +45,8 @@ def __init__(self,
with_bias: Whether or not to apply a bias in each layer.
activation: Activation function to apply between linear layers. Defaults
to ReLU.
use_dropout: Whether to apply dropout to all hidden layers. Default
`False`.
dropout_rate: Dropout rate applied if using dropout.
dropout_rate: Dropout rate to apply, a rate of `None` (the default) or `0`
means no dropout will be applied.
activate_final: Whether or not to activate the final layer of the MLP.
name: Optional name for this module.
Expand All @@ -58,18 +56,12 @@ def __init__(self,
if not with_bias and b_init is not None:
raise ValueError("When with_bias=False b_init must not be set.")

if use_dropout and dropout_rate is None:
raise ValueError("When use_dropout=True dropout_rate must be set.")
elif not use_dropout and dropout_rate is not None:
raise ValueError("When use_dropout=False dropout_rate must not be set.")

super(MLP, self).__init__(name=name)
self._with_bias = with_bias
self._w_init = w_init
self._b_init = b_init
self._activation = activation
self._activate_final = activate_final
self._use_dropout = use_dropout
self._dropout_rate = dropout_rate
self._layers = []
for index, output_size in enumerate(output_sizes):
Expand All @@ -92,10 +84,11 @@ def __call__(self, inputs, is_training=None):
Returns:
output: The output of the model of size `[batch_size, output_size]`.
"""
if self._use_dropout and is_training is None:
use_dropout = self._dropout_rate not in (None, 0)
if use_dropout and is_training is None:
raise ValueError(
"The `is_training` argument is required when dropout is used.")
elif not self._use_dropout and is_training is not None:
elif not use_dropout and is_training is not None:
raise ValueError(
"The `is_training` argument should only be used with dropout.")

Expand All @@ -105,7 +98,7 @@ def __call__(self, inputs, is_training=None):
inputs = layer(inputs)
if i < (num_layers - 1) or self._activate_final:
# Only perform dropout if we are activating the output.
if self._use_dropout and is_training:
if use_dropout and is_training:
inputs = tf.nn.dropout(inputs, rate=self._dropout_rate)
inputs = self._activation(inputs)

Expand Down Expand Up @@ -150,7 +143,6 @@ def reverse(self, activate_final=None, name=None):
b_init=self._b_init,
with_bias=self._with_bias,
activation=self._activation,
use_dropout=self._use_dropout,
dropout_rate=self._dropout_rate,
activate_final=activate_final,
name=name)
17 changes: 4 additions & 13 deletions sonnet/src/nets/mlp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ def test_b_init_when_with_bias_false(self):
mlp.MLP([1], with_bias=False, b_init=object())

@parameterized.parameters(
itertools.product((1, 2, 3), (True, False), (0.1, 0.0)))
def test_submodules(self, num_layers, use_dropout, dropout_rate):
mod = mlp.MLP([1] * num_layers)
itertools.product((1, 2, 3), (0.1, 0.0, None)))
def test_submodules(self, num_layers, dropout_rate):
mod = mlp.MLP([1] * num_layers, dropout_rate=dropout_rate)
self.assertLen(mod.submodules, num_layers)

@parameterized.parameters(1, 2, 3)
Expand Down Expand Up @@ -122,16 +122,8 @@ def test_reverse_activation(self):
mod(tf.ones([1, 1]))
self.assertEqual(activation.count, 2)

def test_dropout_requires_rate(self):
with self.assertRaisesRegexp(ValueError, "dropout_rate must be set"):
mlp.MLP([1, 1], use_dropout=True)

def test_no_dropout_rejects_rate(self):
with self.assertRaisesRegexp(ValueError, "dropout_rate must not be set"):
mlp.MLP([1, 1], dropout_rate=0.5)

def test_dropout_requires_is_training(self):
mod = mlp.MLP([1, 1], use_dropout=True, dropout_rate=0.5)
mod = mlp.MLP([1, 1], dropout_rate=0.5)
with self.assertRaisesRegexp(ValueError, "is_training.* is required"):
mod(tf.ones([1, 1]))

Expand All @@ -153,7 +145,6 @@ def test_reverse_activate_final(self, activate_final):
def test_applies_activation_with_dropout(self, use_dropout, is_training):
activation = CountingActivation()
mod = mlp.MLP([1, 1, 1],
use_dropout=use_dropout,
dropout_rate=(0.5 if use_dropout else None),
activation=activation)
mod(tf.ones([1, 1]), is_training=(is_training if use_dropout else None))
Expand Down

0 comments on commit 63a3d26

Please sign in to comment.