-
Notifications
You must be signed in to change notification settings - Fork 44
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support for multiple inputs v2 (#445)
* decoder interface and type annotations * add inference base clase; rename SimpleInference to SequenceInference; some refactoring to outputs and SequenceInference * add sequence labeler and classifier * move output processing from generator models to inference * rename SequenceInference to AutoRegressiveInference * improve doc for inference * fix some serialization issues * remove some model-specific code * cleanup: unused inference src_mask, inconsistent input reader read_sent * rename ClassifierInference to IndependentOutputInference * update generate, use independent inference for seq labeler * fix warning * refactor MLE loss to be agnostic of translator internals, rename to AutoRegressiveMLELoss * fix single-quote docstring warning * fix further warnings * some renaming plus doc updates for translator * refactor MLP class * un-comment sentencepiece import * avoid code duplication between calc_loss and generate methods * share code between inference classes * IndependentOutputInference supports forced decoding etc * small cleanup * support forced decoding and fix batch loss for sequence labeler * forced decoding for classifier * rename transducer.__call__() to transduce() to simplify multiple inheritance * more principled use of batcher in inference * batch decoding for sequence classifier * DefaultTranslator: fix masking for (looped) batch decoding * made parameters of generate() clearer + other minor doc fixes * Added type annotation for transducers * clean up output interface * some fixes related to reporting * fix unit tests * Separated softmax and projection (#440) * Started separating out softmax * Started fixing tests * Fixed more tests * Fixed remainder of running tests * Fixed the rest of tests * Added AuxNonLinear * Updated examples (many were already broken?) * Fixed recipes * Removed MLP class * Added some doc * fix problem when calling a super constructor that is wrapped in serialized_init * Added some doc * fix / clean up sequence labeler * fix using scorer * document how to run test configs directly * fix examples * update api doc * Removed extraneous yaml file * Update to doc * attempt to fix travis * represent command line args as normal dictionary * temporarily disable travis cache * undo previous commit * fix serialization problem * downgrade pyyaml * read_sent_len * remove unused include_vocab_reference * implement compound reader and input * fix unit tests * fix count_sents() * add simple unit test for compound input reader * fixes to computing batch size * make plaintext reader vocab optional * introduce compound batch, use sent_len() and batch_size() instead of len() * update doc * make compound batch iterable * more consistency for (unsupported) inference batching * support batch slicing * fix unit tests
- Loading branch information
Showing
25 changed files
with
392 additions
and
192 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import unittest | ||
|
||
from xnmt import input_reader | ||
import xnmt.vocab | ||
import xnmt.input | ||
|
||
class TestInputReader(unittest.TestCase): | ||
|
||
def test_one_file_multiple_readers(self): | ||
vocab = xnmt.vocab.Vocab(vocab_file="examples/data/head.en.vocab") | ||
cr = input_reader.CompoundReader(readers = [input_reader.PlainTextReader(vocab), | ||
input_reader.PlainTextReader(read_sent_len=True)]) | ||
sents = list(cr.read_sents(filename="examples/data/head.en")) | ||
self.assertEqual(len(sents), 10) | ||
self.assertIsInstance(sents[0], xnmt.input.CompoundInput) | ||
self.assertEqual(" ".join([vocab.i2w[w] for w in sents[0].inputs[0].words]), "can you do it in one day ? </s>") | ||
self.assertEqual(sents[0].inputs[1].value, len("can you do it in one day ?".split())) | ||
|
||
def test_multiple_files_multiple_readers(self): | ||
vocab_en = xnmt.vocab.Vocab(vocab_file="examples/data/head.en.vocab") | ||
vocab_ja = xnmt.vocab.Vocab(vocab_file="examples/data/head.ja.vocab") | ||
cr = input_reader.CompoundReader(readers = [input_reader.PlainTextReader(vocab_en), | ||
input_reader.PlainTextReader(vocab_ja)]) | ||
sents = list(cr.read_sents(filename=["examples/data/head.en", "examples/data/head.ja"])) | ||
self.assertEqual(len(sents), 10) | ||
self.assertIsInstance(sents[0], xnmt.input.CompoundInput) | ||
self.assertEqual(" ".join([vocab_en.i2w[w] for w in sents[0].inputs[0].words]), "can you do it in one day ? </s>") | ||
self.assertEqual(" ".join([vocab_ja.i2w[w] for w in sents[0].inputs[1].words]), "君 は 1 日 で それ が でき ま す か 。 </s>") | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.