Skip to content

Commit

Permalink
Added MLP convenience class (#469)
Browse files Browse the repository at this point in the history
* Added MLP convenience class

* Updated to MLP
  • Loading branch information
neubig committed Jul 17, 2018
1 parent 7c473ae commit 0baa5ce
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 34 deletions.
48 changes: 16 additions & 32 deletions examples/21_self_attention.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,38 +35,22 @@ self_attention: !Experiment
- !PositionalSeqTransducer
input_dim: 512
max_pos: 100
- !ResidualSeqTransducer
input_dim: 512
child: !MultiHeadAttentionSeqTransducer
num_heads: 8
layer_norm: True
- !ResidualSeqTransducer
input_dim: 512
child: !ModularSeqTransducer
input_dim: 512
modules:
- !TransformSeqTransducer
transform: !NonLinear
activation: relu
- !TransformSeqTransducer
transform: !Linear {}
layer_norm: True
- !ResidualSeqTransducer
input_dim: 512
child: !MultiHeadAttentionSeqTransducer
num_heads: 8
layer_norm: True
- !ResidualSeqTransducer
input_dim: 512
child: !ModularSeqTransducer
input_dim: 512
modules:
- !TransformSeqTransducer
transform: !NonLinear
activation: relu
- !TransformSeqTransducer
transform: !Linear {}
layer_norm: True
- !ModularSeqTransducer
modules: !Repeat
times: 2
content: !ModularSeqTransducer
modules:
- !ResidualSeqTransducer
input_dim: 512
child: !MultiHeadAttentionSeqTransducer
num_heads: 8
layer_norm: True
- !ResidualSeqTransducer
input_dim: 512
child: !TransformSeqTransducer
transform: !MLP
activation: relu
layer_norm: True
attender: !MlpAttender
hidden_dim: 512
state_dim: 512
Expand Down
5 changes: 3 additions & 2 deletions xnmt/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@ class SequenceClassifier(model_base.ConditionedModel, model_base.GeneratorModel,
"""
A sequence classifier.
Runs embeddings through an encoder, feeds the average over all encoder outputs to a MLP softmax output layer.
Runs embeddings through an encoder, feeds the average over all encoder outputs to a transform and scoring layer.
Args:
src_reader: A reader for the source side.
trg_reader: A reader for the target side.
src_embedder: A word embedder for the input language
encoder: An encoder to generate encoded inputs
inference: how to perform inference
mlp: final prediction MLP layer
transform: A transform performed before the scoring function
scorer: A scoring function over the multiple choices
"""

yaml_tag = '!SequenceClassifier'
Expand Down
27 changes: 27 additions & 0 deletions xnmt/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,33 @@ def __init__(self,
)
self.save_processed_arg("input_dim", original_input_dim)

class MLP(Transform, Serializable):
"""
A multi-layer perceptron. Defined as one or more NonLinear transforms of equal hidden
dimension and type, then a Linear transform to the output dimension.
"""
yaml_tag = "!MLP"

@serializable_init
def __init__(self,
input_dim: int = Ref("exp_global.default_layer_dim"),
hidden_dim: int = Ref("exp_global.default_layer_dim"),
output_dim: int = Ref("exp_global.default_layer_dim"),
bias: bool = True,
activation: str = 'tanh',
hidden_layers: int = 1,
param_init=Ref("exp_global.param_init", default=bare(GlorotInitializer)),
bias_init=Ref("exp_global.bias_init", default=bare(ZeroInitializer))):
self.layers = []
if hidden_layers > 0:
self.layers = [NonLinear(input_dim=input_dim, output_dim=hidden_dim, bias=bias, activation=activation, param_init=param_init, bias_init=bias_init)]
self.layers += [NonLinear(input_dim=hidden_dim, output_dim=hidden_dim, bias=bias, activation=activation, param_init=param_init, bias_init=bias_init) for _ in range(1,hidden_layers)]
self.layers += [Linear(input_dim=hidden_dim, output_dim=output_dim, bias=bias, param_init=param_init, bias_init=bias_init)]

def __call__(self, expr: dy.Expression) -> dy.Expression:
for layer in self.layers:
expr = layer(expr)
return expr

class TransformSeqTransducer(SeqTransducer, Serializable):
yaml_tag = '!TransformSeqTransducer'
Expand Down

0 comments on commit 0baa5ce

Please sign in to comment.