From 5cd9c21ad52137d478b1c4849782eabcef64788d Mon Sep 17 00:00:00 2001 From: Alexei Baevski Date: Tue, 15 May 2018 18:54:01 +0100 Subject: [PATCH] allow specifying max_tokens for generation (#280) * allow specifying max_tokens for generation --- fairseq/options.py | 2 +- generate.py | 9 +++++++-- train.py | 4 ++++ 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/fairseq/options.py b/fairseq/options.py index 8361032e6b..33ffbaf64b 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -106,7 +106,7 @@ def add_dataset_args(parser, train=False, gen=False): help='max number of tokens in the target sequence') group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true', help='Ignore too long or too short lines in valid and test set') - group.add_argument('--max-tokens', default=6000, type=int, metavar='N', + group.add_argument('--max-tokens', type=int, metavar='N', help='maximum number of tokens in a batch') group.add_argument('--max-sentences', '--batch-size', type=int, metavar='N', help='maximum number of sentences in a batch') diff --git a/generate.py b/generate.py index bc44505f0b..40df7ab1b6 100644 --- a/generate.py +++ b/generate.py @@ -16,6 +16,10 @@ def main(args): assert args.path is not None, '--path required for generation!' + + if args.max_tokens is None and args.max_sentences is None: + args.max_tokens = 12000 + print(args) assert not args.sampling or args.nbest == args.beam, \ '--sampling requires --nbest to be equal to --beam' @@ -58,12 +62,13 @@ def main(args): # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(args.replace_unk) - # Load dataset (possibly sharded) max_positions = min(model.max_encoder_positions() for model in models) + itr = dataset.eval_dataloader( args.gen_subset, - max_sentences=args.max_sentences or 128, + max_tokens=args.max_tokens, + max_sentences=args.max_sentences, max_positions=max_positions, skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test, ) diff --git a/train.py b/train.py index b83a31b2a9..1582d19642 100644 --- a/train.py +++ b/train.py @@ -18,6 +18,10 @@ def main(args): + + if args.max_tokens is None: + args.max_tokens = 6000 + print(args) if not torch.cuda.is_available():