Skip to content

Commit

Permalink
Support for multiple inputs v2 (#445)
Browse files Browse the repository at this point in the history
* decoder interface and type annotations

* add inference base clase; rename SimpleInference to SequenceInference; some refactoring to outputs and SequenceInference

* add sequence labeler and classifier

* move output processing from generator models to inference

* rename SequenceInference to AutoRegressiveInference

* improve doc for inference

* fix some serialization issues

* remove some model-specific code

* cleanup: unused inference src_mask, inconsistent input reader read_sent

* rename ClassifierInference to IndependentOutputInference

* update generate, use independent inference for seq labeler

* fix warning

* refactor MLE loss to be agnostic of translator internals, rename to AutoRegressiveMLELoss

* fix single-quote docstring warning

* fix further warnings

* some renaming plus doc updates for translator

* refactor MLP class

* un-comment sentencepiece import

* avoid code duplication between calc_loss and generate methods

* share code between inference classes

* IndependentOutputInference supports forced decoding etc

* small cleanup

* support forced decoding and fix batch loss for sequence labeler

* forced decoding for classifier

* rename transducer.__call__() to transduce() to simplify multiple inheritance

* more principled use of batcher in inference

* batch decoding for sequence classifier

* DefaultTranslator: fix masking for (looped) batch decoding

* made parameters of generate() clearer + other minor doc fixes

* Added type annotation for transducers

* clean up output interface

* some fixes related to reporting

* fix unit tests

* Separated softmax and projection (#440)

* Started separating out softmax

* Started fixing tests

* Fixed more tests

* Fixed remainder of running tests

* Fixed the rest of tests

* Added AuxNonLinear

* Updated examples (many were already broken?)

* Fixed recipes

* Removed MLP class

* Added some doc

* fix problem when calling a super constructor that is wrapped in serialized_init

* Added some doc

* fix / clean up sequence labeler

* fix using scorer

* document how to run test configs directly

* fix examples

* update api doc

* Removed extraneous yaml file

* Update to doc

* attempt to fix travis

* represent command line args as normal dictionary

* temporarily disable travis cache

* undo previous commit

* fix serialization problem

* downgrade pyyaml

* read_sent_len

* remove unused include_vocab_reference

* implement compound reader and input

* fix unit tests

* fix count_sents()

* add simple unit test for compound input reader

* fixes to computing batch size

* make plaintext reader vocab optional

* introduce compound batch, use sent_len() and batch_size() instead of len()

* update doc

* make compound batch iterable

* more consistency for (unsupported) inference batching

* support batch slicing

* fix unit tests
  • Loading branch information
msperber committed Jul 10, 2018
1 parent 6fbbba4 commit a2ce074
Show file tree
Hide file tree
Showing 25 changed files with 392 additions and 192 deletions.
4 changes: 2 additions & 2 deletions test/config/encoders.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ exp3-pyramidal-encoder: !Experiment
bridge: !CopyBridge {}
inference: !AutoRegressiveInference
batcher: !InOrderBatcher
batch_size: 5
batch_size: 1
pad_src_to_multiple: 4
exp4-modular-encoder: !Experiment
kwargs:
Expand Down Expand Up @@ -146,6 +146,6 @@ exp4-modular-encoder: !Experiment
bridge: !CopyBridge {}
inference: !AutoRegressiveInference
batcher: !InOrderBatcher
batch_size: 5
batch_size: 1
pad_src_to_multiple: 4

2 changes: 1 addition & 1 deletion test/config/forced.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ forced: !Experiment
inference: !AutoRegressiveInference
mode: forceddebug
ref_file: examples/data/head.en
batcher: !InOrderBatcher { batch_size: 5 }
batcher: !InOrderBatcher { batch_size: 1 }
src_file: examples/data/head.ja
ref_file: examples/data/head.en
hyp_file: test/tmp/{EXP}.forceddebug_hyp
5 changes: 4 additions & 1 deletion test/config/prior_segmenting.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@ prior-segmenting: !Experiment
dropout: 0.5
model: !DefaultTranslator
src_reader: !SegmentationTextReader
vocab: !Vocab {vocab_file: examples/data/head-char.ja.vocab}
vocab: !Vocab
vocab_file: examples/data/head-char.ja.vocab
_xnmt_id: src_vocab
trg_reader: !PlainTextReader
vocab: !Vocab {vocab_file: examples/data/head.en.vocab}
src_embedder: !SimpleWordEmbedder
emb_dim: 16
encoder: !SegmentingSeqTransducer
vocab: !Ref { name: src_vocab }
debug: True
embed_encoder: !BiLSTMSeqTransducer
input_dim: 16
Expand Down
5 changes: 4 additions & 1 deletion test/config/segmenting.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@ debug-segmenting: !Experiment
dropout: 0.5
model: !DefaultTranslator
src_reader: !SegmentationTextReader
vocab: !Vocab {vocab_file: examples/data/head-char.ja.vocab}
vocab: !Vocab
_xnmt_id: src_vocab
vocab_file: examples/data/head-char.ja.vocab
trg_reader: !PlainTextReader
vocab: !Vocab {vocab_file: examples/data/head.en.vocab}
calc_global_fertility: True
calc_attention_entropy: True
src_embedder: !SimpleWordEmbedder
emb_dim: 16
encoder: !SegmentingSeqTransducer
vocab: !Ref { name: src_vocab }
# Components
embed_encoder: !BiLSTMSeqTransducer
input_dim: 16
Expand Down
8 changes: 4 additions & 4 deletions test/test_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,21 @@ def test_batch_random_no_ties(self):
trg_sents = [xnmt.input.SimpleSentenceInput([0] * ((i+3)%6 + 1)) for i in range(1,7)]
my_batcher = xnmt.batcher.SrcBatcher(batch_size=3, src_pad_token=1, trg_pad_token=2)
_, trg = my_batcher.pack(src_sents, trg_sents)
l0 = len(trg[0][0])
l0 = trg[0].sent_len()
for _ in range(10):
_, trg = my_batcher.pack(src_sents, trg_sents)
l = len(trg[0][0])
l = trg[0].sent_len()
self.assertTrue(l==l0)

def test_batch_random_ties(self):
src_sents = [xnmt.input.SimpleSentenceInput([0] * 5) for _ in range(1,7)]
trg_sents = [xnmt.input.SimpleSentenceInput([0] * ((i+3)%6 + 1)) for i in range(1,7)]
my_batcher = xnmt.batcher.SrcBatcher(batch_size=3, src_pad_token=1, trg_pad_token=2)
_, trg = my_batcher.pack(src_sents, trg_sents)
l0 = len(trg[0][0])
l0 = trg[0].sent_len()
for _ in range(10):
_, trg = my_batcher.pack(src_sents, trg_sents)
l = len(trg[0][0])
l = trg[0].sent_len()
if l!=l0: return
self.assertTrue(False)

Expand Down
4 changes: 2 additions & 2 deletions test/test_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def assert_forced_decoding(self, sent_id):
dy.renew_cg()
outputs = self.model.generate(xnmt.batcher.mark_as_batch([self.src_data[sent_id]]), [sent_id], BeamSearch(),
forced_trg_ids=xnmt.batcher.mark_as_batch([self.trg_data[sent_id]]))
self.assertItemsEqual(self.trg_data[sent_id], outputs[0].actions)
self.assertItemsEqual(self.trg_data[sent_id].words, outputs[0].actions)

def test_forced_decoding(self):
for i in range(1):
Expand Down Expand Up @@ -129,7 +129,7 @@ def test_single(self):
forced_trg_ids=xnmt.batcher.mark_as_batch([self.trg_data[0]]))
dy.renew_cg()
train_loss = self.model.calc_loss(src=self.src_data[0],
trg=outputs[0].actions,
trg=outputs[0],
loss_calculator=AutoRegressiveMLELoss()).value()

self.assertAlmostEqual(-outputs[0].score, train_loss, places=4)
Expand Down
4 changes: 2 additions & 2 deletions test/test_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def assert_forced_decoding(self, sent_id):
dy.renew_cg()
outputs = self.model.generate(xnmt.batcher.mark_as_batch([self.src_data[sent_id]]), [sent_id], self.search,
forced_trg_ids=xnmt.batcher.mark_as_batch([self.trg_data[sent_id]]))
self.assertItemsEqual(self.trg_data[sent_id], outputs[0].actions)
self.assertItemsEqual(self.trg_data[sent_id].words, outputs[0].actions)

def test_forced_decoding(self):
for i in range(1):
Expand Down Expand Up @@ -134,7 +134,7 @@ def test_single(self):

dy.renew_cg()
train_loss = self.model.calc_loss(src=self.src_data[0],
trg=outputs[0].actions,
trg=outputs[0],
loss_calculator=AutoRegressiveMLELoss()).value()

self.assertAlmostEqual(-output_score, train_loss, places=5)
Expand Down
2 changes: 1 addition & 1 deletion test/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test_py_lstm_encoder_len(self):
self.set_train(True)
for sent_i in range(10):
dy.renew_cg()
src = self.src_data[sent_i].get_padded_sent(Vocab.ES, 4 - (len(self.src_data[sent_i]) % 4))
src = self.src_data[sent_i].get_padded_sent(Vocab.ES, 4 - (self.src_data[sent_i].sent_len() % 4))
self.start_sent(src)
embeddings = model.src_embedder.embed_sent(src)
encodings = model.encoder.transduce(embeddings)
Expand Down
32 changes: 32 additions & 0 deletions test/test_input_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import unittest

from xnmt import input_reader
import xnmt.vocab
import xnmt.input

class TestInputReader(unittest.TestCase):

def test_one_file_multiple_readers(self):
vocab = xnmt.vocab.Vocab(vocab_file="examples/data/head.en.vocab")
cr = input_reader.CompoundReader(readers = [input_reader.PlainTextReader(vocab),
input_reader.PlainTextReader(read_sent_len=True)])
sents = list(cr.read_sents(filename="examples/data/head.en"))
self.assertEqual(len(sents), 10)
self.assertIsInstance(sents[0], xnmt.input.CompoundInput)
self.assertEqual(" ".join([vocab.i2w[w] for w in sents[0].inputs[0].words]), "can you do it in one day ? </s>")
self.assertEqual(sents[0].inputs[1].value, len("can you do it in one day ?".split()))

def test_multiple_files_multiple_readers(self):
vocab_en = xnmt.vocab.Vocab(vocab_file="examples/data/head.en.vocab")
vocab_ja = xnmt.vocab.Vocab(vocab_file="examples/data/head.ja.vocab")
cr = input_reader.CompoundReader(readers = [input_reader.PlainTextReader(vocab_en),
input_reader.PlainTextReader(vocab_ja)])
sents = list(cr.read_sents(filename=["examples/data/head.en", "examples/data/head.ja"]))
self.assertEqual(len(sents), 10)
self.assertIsInstance(sents[0], xnmt.input.CompoundInput)
self.assertEqual(" ".join([vocab_en.i2w[w] for w in sents[0].inputs[0].words]), "can you do it in one day ? </s>")
self.assertEqual(" ".join([vocab_ja.i2w[w] for w in sents[0].inputs[1].words]), "君 は 1 日 で それ が でき ま す か 。 </s>")


if __name__ == '__main__':
unittest.main()
22 changes: 14 additions & 8 deletions test/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from xnmt.embedder import SimpleWordEmbedder
from xnmt.eval_task import LossEvalTask
import xnmt.events
from xnmt.input_reader import PlainTextReader
from xnmt.input_reader import PlainTextReader, SimpleSentenceInput
from xnmt.lstm import UniLSTMSeqTransducer, BiLSTMSeqTransducer
from xnmt.loss_calculator import AutoRegressiveMLELoss
from xnmt.optimizer import AdamTrainer
Expand Down Expand Up @@ -40,17 +40,20 @@ def assert_single_loss_equals_batch_loss(self, model, pad_src_to_multiple=1):
"""
batch_size=5
src_sents = self.src_data[:batch_size]
src_min = min([len(x) for x in src_sents])
src_min = min([x.sent_len() for x in src_sents])
src_sents_trunc = [s[:src_min] for s in src_sents]
for single_sent in src_sents_trunc:
single_sent[src_min-1] = Vocab.ES
while len(single_sent)%pad_src_to_multiple != 0:
single_sent.append(Vocab.ES)
trg_sents = self.trg_data[:batch_size]
trg_min = min([len(x) for x in trg_sents])
trg_min = min([x.sent_len() for x in trg_sents])
trg_sents_trunc = [s[:trg_min] for s in trg_sents]
for single_sent in trg_sents_trunc: single_sent[trg_min-1] = Vocab.ES

src_sents_trunc = [SimpleSentenceInput(s) for s in src_sents_trunc]
trg_sents_trunc = [SimpleSentenceInput(s) for s in trg_sents_trunc]

single_loss = 0.0
for sent_id in range(batch_size):
dy.renew_cg()
Expand Down Expand Up @@ -172,19 +175,22 @@ def assert_single_loss_equals_batch_loss(self, model, pad_src_to_multiple=1):
"""
batch_size = 5
src_sents = self.src_data[:batch_size]
src_min = min([len(x) for x in src_sents])
src_min = min([x.sent_len() for x in src_sents])
src_sents_trunc = [s[:src_min] for s in src_sents]
for single_sent in src_sents_trunc:
single_sent[src_min-1] = Vocab.ES
while len(single_sent)%pad_src_to_multiple != 0:
single_sent.append(Vocab.ES)
trg_sents = sorted(self.trg_data[:batch_size], key=lambda x: len(x), reverse=True)
trg_max = max([len(x) for x in trg_sents])
trg_sents = sorted(self.trg_data[:batch_size], key=lambda x: x.sent_len(), reverse=True)
trg_max = max([x.sent_len() for x in trg_sents])
trg_masks = Mask(np.zeros([batch_size, trg_max]))
for i in range(batch_size):
for j in range(len(trg_sents[i]), trg_max):
for j in range(trg_sents[i].sent_len(), trg_max):
trg_masks.np_arr[i,j] = 1.0
trg_sents_padded = [[w for w in s] + [Vocab.ES]*(trg_max-len(s)) for s in trg_sents]
trg_sents_padded = [[w for w in s] + [Vocab.ES]*(trg_max-s.sent_len()) for s in trg_sents]

src_sents_trunc = [SimpleSentenceInput(s) for s in src_sents_trunc]
trg_sents_padded = [SimpleSentenceInput(s) for s in trg_sents_padded]

single_loss = 0.0
for sent_id in range(batch_size):
Expand Down

0 comments on commit a2ce074

Please sign in to comment.