Skip to content

Commit

Permalink
Merge pull request #84 from neulab/decoder-dim
Browse files Browse the repository at this point in the history
setting decoder RNN input dim to size of trg embeddings, not of encoder output
  • Loading branch information
msperber committed Jun 3, 2017
2 parents d4c6661 + dda1d88 commit 1727d1c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions xnmt/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def rnn_from_spec(spec, num_layers, input_dim, hidden_dim, model, residual_to_ou
class MlpSoftmaxDecoder(Decoder):
# TODO: This should probably take a softmax object, which can be normal or class-factored, etc.
# For now the default behavior is hard coded.
def __init__(self, layers, input_dim, lstm_dim, mlp_hidden_dim, vocab_size, model,
def __init__(self, layers, input_dim, lstm_dim, mlp_hidden_dim, vocab_size, model, trg_embed_dim,
fwd_lstm=None):
self.input_dim = input_dim
if fwd_lstm == None:
Expand All @@ -40,7 +40,7 @@ def __init__(self, layers, input_dim, lstm_dim, mlp_hidden_dim, vocab_size, mode
self.fwd_lstm = fwd_lstm
self.mlp = MLP(input_dim + lstm_dim, mlp_hidden_dim, vocab_size, model)
self.state = None
self.serialize_params = [layers, input_dim, lstm_dim, mlp_hidden_dim, vocab_size, model]
self.serialize_params = [layers, input_dim, lstm_dim, mlp_hidden_dim, vocab_size, model, trg_embed_dim]

def initialize(self):
self.state = self.fwd_lstm.initial_state()
Expand Down
4 changes: 2 additions & 2 deletions xnmt/xnmt_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,11 @@ def create_model(self):
self.attender = StandardAttender(self.attention_context_dim, self.output_state_dim, self.attender_hidden_dim,
self.model)

decoder_rnn = Decoder.rnn_from_spec(self.args.decoder_type, self.args.decoder_layers, self.attention_context_dim,
decoder_rnn = Decoder.rnn_from_spec(self.args.decoder_type, self.args.decoder_layers, self.output_word_emb_dim,
self.output_state_dim, self.model, self.args.residual_to_output)
self.decoder = MlpSoftmaxDecoder(self.args.decoder_layers, self.attention_context_dim, self.output_state_dim,
self.output_mlp_hidden_dim, len(self.output_reader.vocab),
self.model, decoder_rnn)
self.model, self.output_word_emb_dim, decoder_rnn)

self.translator = DefaultTranslator(self.input_embedder, self.encoder, self.attender, self.output_embedder, self.decoder)
self.model_params = ModelParams(self.encoder, self.attender, self.decoder, self.input_reader.vocab.i2w,
Expand Down

0 comments on commit 1727d1c

Please sign in to comment.