Skip to content

Commit

Permalink
Move _standardize_args to recurrent, remove duplication
Browse files Browse the repository at this point in the history
  • Loading branch information
nisargjhaveri committed Apr 10, 2018
1 parent 5d4e068 commit b7c12ef
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 88 deletions.
5 changes: 3 additions & 2 deletions keras/layers/convolutional_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .. import regularizers
from .. import constraints
from .recurrent import _generate_dropout_mask
from .recurrent import _standardize_args

import numpy as np
import warnings
Expand Down Expand Up @@ -270,8 +271,8 @@ def get_initial_state(self, inputs):
return [initial_state]

def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
inputs, initial_state, constants = self._standardize_args(
inputs, initial_state, constants)
inputs, initial_state, constants = _standardize_args(
inputs, initial_state, constants, self._num_constants)

if initial_state is None and constants is None:
return super(ConvRNN2D, self).__call__(inputs, **kwargs)
Expand Down
85 changes: 43 additions & 42 deletions keras/layers/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,8 +493,8 @@ def get_initial_state(self, inputs):
return [K.tile(initial_state, [1, self.cell.state_size])]

def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
inputs, initial_state, constants = self._standardize_args(
inputs, initial_state, constants)
inputs, initial_state, constants = _standardize_args(
inputs, initial_state, constants, self._num_constants)

if initial_state is None and constants is None:
return super(RNN, self).__call__(inputs, **kwargs)
Expand Down Expand Up @@ -633,46 +633,6 @@ def step(inputs, states):
else:
return output

def _standardize_args(self, inputs, initial_state, constants):
"""Standardize `__call__` to a single list of tensor inputs.
When running a model loaded from file, the input tensors
`initial_state` and `constants` can be passed to `RNN.__call__` as part
of `inputs` instead of by the dedicated keyword arguments. This method
makes sure the arguments are separated and that `initial_state` and
`constants` are lists of tensors (or None).
# Arguments
inputs: tensor or list/tuple of tensors
initial_state: tensor or list of tensors or None
constants: tensor or list of tensors or None
# Returns
inputs: tensor
initial_state: list of tensors or None
constants: list of tensors or None
"""
if isinstance(inputs, list):
assert initial_state is None and constants is None
if self._num_constants is not None:
constants = inputs[-self._num_constants:]
inputs = inputs[:-self._num_constants]
if len(inputs) > 1:
initial_state = inputs[1:]
inputs = inputs[0]

def to_list_or_none(x):
if x is None or isinstance(x, list):
return x
if isinstance(x, tuple):
return list(x)
return [x]

initial_state = to_list_or_none(initial_state)
constants = to_list_or_none(constants)

return inputs, initial_state, constants

def reset_states(self, states=None):
if not self.stateful:
raise AttributeError('Layer must be stateful.')
Expand Down Expand Up @@ -2272,3 +2232,44 @@ def dropped_inputs():
dropped_inputs,
ones,
training=training)


def _standardize_args(inputs, initial_state, constants, num_constants):
"""Standardize `__call__` to a single list of tensor inputs.
When running a model loaded from file, the input tensors
`initial_state` and `constants` can be passed to `RNN.__call__` as part
of `inputs` instead of by the dedicated keyword arguments. This method
makes sure the arguments are separated and that `initial_state` and
`constants` are lists of tensors (or None).
# Arguments
inputs: tensor or list/tuple of tensors
initial_state: tensor or list of tensors or None
constants: tensor or list of tensors or None
# Returns
inputs: tensor
initial_state: list of tensors or None
constants: list of tensors or None
"""
if isinstance(inputs, list):
assert initial_state is None and constants is None
if num_constants is not None:
constants = inputs[-num_constants:]
inputs = inputs[:-num_constants]
if len(inputs) > 1:
initial_state = inputs[1:]
inputs = inputs[0]

def to_list_or_none(x):
if x is None or isinstance(x, list):
return x
if isinstance(x, tuple):
return list(x)
return [x]

initial_state = to_list_or_none(initial_state)
constants = to_list_or_none(constants)

return inputs, initial_state, constants
48 changes: 4 additions & 44 deletions keras/layers/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from ..utils.generic_utils import has_arg
from .. import backend as K

from . import recurrent


class Wrapper(Layer):
"""Abstract wrapper base class.
Expand Down Expand Up @@ -316,8 +318,8 @@ def compute_output_shape(self, input_shape):
return output_shape

def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
inputs, initial_state, constants = self._standardize_args(
inputs, initial_state, constants)
inputs, initial_state, constants = recurrent._standardize_args(
inputs, initial_state, constants, self._num_constants)

if initial_state is None and constants is None:
return super(Bidirectional, self).__call__(inputs, **kwargs)
Expand Down Expand Up @@ -434,48 +436,6 @@ def call(self,
return [output] + states
return output

def _standardize_args(self, inputs, initial_state, constants):
"""Standardize `__call__` to a single list of tensor inputs.
This method is taken from `RNN` as it is.
When running a model loaded from file, the input tensors
`initial_state` and `constants` can be passed to `RNN.__call__` as part
of `inputs` instead of by the dedicated keyword arguments. This method
makes sure the arguments are separated and that `initial_state` and
`constants` are lists of tensors (or None).
# Arguments
inputs: tensor or list/tuple of tensors
initial_state: tensor or list of tensors or None
constants: tensor or list of tensors or None
# Returns
inputs: tensor
initial_state: list of tensors or None
constants: list of tensors or None
"""
if isinstance(inputs, list):
assert initial_state is None and constants is None
if self._num_constants is not None:
constants = inputs[-self._num_constants:]
inputs = inputs[:-self._num_constants]
if len(inputs) > 1:
initial_state = inputs[1:]
inputs = inputs[0]

def to_list_or_none(x):
if x is None or isinstance(x, list):
return x
if isinstance(x, tuple):
return list(x)
return [x]

initial_state = to_list_or_none(initial_state)
constants = to_list_or_none(constants)

return inputs, initial_state, constants

def reset_states(self):
self.forward_layer.reset_states()
self.backward_layer.reset_states()
Expand Down

0 comments on commit b7c12ef

Please sign in to comment.