Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
farizrahman4u committed May 7, 2017
1 parent 5635454 commit a982d34
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions recurrentshop/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,6 @@ def __init__(self, state_sync=False, decode=False, output_length=None, return_st
state_initializer = initializers.get(state_initializer)
self._state_initializer = state_initializer


@property
def state_initializer(self):
if self._state_initializer is None:
Expand Down Expand Up @@ -875,26 +874,29 @@ def num_states(self):

def add(self, cell):
self.cells.append(cell)
cell_input_shape = cell.batch_input_shape
if set(map(type, list(set(cell_input_shape) - set([None])))) != set([int]):
cell_input_shape = cell_input_shape[0]
if len(self.cells) == 1:
cell_input_shape = cell.batch_input_shape
if set(map(type, list(set(cell_input_shape) - set([None])))) != set([int]):
cell_input_shape = cell_input_shape[0]
if self.decode:
self.input_spec = InputSpec(shape=cell_input_shape)
else:
self.input_spec = InputSpec(shape=cell_input_shape[:1] + (None,) + cell_input_shape[1:])
batch_size = cell_input_shape[0]
if batch_size is not None:
self.batch_size = batch_size
if not self.stateful:
self.states = [None] * self.num_states

def build(self, input_shape):
if hasattr(self, 'model'):
del self.model
# Try and get batch size for initializer
for cell in self.cells:
if hasattr(cell, 'batch_input_shape'):
if cell.batch_input_shape[0] is not None:
self.batch_size = cell.batch_input_shape[0]
break
if not hasattr(self, 'batch_size'):
if hasattr(self, 'batch_input_shape'):
batch_size = self.batch_input_shape[0]
if batch_size is not None:
self.batch_size = batch_size
if self.state_sync:
if type(input_shape) is list:
x_shape = input_shape[0]
Expand Down Expand Up @@ -1009,7 +1011,7 @@ def build(self, input_shape):
def get_config(self):
if not hasattr(self, 'model'):
if len(self.cells) > 0:
input_shape = self.cells[0].batch_input_shape
input_shape = self.input_spec.shape
self.build(input_shape)
else:
self.model = None
Expand Down

0 comments on commit a982d34

Please sign in to comment.