Skip to content

Commit

Permalink
Merge pull request #1029 from memimo/children2
Browse files Browse the repository at this point in the history
Pass children as argument to Brick
  • Loading branch information
rizar committed Mar 14, 2016
2 parents 0bd6ee8 + e7c6e84 commit 0e1696d
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 49 deletions.
14 changes: 8 additions & 6 deletions blocks/bricks/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,13 +309,12 @@ class SequenceContentAttention(GenericSequenceAttention, Initializable):
@lazy(allocation=['match_dim'])
def __init__(self, match_dim, state_transformer=None,
attended_transformer=None, energy_computer=None, **kwargs):
super(SequenceContentAttention, self).__init__(**kwargs)
if not state_transformer:
state_transformer = Linear(use_bias=False)
self.match_dim = match_dim
self.state_transformer = state_transformer

self.state_transformers = Parallel(input_names=self.state_names,
self.state_transformers = Parallel(input_names=kwargs['state_names'],
prototype=state_transformer,
name="state_trans")
if not attended_transformer:
Expand All @@ -325,8 +324,10 @@ def __init__(self, match_dim, state_transformer=None,
self.attended_transformer = attended_transformer
self.energy_computer = energy_computer

self.children = [self.state_transformers, attended_transformer,
energy_computer]
children = [self.state_transformers, attended_transformer,
energy_computer] + kwargs.get('children', [])
super(SequenceContentAttention, self).__init__(children=children,
**kwargs)

def _push_allocation_config(self):
self.state_transformers.input_dims = self.state_dims
Expand Down Expand Up @@ -540,7 +541,6 @@ def __init__(self, transition, attention, distribute=None,
add_contexts=True,
attended_name=None, attended_mask_name=None,
**kwargs):
super(AttentionRecurrent, self).__init__(**kwargs)
self._sequence_names = list(transition.apply.sequences)
self._state_names = list(transition.apply.states)
self._context_names = list(transition.apply.contexts)
Expand Down Expand Up @@ -575,7 +575,9 @@ def __init__(self, transition, attention, distribute=None,
name for name in self._glimpse_names
if name in self.attention.take_glimpses.inputs]

self.children = [self.transition, self.attention, self.distribute]
children = [self.transition, self.attention, self.distribute]
children += kwargs.get('children', [])
super(AttentionRecurrent, self).__init__(children=children, **kwargs)

def _push_allocation_config(self):
self.attention.state_dims = self.transition.get_dims(
Expand Down
9 changes: 6 additions & 3 deletions blocks/bricks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,12 +546,15 @@ class Brick(Annotation):
#: See :attr:`Brick.print_shapes`
print_shapes = False

def __init__(self, name=None):
def __init__(self, name=None, children=None):
if name is None:
name = self.__class__.__name__.lower()
self.name = name

self.children = []
if children is None:
children = []

self.name = name
self.children = children
self.parents = []

self.allocated = False
Expand Down
21 changes: 11 additions & 10 deletions blocks/bricks/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,9 @@ class SimpleRecurrent(BaseRecurrent, Initializable):
"""
@lazy(allocation=['dim'])
def __init__(self, dim, activation, **kwargs):
super(SimpleRecurrent, self).__init__(**kwargs)
self.dim = dim
self.children = [activation]
children = [activation] + kwargs.get('children', [])
super(SimpleRecurrent, self).__init__(children=children, **kwargs)

@property
def W(self):
Expand Down Expand Up @@ -370,12 +370,12 @@ class LSTM(BaseRecurrent, Initializable):
"""
@lazy(allocation=['dim'])
def __init__(self, dim, activation=None, **kwargs):
super(LSTM, self).__init__(**kwargs)
self.dim = dim

if not activation:
activation = Tanh()
self.children = [activation]
children = [activation] + kwargs.get('children', [])
super(LSTM, self).__init__(children=children, **kwargs)

def get_dim(self, name):
if name == 'inputs':
Expand Down Expand Up @@ -513,7 +513,6 @@ class GatedRecurrent(BaseRecurrent, Initializable):
@lazy(allocation=['dim'])
def __init__(self, dim, activation=None, gate_activation=None,
**kwargs):
super(GatedRecurrent, self).__init__(**kwargs)
self.dim = dim

if not activation:
Expand All @@ -523,7 +522,8 @@ def __init__(self, dim, activation=None, gate_activation=None,
self.activation = activation
self.gate_activation = gate_activation

self.children = [activation, gate_activation]
children = [activation, gate_activation] + kwargs.get('children', [])
super(GatedRecurrent, self).__init__(children=children, **kwargs)

@property
def state_to_state(self):
Expand Down Expand Up @@ -629,12 +629,13 @@ class Bidirectional(Initializable):

@lazy()
def __init__(self, prototype, **kwargs):
super(Bidirectional, self).__init__(**kwargs)
self.prototype = prototype

self.children = [copy.deepcopy(prototype) for _ in range(2)]
self.children[0].name = 'forward'
self.children[1].name = 'backward'
children = [copy.deepcopy(prototype) for _ in range(2)]
children[0].name = 'forward'
children[1].name = 'backward'
children += kwargs.get('children', [])
super(Bidirectional, self).__init__(children=children, **kwargs)

@application
def apply(self, *args, **kwargs):
Expand Down
38 changes: 20 additions & 18 deletions blocks/bricks/sequence_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,14 @@ class BaseSequenceGenerator(Initializable):
"""
@lazy()
def __init__(self, readout, transition, fork, **kwargs):
super(BaseSequenceGenerator, self).__init__(**kwargs)
self.readout = readout
self.transition = transition
self.fork = fork

self.children = [self.readout, self.fork, self.transition]
children = [self.readout, self.fork, self.transition]
children += kwargs.get('children', [])
super(BaseSequenceGenerator, self).__init__(children=children,
**kwargs)

@property
def _state_names(self):
Expand Down Expand Up @@ -508,27 +510,27 @@ class Readout(AbstractReadout):
def __init__(self, emitter=None, feedback_brick=None,
merge=None, merge_prototype=None,
post_merge=None, merged_dim=None, **kwargs):
super(Readout, self).__init__(**kwargs)

if not emitter:
emitter = TrivialEmitter(self.readout_dim)
emitter = TrivialEmitter(kwargs['readout_dim'])
if not feedback_brick:
feedback_brick = TrivialFeedback(self.readout_dim)
feedback_brick = TrivialFeedback(kwargs['readout_dim'])
if not merge:
merge = Merge(input_names=self.source_names,
merge = Merge(input_names=kwargs['source_names'],
prototype=merge_prototype)
if not post_merge:
post_merge = Bias(dim=self.readout_dim)
post_merge = Bias(dim=kwargs['readout_dim'])
if not merged_dim:
merged_dim = self.readout_dim
merged_dim = kwargs['readout_dim']
self.emitter = emitter
self.feedback_brick = feedback_brick
self.merge = merge
self.post_merge = post_merge
self.merged_dim = merged_dim

self.children = [self.emitter, self.feedback_brick,
self.merge, self.post_merge]
children = [self.emitter, self.feedback_brick, self.merge,
self.post_merge] + kwargs.get('children', [])
super(Readout, self).__init__(children=children, **kwargs)

def _push_allocation_config(self):
self.emitter.readout_dim = self.get_dim('readouts')
Expand Down Expand Up @@ -684,10 +686,10 @@ class SoftmaxEmitter(AbstractEmitter, Initializable, Random):
"""
def __init__(self, initial_output=0, **kwargs):
super(SoftmaxEmitter, self).__init__(**kwargs)
self.initial_output = initial_output
self.softmax = NDimensionalSoftmax()
self.children = [self.softmax]
children = [self.softmax] + kwargs.get('children', [])
super(SoftmaxEmitter, self).__init__(children=children, **kwargs)

@application
def probs(self, readouts):
Expand Down Expand Up @@ -743,13 +745,12 @@ class LookupFeedback(AbstractFeedback, Initializable):
"""
def __init__(self, num_outputs=None, feedback_dim=None, **kwargs):
super(LookupFeedback, self).__init__(**kwargs)
self.num_outputs = num_outputs
self.feedback_dim = feedback_dim

self.lookup = LookupTable(num_outputs, feedback_dim,
weights_init=self.weights_init)
self.children = [self.lookup]
self.lookup = LookupTable(num_outputs, feedback_dim)
children = [self.lookup] + kwargs.get('children', [])
super(LookupFeedback, self).__init__(children=children, **kwargs)

def _push_allocation_config(self):
self.lookup.length = self.num_outputs
Expand Down Expand Up @@ -784,14 +785,15 @@ class FakeAttentionRecurrent(AbstractAttentionRecurrent, Initializable):
"""
def __init__(self, transition, **kwargs):
super(FakeAttentionRecurrent, self).__init__(**kwargs)
self.transition = transition

self.state_names = transition.apply.states
self.context_names = transition.apply.contexts
self.glimpse_names = []

self.children = [self.transition]
children = [self.transition] + kwargs.get('children', [])
super(FakeAttentionRecurrent, self).__init__(children=children,
**kwargs)

@application
def apply(self, *args, **kwargs):
Expand Down
7 changes: 4 additions & 3 deletions blocks/bricks/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ class Sequence(Brick):
"""
def __init__(self, application_methods, **kwargs):
super(Sequence, self).__init__(**kwargs)
self.application_methods = application_methods

seen = set()
self.children = [app.brick for app in application_methods
if not (app.brick in seen or seen.add(app.brick))]
children = [app.brick for app in application_methods
if not (app.brick in seen or seen.add(app.brick))]
children += kwargs.get('children', [])
super(Sequence, self).__init__(children=children, **kwargs)

@application
def apply(self, *args):
Expand Down
5 changes: 2 additions & 3 deletions blocks/bricks/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,10 @@ class LinearMaxout(Initializable, Feedforward):
"""
@lazy(allocation=['input_dim', 'output_dim', 'num_pieces'])
def __init__(self, input_dim, output_dim, num_pieces, **kwargs):
super(LinearMaxout, self).__init__(**kwargs)
self.linear = Linear()
self.maxout = Maxout()
self.children = [self.linear,
self.maxout]
children = [self.linear, self.maxout] + kwargs.get('children', [])
super(LinearMaxout, self).__init__(children=children, **kwargs)

self.input_dim = input_dim
self.output_dim = output_dim
Expand Down
15 changes: 9 additions & 6 deletions docs/create_your_own_brick.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ of :class:`.Brick` for a precise description of the life-cycle of a brick):

* :meth:`.Brick.__init__`: you should pass by argument the attributes of your
brick. It is also in this method that you should create the potential
"children bricks" that belongs to your brick (in that case, you have to put
the children bricks into ``self.children``). The initialization of the
"children bricks" that belongs to your brick (in that case, you have to pass
the children bricks to ``super().__init__``). The initialization of the
attributes can be lazy as described later in the tutorial.
* :meth:`apply`: you need to implement a method that actually
implements the operation of the brick, taking as arguments the inputs
Expand Down Expand Up @@ -210,10 +210,11 @@ specify the ``input_dim`` of ``brick2`` directly at its creation.
>>> class ChainOfTwoFeedforward(Feedforward):
... """Two sequential Feedforward bricks."""
... def __init__(self, brick1, brick2, **kwargs):
... super(Feedforward, self).__init__(**kwargs)
... self.brick1 = brick1
... self.brick2 = brick2
... self.children = [self.brick1, self.brick2]
... children = [self.brick1, self.brick2]
... children += kwargs.get('children', [])
... super(Feedforward, self).__init__(children=children, **kwargs)
...
... @property
... def input_dim(self):
Expand Down Expand Up @@ -370,12 +371,14 @@ One can also create the brick using :class:`Linear` children bricks, which
>>> class ParallelLinear2(Initializable):
... def __init__(self, input_dim1, input_dim2, output_dim1, output_dim2,
... **kwargs):
... super(ParallelLinear2, self).__init__(**kwargs)
... self.linear1 = Linear(input_dim1, output_dim1,
... use_bias=False, **kwargs)
... self.linear2 = Linear(input_dim2, output_dim2,
... use_bias=False, **kwargs)
... self.children = [self.linear1, self.linear2]
... children = [self.linear1, self.linear2]
... children += kwargs.get('children', [])
... super(ParallelLinear2, self).__init__(children=children,
... **kwargs)
...
... @application(inputs=['input1_', 'input2_'], outputs=['output1',
... 'output2'])
Expand Down

0 comments on commit 0e1696d

Please sign in to comment.