From b7c12efa3e148f0d5ba760ade58b121e40afa564 Mon Sep 17 00:00:00 2001 From: Nisarg Jhaveri Date: Tue, 10 Apr 2018 11:22:20 +0530 Subject: [PATCH] Move _standardize_args to recurrent, remove duplication --- keras/layers/convolutional_recurrent.py | 5 +- keras/layers/recurrent.py | 85 +++++++++++++------------ keras/layers/wrappers.py | 48 ++------------ 3 files changed, 50 insertions(+), 88 deletions(-) diff --git a/keras/layers/convolutional_recurrent.py b/keras/layers/convolutional_recurrent.py index e64fbe15db70..9c720673ee44 100644 --- a/keras/layers/convolutional_recurrent.py +++ b/keras/layers/convolutional_recurrent.py @@ -11,6 +11,7 @@ from .. import regularizers from .. import constraints from .recurrent import _generate_dropout_mask +from .recurrent import _standardize_args import numpy as np import warnings @@ -270,8 +271,8 @@ def get_initial_state(self, inputs): return [initial_state] def __call__(self, inputs, initial_state=None, constants=None, **kwargs): - inputs, initial_state, constants = self._standardize_args( - inputs, initial_state, constants) + inputs, initial_state, constants = _standardize_args( + inputs, initial_state, constants, self._num_constants) if initial_state is None and constants is None: return super(ConvRNN2D, self).__call__(inputs, **kwargs) diff --git a/keras/layers/recurrent.py b/keras/layers/recurrent.py index 952c6c77ce81..582858d7a1ca 100644 --- a/keras/layers/recurrent.py +++ b/keras/layers/recurrent.py @@ -493,8 +493,8 @@ def get_initial_state(self, inputs): return [K.tile(initial_state, [1, self.cell.state_size])] def __call__(self, inputs, initial_state=None, constants=None, **kwargs): - inputs, initial_state, constants = self._standardize_args( - inputs, initial_state, constants) + inputs, initial_state, constants = _standardize_args( + inputs, initial_state, constants, self._num_constants) if initial_state is None and constants is None: return super(RNN, self).__call__(inputs, **kwargs) @@ -633,46 +633,6 @@ def step(inputs, states): else: return output - def _standardize_args(self, inputs, initial_state, constants): - """Standardize `__call__` to a single list of tensor inputs. - - When running a model loaded from file, the input tensors - `initial_state` and `constants` can be passed to `RNN.__call__` as part - of `inputs` instead of by the dedicated keyword arguments. This method - makes sure the arguments are separated and that `initial_state` and - `constants` are lists of tensors (or None). - - # Arguments - inputs: tensor or list/tuple of tensors - initial_state: tensor or list of tensors or None - constants: tensor or list of tensors or None - - # Returns - inputs: tensor - initial_state: list of tensors or None - constants: list of tensors or None - """ - if isinstance(inputs, list): - assert initial_state is None and constants is None - if self._num_constants is not None: - constants = inputs[-self._num_constants:] - inputs = inputs[:-self._num_constants] - if len(inputs) > 1: - initial_state = inputs[1:] - inputs = inputs[0] - - def to_list_or_none(x): - if x is None or isinstance(x, list): - return x - if isinstance(x, tuple): - return list(x) - return [x] - - initial_state = to_list_or_none(initial_state) - constants = to_list_or_none(constants) - - return inputs, initial_state, constants - def reset_states(self, states=None): if not self.stateful: raise AttributeError('Layer must be stateful.') @@ -2272,3 +2232,44 @@ def dropped_inputs(): dropped_inputs, ones, training=training) + + +def _standardize_args(inputs, initial_state, constants, num_constants): + """Standardize `__call__` to a single list of tensor inputs. + + When running a model loaded from file, the input tensors + `initial_state` and `constants` can be passed to `RNN.__call__` as part + of `inputs` instead of by the dedicated keyword arguments. This method + makes sure the arguments are separated and that `initial_state` and + `constants` are lists of tensors (or None). + + # Arguments + inputs: tensor or list/tuple of tensors + initial_state: tensor or list of tensors or None + constants: tensor or list of tensors or None + + # Returns + inputs: tensor + initial_state: list of tensors or None + constants: list of tensors or None + """ + if isinstance(inputs, list): + assert initial_state is None and constants is None + if num_constants is not None: + constants = inputs[-num_constants:] + inputs = inputs[:-num_constants] + if len(inputs) > 1: + initial_state = inputs[1:] + inputs = inputs[0] + + def to_list_or_none(x): + if x is None or isinstance(x, list): + return x + if isinstance(x, tuple): + return list(x) + return [x] + + initial_state = to_list_or_none(initial_state) + constants = to_list_or_none(constants) + + return inputs, initial_state, constants diff --git a/keras/layers/wrappers.py b/keras/layers/wrappers.py index ae956f80e099..276729294243 100644 --- a/keras/layers/wrappers.py +++ b/keras/layers/wrappers.py @@ -12,6 +12,8 @@ from ..utils.generic_utils import has_arg from .. import backend as K +from . import recurrent + class Wrapper(Layer): """Abstract wrapper base class. @@ -316,8 +318,8 @@ def compute_output_shape(self, input_shape): return output_shape def __call__(self, inputs, initial_state=None, constants=None, **kwargs): - inputs, initial_state, constants = self._standardize_args( - inputs, initial_state, constants) + inputs, initial_state, constants = recurrent._standardize_args( + inputs, initial_state, constants, self._num_constants) if initial_state is None and constants is None: return super(Bidirectional, self).__call__(inputs, **kwargs) @@ -434,48 +436,6 @@ def call(self, return [output] + states return output - def _standardize_args(self, inputs, initial_state, constants): - """Standardize `__call__` to a single list of tensor inputs. - - This method is taken from `RNN` as it is. - - When running a model loaded from file, the input tensors - `initial_state` and `constants` can be passed to `RNN.__call__` as part - of `inputs` instead of by the dedicated keyword arguments. This method - makes sure the arguments are separated and that `initial_state` and - `constants` are lists of tensors (or None). - - # Arguments - inputs: tensor or list/tuple of tensors - initial_state: tensor or list of tensors or None - constants: tensor or list of tensors or None - - # Returns - inputs: tensor - initial_state: list of tensors or None - constants: list of tensors or None - """ - if isinstance(inputs, list): - assert initial_state is None and constants is None - if self._num_constants is not None: - constants = inputs[-self._num_constants:] - inputs = inputs[:-self._num_constants] - if len(inputs) > 1: - initial_state = inputs[1:] - inputs = inputs[0] - - def to_list_or_none(x): - if x is None or isinstance(x, list): - return x - if isinstance(x, tuple): - return list(x) - return [x] - - initial_state = to_list_or_none(initial_state) - constants = to_list_or_none(constants) - - return inputs, initial_state, constants - def reset_states(self): self.forward_layer.reset_states() self.backward_layer.reset_states()