Skip to content

Commit

Permalink
support writing out nbest list v2 (#446)
Browse files Browse the repository at this point in the history
* decoder interface and type annotations

* add inference base clase; rename SimpleInference to SequenceInference; some refactoring to outputs and SequenceInference

* add sequence labeler and classifier

* move output processing from generator models to inference

* rename SequenceInference to AutoRegressiveInference

* improve doc for inference

* fix some serialization issues

* remove some model-specific code

* cleanup: unused inference src_mask, inconsistent input reader read_sent

* rename ClassifierInference to IndependentOutputInference

* update generate, use independent inference for seq labeler

* fix warning

* refactor MLE loss to be agnostic of translator internals, rename to AutoRegressiveMLELoss

* fix single-quote docstring warning

* fix further warnings

* some renaming plus doc updates for translator

* refactor MLP class

* un-comment sentencepiece import

* avoid code duplication between calc_loss and generate methods

* share code between inference classes

* IndependentOutputInference supports forced decoding etc

* small cleanup

* support forced decoding and fix batch loss for sequence labeler

* forced decoding for classifier

* rename transducer.__call__() to transduce() to simplify multiple inheritance

* more principled use of batcher in inference

* batch decoding for sequence classifier

* DefaultTranslator: fix masking for (looped) batch decoding

* made parameters of generate() clearer + other minor doc fixes

* Added type annotation for transducers

* clean up output interface

* support writing out nbest list

* some fixes related to reporting

* fix unit tests

* Separated softmax and projection (#440)

* Started separating out softmax

* Started fixing tests

* Fixed more tests

* Fixed remainder of running tests

* Fixed the rest of tests

* Added AuxNonLinear

* Updated examples (many were already broken?)

* Fixed recipes

* Removed MLP class

* Added some doc

* fix problem when calling a super constructor that is wrapped in serialized_init

* Added some doc

* fix / clean up sequence labeler

* fix using scorer

* document how to run test configs directly

* fix examples

* update api doc

* Removed extraneous yaml file

* Update to doc

* attempt to fix travis

* represent command line args as normal dictionary

* temporarily disable travis cache

* undo previous commit

* fix serialization problem

* downgrade pyyaml

* support writing out nbest list
  • Loading branch information
msperber authored and neubig committed Jun 27, 2018
1 parent e3c7656 commit 316c72c
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 26 deletions.
7 changes: 4 additions & 3 deletions xnmt/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,15 @@ def _generate_output(self, generator: model_base.GeneratorModel, src_corpus: Seq
if forced_ref_corpus: ref_batch = ref_batches[batch_i]
dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY)
outputs = self.generate_one(generator, src_batch, range(cur_sent_i,cur_sent_i+len(src_batch)), ref_batch)
# If debugging forced decoding, make sure it matches
for i in range(len(src_batch)):
for i in range(len(outputs)):
if assert_scores is not None:
# If debugging forced decoding, make sure it matches
assert len(src_batch) == len(outputs), "debug forced decoding not supported with nbest inference"
if (abs(outputs[i].score - assert_scores[cur_sent_i + i]) / abs(assert_scores[cur_sent_i + i])) > 1e-5:
raise ValueError(
f'Forced decoding score {outputs[0].score} and loss {assert_scores[cur_sent_i + i]} do not match at '
f'sentence {cur_sent_i + i}')
output_txt = self.post_processor.process_output(outputs[i])
output_txt = outputs[i].apply_post_processor(self.post_processor)
fp.write(f"{output_txt}\n")
cur_sent_i += len(src_batch)

Expand Down
55 changes: 41 additions & 14 deletions xnmt/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def readable_actions(self) -> Sequence[str]:
"""
raise NotImplementedError('must be implemented by subclasses')

def apply_post_processor(self, output_processor: 'OutputProcessor') -> str:
return output_processor.process_output(self.readable_actions())

def __str__(self):
return " ".join(self.readable_actions())

Expand Down Expand Up @@ -91,14 +94,41 @@ def readable_actions(self):
ret.append(self.vocab[action] if self.vocab else str(action))
return ret


class NbestOutput(Output):
"""
Output in the context of an nbest list.
Args:
base_output: The base output object
nbest_id: The sentence id in the nbest list
print_score: If True, print nbest_id, score, content separated by ``|||```. If False, drop the score.
"""
def __init__(self, base_output: Output, nbest_id: int, print_score: bool = False) -> None:
super().__init__(actions=base_output.actions, score=base_output.score)
self.base_output = base_output
self.nbest_id = nbest_id
self.print_score = print_score
def readable_actions(self) -> Sequence[str]:
return self.base_output.readable_actions()
def __str__(self):
return self._make_nbest_entry(" ".join(self.readable_actions()))
def _make_nbest_entry(self, content_str: str) -> str:
entries = [str(self.nbest_id), content_str]
if self.print_score:
entries.insert(1, str(self.base_output.score))
return " ||| ".join(entries)
def apply_post_processor(self, output_processor: 'OutputProcessor') -> str:
return self._make_nbest_entry(output_processor.process_output(self.readable_actions()))

class OutputProcessor(object):
# TODO: this should be refactored so that multiple processors can be chained
def process_output(self, output: Output) -> str:
def process_output(self, output_actions: Sequence) -> str:
"""
Produce a string-representation of an Output object.
Produce a string-representation of an output.
Args:
output: object holding output actions
output_actions: readable output actions
Returns:
string representation
Expand All @@ -123,8 +153,8 @@ class PlainTextOutputProcessor(OutputProcessor, Serializable):
Handles the typical case of writing plain text, with one sentence per line.
"""
yaml_tag = "!PlainTextOutputProcessor"
def process_output(self, output):
return " ".join(output.readable_actions())
def process_output(self, output_actions):
return " ".join(output_actions)

class JoinCharTextOutputProcessor(PlainTextOutputProcessor, Serializable):
"""
Expand All @@ -137,9 +167,8 @@ class JoinCharTextOutputProcessor(PlainTextOutputProcessor, Serializable):
def __init__(self, space_token="__"):
self.space_token = space_token

def process_output(self, output):
word_list = output.readable_actions()
return "".join(" " if s==self.space_token else s for s in word_list)
def process_output(self, output_actions):
return "".join(" " if s==self.space_token else s for s in output_actions)

class JoinBPETextOutputProcessor(PlainTextOutputProcessor, Serializable):
"""
Expand All @@ -152,9 +181,8 @@ class JoinBPETextOutputProcessor(PlainTextOutputProcessor, Serializable):
def __init__(self, merge_indicator="@@"):
self.merge_indicator_with_space = merge_indicator + " "

def process_output(self, output):
word_list = output.readable_actions()
return " ".join(word_list).replace(self.merge_indicator_with_space, "")
def process_output(self, output_actions):
return " ".join(output_actions).replace(self.merge_indicator_with_space, "")

class JoinPieceTextOutputProcessor(PlainTextOutputProcessor, Serializable):
"""
Expand All @@ -167,6 +195,5 @@ class JoinPieceTextOutputProcessor(PlainTextOutputProcessor, Serializable):
def __init__(self, space_token="\u2581"):
self.space_token = space_token

def process_output(self, output):
word_list = output.readable_actions()
return "".join(word_list).replace(self.space_token, " ").strip()
def process_output(self, output_actions):
return "".join(output_actions).replace(self.space_token, " ").strip()
26 changes: 17 additions & 9 deletions xnmt/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from xnmt.loss import FactoredLossExpr
from xnmt.loss_calculator import LossCalculator
from xnmt.lstm import BiLSTMSeqTransducer
from xnmt.output import TextOutput, Output
from xnmt.output import TextOutput, Output, NbestOutput
import xnmt.plot
from xnmt.reports import Reportable
from xnmt.persistence import serializable_init, Serializable, bare
Expand Down Expand Up @@ -183,10 +183,22 @@ def generate(self, src: Batch, idx: Sequence[int], search_strategy: SearchStrate
search_outputs = search_strategy.generate_output(self, initial_state,
src_length=[len(sent)],
forced_trg_ids=cur_forced_trg)
best_output = sorted(search_outputs, key=lambda x: x.score[0], reverse=True)[0]
output_actions = [x for x in best_output.word_ids[0]]
attentions = [x for x in best_output.attentions[0]]
score = best_output.score[0]
sorted_outputs = sorted(search_outputs, key=lambda x: x.score[0], reverse=True)
assert len(sorted_outputs) >= 1
for curr_output in sorted_outputs:
output_actions = [x for x in curr_output.word_ids[0]]
attentions = [x for x in curr_output.attentions[0]]
score = curr_output.score[0]
if len(sorted_outputs) == 1:
outputs.append(TextOutput(actions=output_actions,
vocab=self.trg_vocab if hasattr(self, "trg_vocab") else None,
score=score))
else:
outputs.append(NbestOutput(TextOutput(actions=output_actions,
vocab=self.trg_vocab if hasattr(self, "trg_vocab") else None,
score=score),
nbest_id=idx[sent_i]))

# In case of reporting
if self.report_path is not None:
if self.reporting_src_vocab:
Expand All @@ -208,10 +220,6 @@ def generate(self, src: Batch, idx: Sequence[int], search_strategy: SearchStrate
self.set_report_resource("src_words", src_words)
self.set_report_path('{}.{}'.format(self.report_path, str(idx[sent_i])))
self.generate_report(self.report_type)
# Append output to the outputs
outputs.append(TextOutput(actions=output_actions,
vocab=self.trg_vocab if hasattr(self, "trg_vocab") else None,
score=score))
return outputs

def generate_one_step(self, current_word: Any, current_state: AutoRegressiveDecoderState) -> TranslatorOutput:
Expand Down

0 comments on commit 316c72c

Please sign in to comment.