Skip to content

Commit

Permalink
Remove skip_connnection option from ConvLSTM.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 192131304
  • Loading branch information
fvioladm authored and diegolascasas committed Apr 10, 2018
1 parent 600268f commit a86f044
Showing 1 changed file with 0 additions and 12 deletions.
12 changes: 0 additions & 12 deletions sonnet/python/modules/gated_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,7 +1236,6 @@ def __init__(self,
rate=1,
padding=conv.SAME,
use_bias=True,
skip_connection=False,
forget_bias=1.0,
initializers=None,
partitioners=None,
Expand All @@ -1259,8 +1258,6 @@ def __init__(self,
convolution. Cannot be > 1 if any of stride is also > 1.
padding: Padding algorithm, either `snt.SAME` or `snt.VALID`.
use_bias: Use bias in convolutions.
skip_connection: If set to `True`, concatenate the input to the output
of the conv LSTM. Default: `False`.
forget_bias: Forget bias.
initializers: Dict containing ops to initialize the convolutional weights.
partitioners: Optional dict containing partitioners to partition
Expand All @@ -1281,9 +1278,6 @@ def __init__(self,

self._conv_class = self._get_conv_class(conv_ndims)

if skip_connection and stride != 1:
raise ValueError("`stride` needs to be 1 when using skip connection")

if conv_ndims != len(input_shape)-1:
raise ValueError("Invalid input_shape {} for conv_ndims={}.".format(
input_shape, conv_ndims))
Expand All @@ -1297,16 +1291,13 @@ def __init__(self,
self._padding = padding
self._use_bias = use_bias
self._forget_bias = forget_bias
self._skip_connection = skip_connection
self._initializers = initializers
self._partitioners = partitioners
self._regularizers = regularizers

self._total_output_channels = output_channels
if self._stride != 1:
self._total_output_channels //= self._stride * self._stride
if self._skip_connection:
self._total_output_channels += self._input_shape[-1]

self._convolutions = collections.defaultdict(self._new_convolution)

Expand Down Expand Up @@ -1352,9 +1343,6 @@ def _build(self, inputs, state):
next_cell = tf.sigmoid(forget_gate + self._forget_bias) * cell
next_cell += tf.sigmoid(input_gate) * tf.tanh(next_input)
output = tf.tanh(next_cell) * tf.sigmoid(output_gate)

if self._skip_connection:
output = tf.concat([output, inputs], axis=-1)
return output, (output, next_cell)


Expand Down

0 comments on commit a86f044

Please sign in to comment.