Skip to content

Commit

Permalink
Modularize generate.py (#351)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch/translate#351

This makes it easier for tasks to plugin to generate.py/interactive.py
Pull Request resolved: #520

Differential Revision: D14183881

Pulled By: myleott

fbshipit-source-id: ede5e53ddc1215ed3b12b8f1eba048c946913c33
  • Loading branch information
myleott authored and facebook-github-bot committed Feb 22, 2019
1 parent 08e866f commit b65c579
Show file tree
Hide file tree
Showing 11 changed files with 374 additions and 401 deletions.
22 changes: 15 additions & 7 deletions eval_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def main(parsed_args):
model.make_generation_fast_()
if args.fp16:
model.half()
if use_cuda:
model.cuda()

assert len(models) > 0

Expand All @@ -95,9 +97,7 @@ def main(parsed_args):
).next_epoch_itr(shuffle=False)

gen_timer = StopwatchMeter()
scorer = SequenceScorer(models, task.target_dictionary)
if use_cuda:
scorer.cuda()
scorer = SequenceScorer(task.target_dictionary)

score_sum = 0.
count = 0
Expand All @@ -113,10 +113,18 @@ def main(parsed_args):
word_stats = dict()

with progress_bar.build_progress_bar(args, itr) as t:
results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
wps_meter = TimeMeter()
for _, src_tokens, __, hypos in results:
for hypo in hypos:
for sample in t:
sample = utils.move_to_cuda(sample) if use_cuda else sample
if 'net_input' not in sample:
continue

gen_timer.start()
hypos = scorer.generate(models, sample)
gen_timer.stop(sample['ntokens'])

for hypos_i in hypos:
hypo = hypos_i[0]
pos_scores = hypo['positional_scores']

skipped_toks = 0
Expand Down Expand Up @@ -162,7 +170,7 @@ def main(parsed_args):
if args.output_word_probs:
print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))

wps_meter.update(src_tokens.size(0))
wps_meter.update(sample['ntokens'])
t.log({'wps': round(wps_meter.avg)})

avg_nll_loss = -score_sum / count
Expand Down
14 changes: 1 addition & 13 deletions fairseq/data/backtranslation_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,6 @@ class BacktranslationDataset(FairseqDataset):
backtranslation_fn (callable): function to call to generate
backtranslations. This is typically the `generate` method of a
:class:`~fairseq.sequence_generator.SequenceGenerator` object.
max_len_a, max_len_b (int, int): will be used to compute
`maxlen = max_len_a * src_len + max_len_b`, which will be passed
into *backtranslation_fn*.
output_collater (callable, optional): function to call on the
backtranslated samples to create the final batch
(default: ``tgt_dataset.collater``).
Expand All @@ -82,16 +79,12 @@ def __init__(
self,
tgt_dataset,
backtranslation_fn,
max_len_a,
max_len_b,
output_collater=None,
cuda=True,
**kwargs
):
self.tgt_dataset = tgt_dataset
self.backtranslation_fn = backtranslation_fn
self.max_len_a = max_len_a
self.max_len_b = max_len_b
self.output_collater = output_collater if output_collater is not None \
else tgt_dataset.collater
self.cuda = cuda if torch.cuda.is_available() else False
Expand Down Expand Up @@ -130,12 +123,7 @@ def collater(self, samples):
samples=samples,
collate_fn=self.tgt_dataset.collater,
generate_fn=(
lambda net_input: self.backtranslation_fn(
net_input,
maxlen=int(
self.max_len_a * net_input['src_tokens'].size(1) + self.max_len_b
),
)
lambda net_input: self.backtranslation_fn(net_input)
),
cuda=self.cuda,
)
Expand Down
Loading

0 comments on commit b65c579

Please sign in to comment.