Skip to content

Commit

Permalink
Modularize transformer encoder (#465)
Browse files Browse the repository at this point in the history
* Implement residuals (with layer norm) and config

* Add SeqTransducer that wraps Transform

* Remove a little bit of verbosity

* Made positional seq transducer

* Fixed tests

* Renamed TransformerAdamTrainer to NoamTrainer (following tensor2tensor)
  • Loading branch information
neubig committed Jul 16, 2018
1 parent 8143fad commit 455e51e
Show file tree
Hide file tree
Showing 13 changed files with 273 additions and 337 deletions.
2 changes: 1 addition & 1 deletion examples/16_transformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ transformer: !Experiment
batcher: !SentShuffleBatcher
batch_size: 100
restart_trainer: False
trainer: !TransformerAdamTrainer
trainer: !NoamTrainer
alpha: 1.0
warmup_steps: 4000
lr_decay: 1.0
Expand Down
101 changes: 101 additions & 0 deletions examples/21_self_attention.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# A setup using self-attention
self_attention: !Experiment
exp_global: !ExpGlobal
model_file: '{EXP_DIR}/models/{EXP}.mod'
log_file: '{EXP_DIR}/logs/{EXP}.log'
default_layer_dim: 512
dropout: 0.3
placeholders:
DATA_IN: examples/data
DATA_OUT: examples/preproc
preproc: !PreprocRunner
overwrite: False
tasks:
- !PreprocVocab
in_files:
- '{DATA_IN}/train.ja'
- '{DATA_IN}/train.en'
out_files:
- '{DATA_OUT}/train.ja.vocab'
- '{DATA_OUT}/train.en.vocab'
specs:
- filenum: all
filters:
- !VocabFiltererFreq
min_freq: 2
model: !DefaultTranslator
src_reader: !PlainTextReader
vocab: !Vocab {vocab_file: '{DATA_OUT}/train.ja.vocab'}
trg_reader: !PlainTextReader
vocab: !Vocab {vocab_file: '{DATA_OUT}/train.en.vocab'}
src_embedder: !SimpleWordEmbedder
emb_dim: 512
encoder: !ModularSeqTransducer
modules:
- !PositionalSeqTransducer
input_dim: 512
max_pos: 100
- !ResidualSeqTransducer
input_dim: 512
child: !MultiHeadAttentionSeqTransducer
num_heads: 8
layer_norm: True
- !ResidualSeqTransducer
input_dim: 512
child: !ModularSeqTransducer
input_dim: 512
modules:
- !TransformSeqTransducer
transform: !NonLinear
activation: relu
- !TransformSeqTransducer
transform: !Linear {}
layer_norm: True
- !ResidualSeqTransducer
input_dim: 512
child: !MultiHeadAttentionSeqTransducer
num_heads: 8
layer_norm: True
- !ResidualSeqTransducer
input_dim: 512
child: !ModularSeqTransducer
input_dim: 512
modules:
- !TransformSeqTransducer
transform: !NonLinear
activation: relu
- !TransformSeqTransducer
transform: !Linear {}
layer_norm: True
attender: !MlpAttender
hidden_dim: 512
state_dim: 512
input_dim: 512
trg_embedder: !SimpleWordEmbedder
emb_dim: 512
decoder: !AutoRegressiveDecoder
rnn: !UniLSTMSeqTransducer
layers: 1
transform: !AuxNonLinear
output_dim: 512
activation: 'tanh'
bridge: !CopyBridge {}
train: !SimpleTrainingRegimen
batcher: !SrcBatcher
batch_size: 32
trainer: !NoamTrainer
alpha: 1.0
warmup_steps: 4000
run_for_epochs: 2
src_file: examples/data/train.ja
trg_file: examples/data/train.en
dev_tasks:
- !LossEvalTask
src_file: examples/data/head.ja
ref_file: examples/data/head.en
evaluate:
- !AccuracyEvalTask
eval_metrics: bleu
src_file: examples/data/head.ja
ref_file: examples/data/head.en
hyp_file: examples/output/{EXP}.test_hyp
10 changes: 8 additions & 2 deletions test/config/encoders.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ exp1-lstm-encoder: !Experiment
input_feeding: True
bridge: !CopyBridge {}
inference: !AutoRegressiveInference {}

exp2-residual-encoder: !Experiment
kwargs:
<< : *defaults
Expand All @@ -57,9 +58,14 @@ exp2-residual-encoder: !Experiment
vocab: !Vocab {vocab_file: examples/data/head.en.vocab}
src_embedder: !SimpleWordEmbedder
emb_dim: 64
encoder: !ResidualLSTMSeqTransducer
layers: 2
encoder: !ModularSeqTransducer
input_dim: 64
modules:
- !ResidualSeqTransducer
child: !BiLSTMSeqTransducer
input_dim: 64
- !BiLSTMSeqTransducer
input_dim: 64
attender: !MlpAttender
state_dim: 64
hidden_dim: 64
Expand Down
36 changes: 18 additions & 18 deletions test/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from xnmt.lstm import UniLSTMSeqTransducer, BiLSTMSeqTransducer
from xnmt.param_collection import ParamManager
from xnmt.pyramidal import PyramidalLSTMSeqTransducer
from xnmt.residual import ResidualLSTMSeqTransducer
from xnmt.scorer import Softmax
from xnmt.self_attention import MultiHeadAttentionSeqTransducer
from xnmt.transform import NonLinear
Expand Down Expand Up @@ -84,23 +83,24 @@ def test_uni_lstm_encoder_len(self):
)
self.assert_in_out_len_equal(model)

def test_res_lstm_encoder_len(self):
layer_dim = 512
model = DefaultTranslator(
src_reader=self.src_reader,
trg_reader=self.trg_reader,
src_embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100),
encoder=ResidualLSTMSeqTransducer(input_dim=layer_dim, hidden_dim=layer_dim, layers=3),
attender=MlpAttender(input_dim=layer_dim, state_dim=layer_dim, hidden_dim=layer_dim),
trg_embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100),
decoder=AutoRegressiveDecoder(input_dim=layer_dim,
trg_embed_dim=layer_dim,
rnn=UniLSTMSeqTransducer(input_dim=layer_dim, hidden_dim=layer_dim, decoder_input_dim=layer_dim, yaml_path="model.decoder.rnn"),
transform=NonLinear(input_dim=layer_dim*2, output_dim=layer_dim),
scorer=Softmax(input_dim=layer_dim, vocab_size=100),
bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)),
)
self.assert_in_out_len_equal(model)
# TODO: Update this to the new residual LSTM transducer framework
# def test_res_lstm_encoder_len(self):
# layer_dim = 512
# model = DefaultTranslator(
# src_reader=self.src_reader,
# trg_reader=self.trg_reader,
# src_embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100),
# encoder=ResidualLSTMSeqTransducer(input_dim=layer_dim, hidden_dim=layer_dim, layers=3),
# attender=MlpAttender(input_dim=layer_dim, state_dim=layer_dim, hidden_dim=layer_dim),
# trg_embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100),
# decoder=AutoRegressiveDecoder(input_dim=layer_dim,
# trg_embed_dim=layer_dim,
# rnn=UniLSTMSeqTransducer(input_dim=layer_dim, hidden_dim=layer_dim, decoder_input_dim=layer_dim, yaml_path="model.decoder.rnn"),
# transform=NonLinear(input_dim=layer_dim*2, output_dim=layer_dim),
# scorer=Softmax(input_dim=layer_dim, vocab_size=100),
# bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)),
# )
# self.assert_in_out_len_equal(model)

def test_py_lstm_encoder_len(self):
layer_dim = 512
Expand Down
2 changes: 1 addition & 1 deletion xnmt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import xnmt.evaluator
import xnmt.exp_global
import xnmt.experiment
import xnmt.ff
import xnmt.fixed_size_att
import xnmt.hyper_parameters
import xnmt.inference
Expand All @@ -41,6 +40,7 @@
import xnmt.model_base
import xnmt.optimizer
import xnmt.param_init
import xnmt.positional
import xnmt.preproc_runner
import xnmt.pyramidal
import xnmt.residual
Expand Down
25 changes: 0 additions & 25 deletions xnmt/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,31 +277,6 @@ def embed(self, x):
ret = dy.noise(ret, self.weight_noise)
return ret

class PositionEmbedder(Embedder, Serializable):

yaml_tag = '!PositionEmbedder'

@serializable_init
@register_xnmt_handler
def __init__(self, max_pos: int, emb_dim: int = Ref("exp_global.default_layer_dim"),
param_init: ParamInitializer = Ref("exp_global.param_init", default=bare(GlorotInitializer))):
"""
max_pos: largest embedded position
emb_dim: embedding size
param_init: how to initialize embedding matrix
"""
self.max_pos = max_pos
self.emb_dim = emb_dim
param_collection = ParamManager.my_params(self)
param_init = param_init
dim = (self.emb_dim, max_pos)
self.embeddings = param_collection.add_parameters(dim, init=param_init.initializer(dim, is_lookup=True))

def embed(self, word): raise NotImplementedError("Position-embedding for individual words not implemented yet.")
def embed_sent(self, sent_len):
embeddings = dy.strided_select(dy.parameter(self.embeddings), [1,1], [0,0], [self.emb_dim, sent_len])
return ExpressionSequence(expr_tensor=embeddings, mask=None)

class NoopEmbedder(Embedder, Serializable):
"""
This embedder performs no lookups but only passes through the inputs.
Expand Down
51 changes: 0 additions & 51 deletions xnmt/ff.py

This file was deleted.

6 changes: 3 additions & 3 deletions xnmt/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def __init__(self, alpha=0.001, beta_1=0.9, beta_2=0.999, eps=1e-8, update_every
super().__init__(optimizer=dy.AdamTrainer(ParamManager.global_collection(), alpha, beta_1, beta_2, eps),
skip_noisy=skip_noisy)

class TransformerAdamTrainer(XnmtOptimizer, Serializable):
class NoamTrainer(XnmtOptimizer, Serializable):
"""
Proposed in the paper "Attention is all you need" (https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf) [Page 7, Eq. 3]
In this the learning rate of Adam Optimizer is increased for the first warmup steps followed by a gradual decay
Expand All @@ -220,7 +220,7 @@ class TransformerAdamTrainer(XnmtOptimizer, Serializable):
values, and abort a step if the norm of the gradient exceeds four standard deviations of the
moving average. Reference: https://arxiv.org/pdf/1804.09849.pdf
"""
yaml_tag = '!TransformerAdamTrainer'
yaml_tag = '!NoamTrainer'

@serializable_init
def __init__(self, alpha=1.0, dim=512, warmup_steps=4000, beta_1=0.9, beta_2=0.98, eps=1e-9,
Expand Down Expand Up @@ -277,4 +277,4 @@ def learning_rate(self):
return 1.0
@learning_rate.setter
def learning_rate(self, value):
pass
pass
80 changes: 80 additions & 0 deletions xnmt/positional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import dynet as dy
from typing import List

from xnmt.embedder import Embedder
from xnmt.expression_sequence import ExpressionSequence
from xnmt.param_collection import ParamManager
from xnmt.param_init import ParamInitializer, GlorotInitializer, ZeroInitializer
from xnmt.persistence import Serializable
from xnmt.persistence import serializable_init, Serializable, bare, Ref
from xnmt.transducer import SeqTransducer, FinalTransducerState

class PositionEmbedder(Embedder, Serializable):

yaml_tag = '!PositionEmbedder'

@serializable_init
def __init__(self, max_pos: int, emb_dim: int = Ref("exp_global.default_layer_dim"),
param_init: ParamInitializer = Ref("exp_global.param_init", default=bare(GlorotInitializer))):
"""
max_pos: largest embedded position
emb_dim: embedding size
param_init: how to initialize embedding matrix
"""
self.max_pos = max_pos
self.emb_dim = emb_dim
param_collection = ParamManager.my_params(self)
param_init = param_init
dim = (self.emb_dim, max_pos)
self.embeddings = param_collection.add_parameters(dim, init=param_init.initializer(dim, is_lookup=True))

def embed(self, word): raise NotImplementedError("Position-embedding for individual words not implemented yet.")
def embed_sent(self, sent_len):
embeddings = dy.strided_select(dy.parameter(self.embeddings), [1,1], [0,0], [self.emb_dim, sent_len])
return ExpressionSequence(expr_tensor=embeddings, mask=None)

# Note: alternatively, this could wrap "PositionEmbedder", but it seems to me
# that PositionEmbedder is probably not necessary in the first place, so
# it probably makes more sense to have this as a SeqTransducer that
# adds positional embeddings to an input
class PositionalSeqTransducer(SeqTransducer, Serializable):
yaml_tag = '!PositionalSeqTransducer'

@serializable_init
def __init__(self,
max_pos: int,
op: str = 'sum',
emb_type: str = 'param',
input_dim: int = Ref("exp_global.default_layer_dim"),
param_init: ParamInitializer = Ref("exp_global.param_init", default=bare(GlorotInitializer))):
"""
max_pos: largest embedded position
op: how to combine positional encodings with the original encodings, can be "sum" or "concat"
type: what type of embddings to use, "param"=parameterized (others, such as the trigonometric embeddings are todo)
input_dim: embedding size
param_init: how to initialize embedding matrix
"""
self.max_pos = max_pos
self.input_dim = input_dim
self.op = op
self.emb_type = emb_type
param_init = param_init
dim = (self.input_dim, max_pos)
param_collection = ParamManager.my_params(self)
self.embedder = param_collection.add_parameters(dim, init=param_init.initializer(dim, is_lookup=True))

def get_final_states(self) -> List[FinalTransducerState]:
return self._final_states

def transduce(self, src: ExpressionSequence) -> ExpressionSequence:
sent_len = len(src)
embeddings = dy.strided_select(dy.parameter(self.embedder), [1,1], [0,0], [self.input_dim, sent_len])
if self.op == 'sum':
output = embeddings + src.as_tensor()
elif self.op == 'concat':
output = dy.concatenate([embeddings, src.as_tensor()])
else:
raise ValueError(f'Illegal op {op} in PositionalTransducer (options are "sum"/"concat")')
output_seq = ExpressionSequence(expr_tensor=output, mask=src.mask)
self._final_states = [FinalTransducerState(output_seq[-1])]
return output_seq

0 comments on commit 455e51e

Please sign in to comment.