Skip to content

Commit

Permalink
Merge pull request #761 from udibr/RecurrentStack
Browse files Browse the repository at this point in the history
fix minor bugs when handling masks in SequenceGenerator/Attention
  • Loading branch information
dmitriy-serdyuk committed Jul 14, 2015
2 parents cfd0633 + efbe048 commit 1e0aca9
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 10 deletions.
14 changes: 8 additions & 6 deletions blocks/bricks/attention.py
Expand Up @@ -605,8 +605,8 @@ def take_glimpses(self, **kwargs):
kwargs.pop(self.preprocessed_attended_name, None),
kwargs.pop(self.attended_mask_name, None),
**dict_union(states, glimpses_needed))
if kwargs:
raise ValueError("extra args to take_glimpses: {}".format(kwargs))
# At this point kwargs may contain additional items.
# e.g. AttentionRecurrent.transition.apply.contexts
return result

@take_glimpses.property('outputs')
Expand Down Expand Up @@ -634,13 +634,15 @@ def compute_states(self, **kwargs):
Current states computed by `self.transition`.
"""
# Masks are not mandatory, that's why 'must_have=False'
sequences = dict_subset(kwargs, self._sequence_names,
pop=True, must_have=False)
# make sure we are not popping the mask
normal_inputs = [name for name in self._sequence_names
if 'mask' not in name]
sequences = dict_subset(kwargs, normal_inputs, pop=True)
glimpses = dict_subset(kwargs, self._glimpse_names, pop=True)
if self.add_contexts:
kwargs.pop(self.attended_name)
kwargs.pop(self.attended_mask_name)
# attended_mask_name can be optional
kwargs.pop(self.attended_mask_name, None)

sequences.update(self.distribute.apply(
as_dict=True, **dict_subset(dict_union(sequences, glimpses),
Expand Down
6 changes: 4 additions & 2 deletions blocks/bricks/sequence_generators.py
Expand Up @@ -248,7 +248,8 @@ def cost_matrix(self, application_call, outputs, mask=None, **kwargs):

# Prepare input for the iterative part
states = dict_subset(kwargs, self._state_names, must_have=False)
contexts = dict_subset(kwargs, self._context_names)
# masks in context are optional (e.g. `attended_mask`)
contexts = dict_subset(kwargs, self._context_names, must_have=False)
feedback = self.readout.feedback(outputs)
inputs = self.fork.apply(feedback, as_dict=True)

Expand Down Expand Up @@ -297,7 +298,8 @@ def generate(self, outputs, **kwargs):
"""
states = dict_subset(kwargs, self._state_names)
contexts = dict_subset(kwargs, self._context_names)
# masks in context are optional (e.g. `attended_mask`)
contexts = dict_subset(kwargs, self._context_names, must_have=False)
glimpses = dict_subset(kwargs, self._glimpse_names)

next_glimpses = self.transition.take_glimpses(
Expand Down
10 changes: 8 additions & 2 deletions blocks/monitoring/evaluators.py
Expand Up @@ -262,9 +262,15 @@ def _compile(self):
if self.theano_buffer.accumulation_updates:
updates = OrderedDict()
updates.update(self.theano_buffer.accumulation_updates)
if self.updates:
updates.update(self.updates)
inputs += self.theano_buffer.inputs
if self.updates:
# Handle the case in which we dont have any theano variables
# to evaluate but we do have MonitoredQuantity
# that may require an update of their own
if updates is None:
updates = self.updates
else:
updates.update(self.updates)
inputs += self.monitored_quantities_buffer.inputs
outputs = self.monitored_quantities_buffer.requires

Expand Down

0 comments on commit 1e0aca9

Please sign in to comment.