Skip to content

Commit

Permalink
Merge pull request #1078 from SwordYork/master
Browse files Browse the repository at this point in the history
Handle adding children to kwargs appropriately if a subclass has already added children.
  • Loading branch information
dwf committed May 5, 2016
2 parents dba9741 + 4fbde83 commit bf89d49
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 35 deletions.
10 changes: 5 additions & 5 deletions blocks/bricks/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,9 @@ def __init__(self, match_dim, state_transformer=None,
self.energy_computer = energy_computer

children = [self.state_transformers, attended_transformer,
energy_computer] + kwargs.get('children', [])
super(SequenceContentAttention, self).__init__(children=children,
**kwargs)
energy_computer]
kwargs.setdefault('children', []).extend(children)
super(SequenceContentAttention, self).__init__(**kwargs)

def _push_allocation_config(self):
self.state_transformers.input_dims = self.state_dims
Expand Down Expand Up @@ -576,8 +576,8 @@ def __init__(self, transition, attention, distribute=None,
if name in self.attention.take_glimpses.inputs]

children = [self.transition, self.attention, self.distribute]
children += kwargs.get('children', [])
super(AttentionRecurrent, self).__init__(children=children, **kwargs)
kwargs.setdefault('children', []).extend(children)
super(AttentionRecurrent, self).__init__(**kwargs)

def _push_allocation_config(self):
self.attention.state_dims = self.transition.get_dims(
Expand Down
20 changes: 11 additions & 9 deletions blocks/bricks/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,9 @@ class SimpleRecurrent(BaseRecurrent, Initializable):
@lazy(allocation=['dim'])
def __init__(self, dim, activation, **kwargs):
self.dim = dim
children = [activation] + kwargs.get('children', [])
super(SimpleRecurrent, self).__init__(children=children, **kwargs)
children = [activation]
kwargs.setdefault('children', []).extend(children)
super(SimpleRecurrent, self).__init__(**kwargs)

@property
def W(self):
Expand Down Expand Up @@ -384,9 +385,9 @@ def __init__(self, dim, activation=None, gate_activation=None, **kwargs):
self.activation = activation
self.gate_activation = gate_activation

children = ([self.activation, self.gate_activation] +
kwargs.get('children', []))
super(LSTM, self).__init__(children=children, **kwargs)
children = [self.activation, self.gate_activation]
kwargs.setdefault('children', []).extend(children)
super(LSTM, self).__init__(**kwargs)

def get_dim(self, name):
if name == 'inputs':
Expand Down Expand Up @@ -532,8 +533,9 @@ def __init__(self, dim, activation=None, gate_activation=None,
self.activation = activation
self.gate_activation = gate_activation

children = [activation, gate_activation] + kwargs.get('children', [])
super(GatedRecurrent, self).__init__(children=children, **kwargs)
children = [activation, gate_activation]
kwargs.setdefault('children', []).extend(children)
super(GatedRecurrent, self).__init__(**kwargs)

@property
def state_to_state(self):
Expand Down Expand Up @@ -644,8 +646,8 @@ def __init__(self, prototype, **kwargs):
children = [copy.deepcopy(prototype) for _ in range(2)]
children[0].name = 'forward'
children[1].name = 'backward'
children += kwargs.get('children', [])
super(Bidirectional, self).__init__(children=children, **kwargs)
kwargs.setdefault('children', []).extend(children)
super(Bidirectional, self).__init__(**kwargs)

@application
def apply(self, *args, **kwargs):
Expand Down
26 changes: 14 additions & 12 deletions blocks/bricks/sequence_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,8 @@ def __init__(self, readout, transition, fork, **kwargs):
self.fork = fork

children = [self.readout, self.fork, self.transition]
children += kwargs.get('children', [])
super(BaseSequenceGenerator, self).__init__(children=children,
**kwargs)
kwargs.setdefault('children', []).extend(children)
super(BaseSequenceGenerator, self).__init__(**kwargs)

@property
def _state_names(self):
Expand Down Expand Up @@ -529,8 +528,9 @@ def __init__(self, emitter=None, feedback_brick=None,
self.merged_dim = merged_dim

children = [self.emitter, self.feedback_brick, self.merge,
self.post_merge] + kwargs.get('children', [])
super(Readout, self).__init__(children=children, **kwargs)
self.post_merge]
kwargs.setdefault('children', []).extend(children)
super(Readout, self).__init__(**kwargs)

def _push_allocation_config(self):
self.emitter.readout_dim = self.get_dim('readouts')
Expand Down Expand Up @@ -688,8 +688,9 @@ class SoftmaxEmitter(AbstractEmitter, Initializable, Random):
def __init__(self, initial_output=0, **kwargs):
self.initial_output = initial_output
self.softmax = NDimensionalSoftmax()
children = [self.softmax] + kwargs.get('children', [])
super(SoftmaxEmitter, self).__init__(children=children, **kwargs)
children = [self.softmax]
kwargs.setdefault('children', []).extend(children)
super(SoftmaxEmitter, self).__init__(**kwargs)

@application
def probs(self, readouts):
Expand Down Expand Up @@ -749,8 +750,9 @@ def __init__(self, num_outputs=None, feedback_dim=None, **kwargs):
self.feedback_dim = feedback_dim

self.lookup = LookupTable(num_outputs, feedback_dim)
children = [self.lookup] + kwargs.get('children', [])
super(LookupFeedback, self).__init__(children=children, **kwargs)
children = [self.lookup]
kwargs.setdefault('children', []).extend(children)
super(LookupFeedback, self).__init__(**kwargs)

def _push_allocation_config(self):
self.lookup.length = self.num_outputs
Expand Down Expand Up @@ -791,9 +793,9 @@ def __init__(self, transition, **kwargs):
self.context_names = transition.apply.contexts
self.glimpse_names = []

children = [self.transition] + kwargs.get('children', [])
super(FakeAttentionRecurrent, self).__init__(children=children,
**kwargs)
children = [self.transition]
kwargs.setdefault('children', []).extend(children)
super(FakeAttentionRecurrent, self).__init__(**kwargs)

@application
def apply(self, *args, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions blocks/bricks/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def __init__(self, application_methods, **kwargs):
seen = set()
children = [app.brick for app in application_methods
if not (app.brick in seen or seen.add(app.brick))]
children += kwargs.get('children', [])
super(Sequence, self).__init__(children=children, **kwargs)
kwargs.setdefault('children', []).extend(children)
super(Sequence, self).__init__(**kwargs)

@application
def apply(self, *args):
Expand Down
5 changes: 3 additions & 2 deletions blocks/bricks/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,9 @@ class LinearMaxout(Initializable, Feedforward):
def __init__(self, input_dim, output_dim, num_pieces, **kwargs):
self.linear = Linear()
self.maxout = Maxout()
children = [self.linear, self.maxout] + kwargs.get('children', [])
super(LinearMaxout, self).__init__(children=children, **kwargs)
children = [self.linear, self.maxout]
kwargs.setdefault('children', []).extend(children)
super(LinearMaxout, self).__init__(**kwargs)

self.input_dim = input_dim
self.output_dim = output_dim
Expand Down
9 changes: 4 additions & 5 deletions docs/create_your_own_brick.rst
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ specify the ``input_dim`` of ``brick2`` directly at its creation.
... self.brick1 = brick1
... self.brick2 = brick2
... children = [self.brick1, self.brick2]
... children += kwargs.get('children', [])
... super(Feedforward, self).__init__(children=children, **kwargs)
... kwargs.setdefault('children', []).extend(children)
... super(Feedforward, self).__init__(**kwargs)
...
... @property
... def input_dim(self):
Expand Down Expand Up @@ -376,9 +376,8 @@ One can also create the brick using :class:`Linear` children bricks, which
... self.linear2 = Linear(input_dim2, output_dim2,
... use_bias=False, **kwargs)
... children = [self.linear1, self.linear2]
... children += kwargs.get('children', [])
... super(ParallelLinear2, self).__init__(children=children,
... **kwargs)
... kwargs.setdefault('children', []).extend(children)
... super(ParallelLinear2, self).__init__(**kwargs)
...
... @application(inputs=['input1_', 'input2_'], outputs=['output1',
... 'output2'])
Expand Down

0 comments on commit bf89d49

Please sign in to comment.