Skip to content

Commit

Permalink
Merge pull request #111 from neulab/yaml-v2
Browse files Browse the repository at this point in the history
More flexible config
  • Loading branch information
neubig committed Jul 3, 2017
2 parents 7cf915a + ded0100 commit 39f3797
Show file tree
Hide file tree
Showing 23 changed files with 729 additions and 485 deletions.
39 changes: 29 additions & 10 deletions examples/debug.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# This is a super-small config just for debugging
# Small config to help refactoring to make the model completely YAML based
defaults:
experiment:
model_file: examples/output/<EXP>.mod
Expand All @@ -9,15 +9,34 @@ defaults:
decode_every: 1
eval_metrics: bleu,wer
train:
train_src: examples/data/head.ja
train_trg: examples/data/head.en
dev_src: examples/data/head.ja
dev_trg: examples/data/head.en
default_layer_dim: 64
dropout: 0.5
encoder:
type: BiLSTM
dropout: 0.0
training_corpus: !BilingualTrainingCorpus
train_src: examples/data/head.ja
train_trg: examples/data/head.en
dev_src: examples/data/head.ja
dev_trg: examples/data/head.en
corpus_parser: !BilingualCorpusParser
src_reader: !PlainTextReader {}
trg_reader: !PlainTextReader {}
model: !DefaultTranslator
input_embedder: !SimpleWordEmbedder
# vocab_size: 100
emb_dim: 64
encoder: !BiLSTMEncoder
layers: 1
input_dim: 64
attender: !StandardAttender
state_dim: 64
hidden_dim: 64
input_dim: 64
output_embedder: !SimpleWordEmbedder
# vocab_size: 100
emb_dim: 64
decoder: !MlpSoftmaxDecoder
layers: 1
# vocab_size: 100
mlp_hidden_dim: 64
decode:
src_file: examples/data/head.ja
evaluate:
Expand All @@ -29,9 +48,9 @@ debug-config-1layer:

debug-config-2layers:
train:
decoder_layers: 2
dropout: 0.2

debug-config-2layers-finetune:
train:
decoder_layers: 2
dropout: 0.2
pretrained_model_file: examples/output/debug-config-2layers.mod
1 change: 1 addition & 0 deletions examples/modular.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# TODO: This config file is in the old format. We need to make it match debug.yaml
# This is a super-small config just for debugging
defaults:
experiment:
Expand Down
1 change: 1 addition & 0 deletions examples/preproc.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# TODO: This config file is in the old format. We need to make it match debug.yaml
defaults:
experiment:
model_file: examples/output/<EXP>.mod
Expand Down
1 change: 1 addition & 0 deletions examples/random-search.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# TODO: This config file is in the old format. We need to make it match debug.yaml
# Demonstrates random search over model parameters
defaults:
experiment:
Expand Down
1 change: 1 addition & 0 deletions examples/random-search2.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# TODO: This config file is in the old format. We need to make it match debug.yaml
# Demonstrates random search over search parameters
defaults:
experiment:
Expand Down
51 changes: 27 additions & 24 deletions examples/speech.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# TODO: This config file is in the old format. We need to make it match debug.yaml
# This config file replicates the Listen-Attend-Spell architecture: https://arxiv.org/pdf/1508.01211.pdf
# Compared to the conventional attentional model, we remove input embeddings, instead directly read in a feature vector
# the pyramidal LSTM reduces length of the input sequence by a factor of 2 per layer (except for the first layer).
Expand All @@ -12,33 +13,35 @@ defaults:
decode_every: 1
eval_metrics: cer,wer
train:
train_src: examples/data/synth.contvec.npz
train_trg: examples/data/synth.char
dev_src: examples/data/synth.contvec.npz
dev_trg: examples/data/synth.char
# choose pyramidal LSTM encoder:
encoder:
type: PyramidalBiLSTM
downsampling_method: skip
# indicate the dimension of the feature vectors:
input_word_embed_dim: 240
# indicates that the src-side data is continuous-space vectors, contained in a numpy archive (see input.py for details):
input_format: contvec
training_corpus: !BilingualTrainingCorpus
train_src: examples/data/synth.contvec.npz
train_trg: examples/data/synth.char
dev_src: examples/data/synth.contvec.npz
dev_trg: examples/data/synth.char
corpus_parser: !BilingualCorpusParser
src_reader: !ContVecReader {}
trg_reader: !PlainTextReader {}
model: !DefaultTranslator
input_embedder: !NoopEmbedder
emb_dim: 240
encoder: !PyramidalLSTMEncoder
layers: 1
downsampling_method: skip
input_dim: 240
hidden_dim: 63
attender: !StandardAttender
state_dim: 64
hidden_dim: 64
input_dim: 64
output_embedder: !SimpleWordEmbedder
emb_dim: 64
decoder: !MlpSoftmaxDecoder
layers: 1
mlp_hidden_dim: 64
decode:
src_file: examples/data/synth.contvec.npz
input_format: contvec
evaluate:
ref_file: examples/data/synth.char

speech:
train:
encoder:
layers: 3
hidden_dim: 64
downsampling_method: concat
output_word_embed_dim: 64
output_state_dim: 64
attender_hidden_dim: 64
output_mlp_hidden_dim: 64
attention_context_dim: 64
speech-2layers:

2 changes: 2 additions & 0 deletions examples/standard.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# TODO: This config file is in the old format. We need to make it match debug.yaml

defaults:
experiment:
model_file: examples/output/<EXP>.mod
Expand Down
22 changes: 16 additions & 6 deletions xnmt/attender.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import dynet as dy
from batcher import *
from serializer import *
import model_globals


class Attender:
'''
A template class for functions implementing attention.
'''

'''
Implement things.
'''
def __init__(self, input_dim):
"""
:param input_dim: every attender needs an input_dim
"""
pass

def start_sent(self, sent):
raise NotImplementedError('start_sent must be implemented for Attender subclasses')
Expand All @@ -18,18 +22,23 @@ def calc_attention(self, state):
raise NotImplementedError('calc_attention must be implemented for Attender subclasses')


class StandardAttender(Attender):
class StandardAttender(Attender, Serializable):
'''
Implements the attention model of Bahdanau et. al (2014)
'''

def __init__(self, input_dim, state_dim, hidden_dim, model):
yaml_tag = u'!StandardAttender'

def __init__(self, input_dim, state_dim, hidden_dim):
self.input_dim = input_dim
self.state_dim = state_dim
self.hidden_dim = hidden_dim
model = model_globals.model
self.pW = model.add_parameters((hidden_dim, input_dim))
self.pV = model.add_parameters((hidden_dim, state_dim))
self.pb = model.add_parameters(hidden_dim)
self.pU = model.add_parameters((1, hidden_dim))
self.curr_sent = None
self.serialize_params = [input_dim, state_dim, hidden_dim, model]

def start_sent(self, sent):
self.curr_sent = sent
Expand All @@ -53,3 +62,4 @@ def calc_context(self, state):
attention = self.calc_attention(state)
I = dy.concatenate_cols(self.curr_sent)
return I * attention

18 changes: 16 additions & 2 deletions xnmt/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import inspect
from batcher import *
from translator import TrainTestInterface
from serializer import Serializable
import model_globals

class Decoder(TrainTestInterface):
'''
Expand All @@ -29,12 +31,24 @@ def rnn_from_spec(spec, num_layers, input_dim, hidden_dim, model, residual_to_ou
raise RuntimeError("Unknown decoder type {}".format(spec))


class MlpSoftmaxDecoder(RnnDecoder):
class MlpSoftmaxDecoder(RnnDecoder, Serializable):
# TODO: This should probably take a softmax object, which can be normal or class-factored, etc.
# For now the default behavior is hard coded.
def __init__(self, layers, input_dim, lstm_dim, mlp_hidden_dim, vocab_size, model, trg_embed_dim, dropout,

yaml_tag = u'!MlpSoftmaxDecoder'

def __init__(self, layers, input_dim, lstm_dim, mlp_hidden_dim, vocab_size, trg_embed_dim, dropout=None,
rnn_spec="lstm", residual_to_output=False):
self.layers = layers
self.lstm_dim = lstm_dim
self.mlp_hidden_dim = mlp_hidden_dim
self.vocab_size = vocab_size
self.trg_embed_dim = trg_embed_dim
self.rnn_spec = rnn_spec
self.residual_to_output = residual_to_output
model = model_globals.model
self.input_dim = input_dim
if dropout is None: dropout = model_globals.dropout
self.dropout = dropout
self.fwd_lstm = RnnDecoder.rnn_from_spec(rnn_spec, layers, trg_embed_dim, lstm_dim, model, residual_to_output)
self.mlp = MLP(input_dim + lstm_dim, mlp_hidden_dim, vocab_size, model)
Expand Down
16 changes: 11 additions & 5 deletions xnmt/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

from batcher import *
import dynet as dy
from serializer import Serializable
import model_globals
import yaml

class Embedder:
"""
Expand Down Expand Up @@ -88,16 +91,18 @@ def as_tensor(self):
self.expr_tensor = dy.concatenate(list(map(lambda x:dy.transpose(x), self)))
return self.expr_tensor

class SimpleWordEmbedder(Embedder):
class SimpleWordEmbedder(Embedder, Serializable):
"""
Simple word embeddings via lookup.
"""

def __init__(self, vocab_size, emb_dim, model):
yaml_tag = u'!SimpleWordEmbedder'

def __init__(self, vocab_size, emb_dim):
self.vocab_size = vocab_size
self.emb_dim = emb_dim
model = model_globals.model
self.embeddings = model.add_lookup_parameters((vocab_size, emb_dim))
self.serialize_params = [vocab_size, emb_dim, model]

def embed(self, x):
# single mode
Expand All @@ -119,7 +124,7 @@ def embed_sent(self, sent):

return ExpressionSequence(expr_list=embeddings)

class NoopEmbedder(Embedder):
class NoopEmbedder(Embedder, Serializable):
"""
This embedder performs no lookups but only passes through the inputs.
Expand All @@ -129,9 +134,10 @@ class NoopEmbedder(Embedder):
This is useful e.g. to stack several encoders, where the second encoder performs no
lookups.
"""

yaml_tag = u'!NoopEmbedder'
def __init__(self, emb_dim, model):
self.emb_dim = emb_dim
self.serialize_params = [emb_dim, model]

def embed(self, x):
if isinstance(x, dy.Expression): return x
Expand Down
23 changes: 17 additions & 6 deletions xnmt/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from embedder import ExpressionSequence
from translator import TrainTestInterface
import inspect
from serializer import Serializable
import model_globals
import yaml

class Encoder(TrainTestInterface):
"""
Expand Down Expand Up @@ -48,16 +51,23 @@ class BuilderEncoder(Encoder):
def transduce(self, sent):
return self.builder.transduce(sent)

class BiLSTMEncoder(BuilderEncoder):
def __init__(self, model, global_train_params, input_dim=512, layers=1, hidden_dim=None, dropout=None):
if hidden_dim is None: hidden_dim = global_train_params.get("default_layer_dim", 512)
if dropout is None: dropout = global_train_params.get("dropout", 0.0)
class BiLSTMEncoder(BuilderEncoder, Serializable):
yaml_tag = u'!BiLSTMEncoder'

def __init__(self, input_dim=None, layers=1, hidden_dim=None, dropout=None):
model = model_globals.model
if input_dim is None: input_dim = model_globals.default_layer_dim
if hidden_dim is None: hidden_dim = model_globals.default_layer_dim
if dropout is None: dropout = model_globals.dropout
self.input_dim = input_dim
self.layers = layers
self.hidden_dim = hidden_dim
self.dropout = dropout
self.builder = dy.BiRNNBuilder(layers, input_dim, hidden_dim, model, dy.VanillaLSTMBuilder)
self.serialize_params = [model, global_train_params, input_dim, layers, hidden_dim, dropout]
def set_train(self, val):
self.builder.set_dropout(self.dropout if val else 0.0)


class ResidualLSTMEncoder(BuilderEncoder):
def __init__(self, model, global_train_params, input_dim=512, layers=1, hidden_dim=None, residual_to_output=False, dropout=None):
if hidden_dim is None: hidden_dim = global_train_params.get("default_layer_dim", 512)
Expand All @@ -78,7 +88,8 @@ def __init__(self, model, global_train_params, input_dim=512, layers=1, hidden_d
def set_train(self, val):
self.builder.set_dropout(self.dropout if val else 0.0)

class PyramidalLSTMEncoder(BuilderEncoder):
class PyramidalLSTMEncoder(BuilderEncoder, Serializable):
yaml_tag = "!PyramidalLSTMEncoder"
def __init__(self, model, global_train_params, input_dim=512, layers=1, hidden_dim=None, downsampling_method="skip", dropout=None):
if hidden_dim is None: hidden_dim = global_train_params.get("default_layer_dim", 512)
if dropout is None: dropout = global_train_params.get("dropout", 0.0)
Expand Down

0 comments on commit 39f3797

Please sign in to comment.