Skip to content

Commit

Permalink
Merge pull request #1134 from rizar/return_initial_states_and_outputs
Browse files Browse the repository at this point in the history
Fix a bug involving return_initial_states=True and outputs
  • Loading branch information
rizar committed Aug 2, 2016
2 parents 2c1d7a5 + 81c8c79 commit a3c8404
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 6 deletions.
2 changes: 1 addition & 1 deletion blocks/bricks/recurrent/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_dim(self, name):
if name == 'mask':
return 0
if name in (SimpleRecurrent.apply.sequences +
SimpleRecurrent.apply.states):
SimpleRecurrent.apply.states):
return self.dim
return super(SimpleRecurrent, self).get_dim(name)

Expand Down
11 changes: 6 additions & 5 deletions blocks/bricks/recurrent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def recurrent_apply(brick, application, application_call,
if key not in scan_arguments}
for value in rest_kwargs.values():
if (isinstance(value, Variable) and not
is_shared_variable(value)):
is_shared_variable(value)):
logger.warning("unknown input {}".format(value) +
unknown_scan_input)

Expand Down Expand Up @@ -227,10 +227,11 @@ def scan_function(*args):
result = pack(result)
if return_initial_states:
# Undo Subtensor
for i in range(len(states_given)):
assert isinstance(result[i].owner.op,
tensor.subtensor.Subtensor)
result[i] = result[i].owner.inputs[0]
for i, info in enumerate(outputs_info):
if info is not None:
assert isinstance(result[i].owner.op,
tensor.subtensor.Subtensor)
result[i] = result[i].owner.inputs[0]
if updates:
application_call.updates = dict_union(application_call.updates,
updates)
Expand Down
19 changes: 19 additions & 0 deletions tests/bricks/test_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,25 @@ def do():
assert_raises(KeyError, do)


class RecurrentBrickWithOutputs(BaseRecurrent):

@recurrent(sequences=[], contexts=[],
states=['states'], outputs=['outputs', 'states'])
def apply(self, states):
return states + 1, states + 1

def get_dim(self, name):
return 4


def test_return_initial_states_with_outputs():
brick = RecurrentBrickWithOutputs()
outputs, states = brick.apply(
n_steps=3, batch_size=5, return_initial_states=True)
assert_allclose(outputs.eval()[0], numpy.ones((5, 4)))
assert_allclose(states.eval()[0], numpy.zeros((5, 4)))


class TestSimpleRecurrent(unittest.TestCase):
def setUp(self):
self.simple = SimpleRecurrent(dim=3, weights_init=Constant(2),
Expand Down

0 comments on commit a3c8404

Please sign in to comment.