Skip to content

Commit

Permalink
Retrieval objective somewhat implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
neubig committed Jul 4, 2017
1 parent 68917ea commit 8745a36
Show file tree
Hide file tree
Showing 11 changed files with 22,027 additions and 24 deletions.
500 changes: 500 additions & 0 deletions examples/data/dev.allids

Large diffs are not rendered by default.

500 changes: 500 additions & 0 deletions examples/data/test.allids

Large diffs are not rendered by default.

10,000 changes: 10,000 additions & 0 deletions examples/data/train.allids

Large diffs are not rendered by default.

11,000 changes: 11,000 additions & 0 deletions examples/data/traindevtest.en

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions examples/debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ defaults:
src_reader: !PlainTextReader {}
trg_reader: !PlainTextReader {}
model: !DefaultTranslator
input_embedder: !SimpleWordEmbedder
src_embedder: !SimpleWordEmbedder
# vocab_size: 100
emb_dim: 64
encoder: !BiLSTMEncoder
Expand All @@ -30,7 +30,7 @@ defaults:
state_dim: 64
hidden_dim: 64
input_dim: 64
output_embedder: !SimpleWordEmbedder
trg_embedder: !SimpleWordEmbedder
# vocab_size: 100
emb_dim: 64
decoder: !MlpSoftmaxDecoder
Expand Down
4 changes: 3 additions & 1 deletion examples/retrieval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,18 @@ defaults:
trg_reader: !IDReader {}
model: !DotProductRetriever
src_embedder: !SimpleWordEmbedder
vocab_size: 5000 # TODO: set this automatically
emb_dim: 512
src_encoder: !BiLSTMEncoder
layers: 1
trg_embedder: !SimpleWordEmbedder
vocab_size: 5000 # TODO: set this automatically
emb_dim: 512
trg_encoder: !BiLSTMEncoder
layers: 1
database: !StandardRetrievalDatabase
reader: !PlainTextReader {}
database_file: examples/traindevtest.en
database_file: examples/data/traindevtest.en
decode:
src_file: examples/data/test.ja
evaluate:
Expand Down
4 changes: 2 additions & 2 deletions examples/speech.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ defaults:
src_reader: !ContVecReader {}
trg_reader: !PlainTextReader {}
model: !DefaultTranslator
input_embedder: !NoopEmbedder
src_embedder: !NoopEmbedder
emb_dim: 240
encoder: !PyramidalLSTMEncoder
layers: 1
Expand All @@ -32,7 +32,7 @@ defaults:
state_dim: 64
hidden_dim: 64
input_dim: 64
output_embedder: !SimpleWordEmbedder
trg_embedder: !SimpleWordEmbedder
emb_dim: 64
decoder: !MlpSoftmaxDecoder
layers: 1
Expand Down
4 changes: 2 additions & 2 deletions examples/standard.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ defaults:
src_reader: !PlainTextReader {}
trg_reader: !PlainTextReader {}
model: !DefaultTranslator
input_embedder: !SimpleWordEmbedder
src_embedder: !SimpleWordEmbedder
emb_dim: 512
encoder: !BiLSTMEncoder
layers: 1
attender: !StandardAttender
hidden_dim: 512
state_dim: 512
input_dim: 512
output_embedder: !SimpleWordEmbedder
trg_embedder: !SimpleWordEmbedder
emb_dim: 512
decoder: !MlpSoftmaxDecoder
layers: 1
Expand Down
3 changes: 2 additions & 1 deletion xnmt/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class StandardRetrievalDatabase(Serializable):
def __init__(self, reader, database_file):
self.reader = reader
self.database_file = database_file
self.database = reader.read_file(reader)
self.database = reader.read_file(database_file)

##### The actual retriever class

Expand Down Expand Up @@ -81,6 +81,7 @@ def __init__(self, src_embedder, src_encoder, trg_embedder, trg_encoder, databas
:param src_encoder: An encoder for the source language
:param trg_embedder: A word embedder for the target language
:param trg_encoder: An encoder for the target language
:param database: A database of things to retrieve
'''
self.src_embedder = src_embedder
self.src_encoder = src_encoder
Expand Down
30 changes: 15 additions & 15 deletions xnmt/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,44 +48,44 @@ class DefaultTranslator(Translator, Serializable):
yaml_tag = u'!DefaultTranslator'


def __init__(self, input_embedder, encoder, attender, output_embedder, decoder):
def __init__(self, src_embedder, encoder, attender, trg_embedder, decoder):
'''Constructor.
:param input_embedder: A word embedder for the input language
:param src_embedder: A word embedder for the input language
:param encoder: An encoder to generate encoded inputs
:param attender: An attention module
:param output_embedder: A word embedder for the output language
:param trg_embedder: A word embedder for the output language
:param decoder: A decoder
'''
self.input_embedder = input_embedder
self.src_embedder = src_embedder
self.encoder = encoder
self.attender = attender
self.output_embedder = output_embedder
self.trg_embedder = trg_embedder
self.decoder = decoder

def shared_params(self):
return [
set(["input_embedder.emb_dim", "encoder.input_dim"]),
set(["src_embedder.emb_dim", "encoder.input_dim"]),
set(["encoder.hidden_dim", "attender.input_dim", "decoder.input_dim"]), # TODO: encoder.hidden_dim may not always exist (e.g. for CNN encoders), need to deal with that case
set(["attender.state_dim", "decoder.lstm_dim"]),
set(["output_embedder.emb_dim", "decoder.trg_embed_dim"]),
set(["trg_embedder.emb_dim", "decoder.trg_embed_dim"]),
]
def dependent_init_params(self):
return [
DependentInitParam(component_name="input_embedder", param_name="vocab_size", value_fct=lambda: self.context["corpus_parser"].src_reader.vocab_size()),
DependentInitParam(component_name="src_embedder", param_name="vocab_size", value_fct=lambda: self.context["corpus_parser"].src_reader.vocab_size()),
DependentInitParam(component_name="decoder", param_name="vocab_size", value_fct=lambda: self.context["corpus_parser"].trg_reader.vocab_size()),
DependentInitParam(component_name="output_embedder", param_name="vocab_size", value_fct=lambda: self.context["corpus_parser"].trg_reader.vocab_size()),
DependentInitParam(component_name="trg_embedder", param_name="vocab_size", value_fct=lambda: self.context["corpus_parser"].trg_reader.vocab_size()),
]

def get_train_test_components(self):
return [self.encoder, self.decoder]

def calc_loss(self, src, trg):
embeddings = self.input_embedder.embed_sent(src)
embeddings = self.src_embedder.embed_sent(src)
encodings = self.encoder.transduce(embeddings)
self.attender.start_sent(encodings)
self.decoder.initialize()
self.decoder.add_input(self.output_embedder.embed(0)) # XXX: HACK, need to initialize decoder better
self.decoder.add_input(self.trg_embedder.embed(0)) # XXX: HACK, need to initialize decoder better
losses = []

# single mode
Expand All @@ -94,7 +94,7 @@ def calc_loss(self, src, trg):
context = self.attender.calc_context(self.decoder.state.output())
word_loss = self.decoder.calc_loss(context, ref_word)
losses.append(word_loss)
self.decoder.add_input(self.output_embedder.embed(ref_word))
self.decoder.add_input(self.trg_embedder.embed(ref_word))

# minibatch mode
else:
Expand All @@ -110,7 +110,7 @@ def calc_loss(self, src, trg):
word_loss = dy.sum_batches(word_loss * mask_exp)
losses.append(word_loss)

self.decoder.add_input(self.output_embedder.embed(ref_word))
self.decoder.add_input(self.trg_embedder.embed(ref_word))

return dy.esum(losses)

Expand All @@ -122,11 +122,11 @@ def translate(self, src, search_strategy=None):
if not Batcher.is_batch_sent(src):
src = Batcher.mark_as_batch([src])
for sents in src:
embeddings = self.input_embedder.embed_sent(src)
embeddings = self.src_embedder.embed_sent(src)
encodings = self.encoder.transduce(embeddings)
self.attender.start_sent(encodings)
self.decoder.initialize()
output.append(search_strategy.generate_output(self.decoder, self.attender, self.output_embedder, src_length=len(sents)))
output.append(search_strategy.generate_output(self.decoder, self.attender, self.trg_embedder, src_length=len(sents)))
return output


2 changes: 1 addition & 1 deletion xnmt/xnmt_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def xnmt_decode(args, model_elements=None):
src_vocab = Vocab(model_params.src_vocab)
trg_vocab = Vocab(model_params.trg_vocab)

translator = DefaultTranslator(model_params.input_embedder, model_params.encoder, model_params.attender, model_params.output_embedder, model_params.decoder)
translator = DefaultTranslator(model_params.src_embedder, model_params.encoder, model_params.attender, model_params.trg_embedder, model_params.decoder)

else:
corpus_parser, translator = model_elements
Expand Down

0 comments on commit 8745a36

Please sign in to comment.