-
Notifications
You must be signed in to change notification settings - Fork 44
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Modularize transformer encoder (#465)
* Implement residuals (with layer norm) and config * Add SeqTransducer that wraps Transform * Remove a little bit of verbosity * Made positional seq transducer * Fixed tests * Renamed TransformerAdamTrainer to NoamTrainer (following tensor2tensor)
- Loading branch information
Showing
13 changed files
with
273 additions
and
337 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# A setup using self-attention | ||
self_attention: !Experiment | ||
exp_global: !ExpGlobal | ||
model_file: '{EXP_DIR}/models/{EXP}.mod' | ||
log_file: '{EXP_DIR}/logs/{EXP}.log' | ||
default_layer_dim: 512 | ||
dropout: 0.3 | ||
placeholders: | ||
DATA_IN: examples/data | ||
DATA_OUT: examples/preproc | ||
preproc: !PreprocRunner | ||
overwrite: False | ||
tasks: | ||
- !PreprocVocab | ||
in_files: | ||
- '{DATA_IN}/train.ja' | ||
- '{DATA_IN}/train.en' | ||
out_files: | ||
- '{DATA_OUT}/train.ja.vocab' | ||
- '{DATA_OUT}/train.en.vocab' | ||
specs: | ||
- filenum: all | ||
filters: | ||
- !VocabFiltererFreq | ||
min_freq: 2 | ||
model: !DefaultTranslator | ||
src_reader: !PlainTextReader | ||
vocab: !Vocab {vocab_file: '{DATA_OUT}/train.ja.vocab'} | ||
trg_reader: !PlainTextReader | ||
vocab: !Vocab {vocab_file: '{DATA_OUT}/train.en.vocab'} | ||
src_embedder: !SimpleWordEmbedder | ||
emb_dim: 512 | ||
encoder: !ModularSeqTransducer | ||
modules: | ||
- !PositionalSeqTransducer | ||
input_dim: 512 | ||
max_pos: 100 | ||
- !ResidualSeqTransducer | ||
input_dim: 512 | ||
child: !MultiHeadAttentionSeqTransducer | ||
num_heads: 8 | ||
layer_norm: True | ||
- !ResidualSeqTransducer | ||
input_dim: 512 | ||
child: !ModularSeqTransducer | ||
input_dim: 512 | ||
modules: | ||
- !TransformSeqTransducer | ||
transform: !NonLinear | ||
activation: relu | ||
- !TransformSeqTransducer | ||
transform: !Linear {} | ||
layer_norm: True | ||
- !ResidualSeqTransducer | ||
input_dim: 512 | ||
child: !MultiHeadAttentionSeqTransducer | ||
num_heads: 8 | ||
layer_norm: True | ||
- !ResidualSeqTransducer | ||
input_dim: 512 | ||
child: !ModularSeqTransducer | ||
input_dim: 512 | ||
modules: | ||
- !TransformSeqTransducer | ||
transform: !NonLinear | ||
activation: relu | ||
- !TransformSeqTransducer | ||
transform: !Linear {} | ||
layer_norm: True | ||
attender: !MlpAttender | ||
hidden_dim: 512 | ||
state_dim: 512 | ||
input_dim: 512 | ||
trg_embedder: !SimpleWordEmbedder | ||
emb_dim: 512 | ||
decoder: !AutoRegressiveDecoder | ||
rnn: !UniLSTMSeqTransducer | ||
layers: 1 | ||
transform: !AuxNonLinear | ||
output_dim: 512 | ||
activation: 'tanh' | ||
bridge: !CopyBridge {} | ||
train: !SimpleTrainingRegimen | ||
batcher: !SrcBatcher | ||
batch_size: 32 | ||
trainer: !NoamTrainer | ||
alpha: 1.0 | ||
warmup_steps: 4000 | ||
run_for_epochs: 2 | ||
src_file: examples/data/train.ja | ||
trg_file: examples/data/train.en | ||
dev_tasks: | ||
- !LossEvalTask | ||
src_file: examples/data/head.ja | ||
ref_file: examples/data/head.en | ||
evaluate: | ||
- !AccuracyEvalTask | ||
eval_metrics: bleu | ||
src_file: examples/data/head.ja | ||
ref_file: examples/data/head.en | ||
hyp_file: examples/output/{EXP}.test_hyp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import dynet as dy | ||
from typing import List | ||
|
||
from xnmt.embedder import Embedder | ||
from xnmt.expression_sequence import ExpressionSequence | ||
from xnmt.param_collection import ParamManager | ||
from xnmt.param_init import ParamInitializer, GlorotInitializer, ZeroInitializer | ||
from xnmt.persistence import Serializable | ||
from xnmt.persistence import serializable_init, Serializable, bare, Ref | ||
from xnmt.transducer import SeqTransducer, FinalTransducerState | ||
|
||
class PositionEmbedder(Embedder, Serializable): | ||
|
||
yaml_tag = '!PositionEmbedder' | ||
|
||
@serializable_init | ||
def __init__(self, max_pos: int, emb_dim: int = Ref("exp_global.default_layer_dim"), | ||
param_init: ParamInitializer = Ref("exp_global.param_init", default=bare(GlorotInitializer))): | ||
""" | ||
max_pos: largest embedded position | ||
emb_dim: embedding size | ||
param_init: how to initialize embedding matrix | ||
""" | ||
self.max_pos = max_pos | ||
self.emb_dim = emb_dim | ||
param_collection = ParamManager.my_params(self) | ||
param_init = param_init | ||
dim = (self.emb_dim, max_pos) | ||
self.embeddings = param_collection.add_parameters(dim, init=param_init.initializer(dim, is_lookup=True)) | ||
|
||
def embed(self, word): raise NotImplementedError("Position-embedding for individual words not implemented yet.") | ||
def embed_sent(self, sent_len): | ||
embeddings = dy.strided_select(dy.parameter(self.embeddings), [1,1], [0,0], [self.emb_dim, sent_len]) | ||
return ExpressionSequence(expr_tensor=embeddings, mask=None) | ||
|
||
# Note: alternatively, this could wrap "PositionEmbedder", but it seems to me | ||
# that PositionEmbedder is probably not necessary in the first place, so | ||
# it probably makes more sense to have this as a SeqTransducer that | ||
# adds positional embeddings to an input | ||
class PositionalSeqTransducer(SeqTransducer, Serializable): | ||
yaml_tag = '!PositionalSeqTransducer' | ||
|
||
@serializable_init | ||
def __init__(self, | ||
max_pos: int, | ||
op: str = 'sum', | ||
emb_type: str = 'param', | ||
input_dim: int = Ref("exp_global.default_layer_dim"), | ||
param_init: ParamInitializer = Ref("exp_global.param_init", default=bare(GlorotInitializer))): | ||
""" | ||
max_pos: largest embedded position | ||
op: how to combine positional encodings with the original encodings, can be "sum" or "concat" | ||
type: what type of embddings to use, "param"=parameterized (others, such as the trigonometric embeddings are todo) | ||
input_dim: embedding size | ||
param_init: how to initialize embedding matrix | ||
""" | ||
self.max_pos = max_pos | ||
self.input_dim = input_dim | ||
self.op = op | ||
self.emb_type = emb_type | ||
param_init = param_init | ||
dim = (self.input_dim, max_pos) | ||
param_collection = ParamManager.my_params(self) | ||
self.embedder = param_collection.add_parameters(dim, init=param_init.initializer(dim, is_lookup=True)) | ||
|
||
def get_final_states(self) -> List[FinalTransducerState]: | ||
return self._final_states | ||
|
||
def transduce(self, src: ExpressionSequence) -> ExpressionSequence: | ||
sent_len = len(src) | ||
embeddings = dy.strided_select(dy.parameter(self.embedder), [1,1], [0,0], [self.input_dim, sent_len]) | ||
if self.op == 'sum': | ||
output = embeddings + src.as_tensor() | ||
elif self.op == 'concat': | ||
output = dy.concatenate([embeddings, src.as_tensor()]) | ||
else: | ||
raise ValueError(f'Illegal op {op} in PositionalTransducer (options are "sum"/"concat")') | ||
output_seq = ExpressionSequence(expr_tensor=output, mask=src.mask) | ||
self._final_states = [FinalTransducerState(output_seq[-1])] | ||
return output_seq |
Oops, something went wrong.