Skip to content

Commit

Permalink
Improve reporting (#449)
Browse files Browse the repository at this point in the history
* initial report refactoring

* reporting working

* clean up

* re-add max_num_sents feature to inference

* fix unit tests

* fix report unit test and give error message if compute_report was not enabled

* create parent dir for reports and set default locations

* draft for segmenting reporter

* attention reports with speech on source side

* add doc; share some code via HTML reporter base class
  • Loading branch information
msperber committed Jul 10, 2018
1 parent a2ce074 commit b2e43df
Show file tree
Hide file tree
Showing 12 changed files with 394 additions and 385 deletions.
37 changes: 11 additions & 26 deletions examples/14_report.yaml
Original file line number Diff line number Diff line change
@@ -1,47 +1,32 @@
# This example demonstrates reporting of attention matrices.
# XNMT supports writing out reports, such as attention matrices generated during inference.
# These are generally created by setting exp_global.compute_report to True, and adding one or several reporters
# to the inference class.
report: !Experiment
exp_global: !ExpGlobal
default_layer_dim: 256
dropout: 0.0
compute_report: True
model: !DefaultTranslator
src_reader: !PlainTextReader
vocab: !Vocab {vocab_file: examples/data/head.ja.vocab}
trg_reader: !PlainTextReader
vocab: !Vocab {vocab_file: examples/data/head.en.vocab}
src_embedder: !SimpleWordEmbedder
emb_dim: 256
encoder: !BiLSTMSeqTransducer
layers: 1
input_dim: 256
attender: !MlpAttender
state_dim: 256
hidden_dim: 256
input_dim: 256
trg_embedder: !SimpleWordEmbedder
emb_dim: 256
decoder: !AutoRegressiveDecoder
rnn: !UniLSTMSeqTransducer
layers: 1
transform: !AuxNonLinear
output_dim: 256
bridge: !NoBridge {}
inference: !AutoRegressiveInference {}
train: !SimpleTrainingRegimen
run_for_epochs: 1
trainer: !AdamTrainer
alpha: 0.01
run_for_epochs: 2
src_file: examples/data/head.ja
trg_file: examples/data/head.en
dev_tasks:
- !LossEvalTask
src_file: examples/data/head.ja
ref_file: examples/data/head.en
train: !SimpleTrainingRegimen
run_for_epochs: 0
src_file: examples/data/head.ja
trg_file: examples/data/head.en
evaluate:
- !AccuracyEvalTask
eval_metrics: bleu
src_file: examples/data/head.ja
ref_file: examples/data/head.en
hyp_file: examples/output/{EXP}.test_hyp
inference: !AutoRegressiveInference
report_path: examples/output/{EXP}.report
report_type: html, file
reporter: !AttentionHtmlReporter { report_path: 'examples/output/{EXP}.report' }

30 changes: 3 additions & 27 deletions test/config/report.yaml
Original file line number Diff line number Diff line change
@@ -1,44 +1,20 @@
report: !Experiment
exp_global: !ExpGlobal
default_layer_dim: 256
dropout: 0.0
compute_report: True
model: !DefaultTranslator
src_reader: !PlainTextReader
vocab: !Vocab {vocab_file: examples/data/head.ja.vocab}
trg_reader: !PlainTextReader
vocab: !Vocab {vocab_file: examples/data/head.en.vocab}
src_embedder: !SimpleWordEmbedder
emb_dim: 256
encoder: !BiLSTMSeqTransducer
layers: 1
input_dim: 256
attender: !MlpAttender
state_dim: 256
hidden_dim: 256
input_dim: 256
trg_embedder: !SimpleWordEmbedder
emb_dim: 256
decoder: !AutoRegressiveDecoder
rnn: !UniLSTMSeqTransducer
layers: 1
bridge: !NoBridge {}
inference: !AutoRegressiveInference {}
train: !SimpleTrainingRegimen
run_for_epochs: 1
trainer: !AdamTrainer
alpha: 0.01
run_for_epochs: 0
src_file: examples/data/head.ja
trg_file: examples/data/head.en
dev_tasks:
- !LossEvalTask
src_file: examples/data/head.ja
ref_file: examples/data/head.en
evaluate:
- !AccuracyEvalTask
eval_metrics: bleu,wer
src_file: examples/data/head.ja
ref_file: examples/data/head.en
hyp_file: test/tmp/{EXP}.test_hyp
inference: !AutoRegressiveInference
report_path: test/tmp/{EXP}.report
report_type: html, file
reporter: !AttentionHtmlReporter {}
5 changes: 1 addition & 4 deletions test/config/segmenting.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,8 @@ debug-segmenting: !Experiment
rnn: !UniLSTMSeqTransducer
layers: 1
bridge: !CopyBridge {}
inference: !AutoRegressiveInference
report_path: test/tmp/{EXP}.report
report_type: html, file
train: !SimpleTrainingRegimen
run_for_epochs: 3
run_for_epochs: 1
src_file: examples/data/head-char.ja
trg_file: examples/data/head.en
dev_tasks:
Expand Down
7 changes: 3 additions & 4 deletions test/config/speech.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ speech: !Experiment
save_num_checkpoints: 2
default_layer_dim: 32
dropout: 0.4
compute_report: True
preproc: !PreprocRunner
overwrite: False
tasks:
Expand Down Expand Up @@ -35,10 +36,8 @@ speech: !Experiment
transpose: True
trg_reader: !PlainTextReader
vocab: !Vocab {vocab_file: examples/data/head.en.vocab}
inference: !AutoRegressiveInference
report_path: test/tmp/report/
train: !SimpleTrainingRegimen
run_for_epochs: 1
run_for_epochs: 0
batcher: !SrcBatcher
pad_src_to_multiple: 4
batch_size: 3
Expand All @@ -57,7 +56,6 @@ speech: !Experiment
hyp_file: test/tmp/{EXP}.dev_hyp
inference: !AutoRegressiveInference
post_process: join-char
report_path: test/tmp/report.{EXP}
batcher: !InOrderBatcher
_xnmt_id: inference_batcher
pad_src_to_multiple: 4
Expand All @@ -72,3 +70,4 @@ speech: !Experiment
inference: !AutoRegressiveInference
post_process: join-char
batcher: !Ref { name: inference_batcher }
reporter: !AttentionHtmlReporter {}
2 changes: 2 additions & 0 deletions xnmt/exp_global.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self,
truncate_dec_batches: bool = False,
save_num_checkpoints: int = 1,
loss_comb_method: str = "sum",
compute_report: bool = False,
commandline_args: dict = {},
placeholders: Dict[str, str] = {}) -> None:
self.model_file = model_file
Expand All @@ -53,4 +54,5 @@ def __init__(self,
self.commandline_args = commandline_args
self.save_num_checkpoints = save_num_checkpoints
self.loss_comb_method = loss_comb_method
self.compute_report = compute_report
self.placeholders = placeholders
62 changes: 26 additions & 36 deletions xnmt/inference.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from collections.abc import Iterable
import collections.abc
from typing import List, Optional, Tuple, Sequence, Union

from xnmt.settings import settings

import dynet as dy

import xnmt.input
import xnmt.batcher
from xnmt import loss, loss_calculator, model_base, output, reports, search_strategy, vocab, util
from xnmt.persistence import serializable_init, Serializable, Ref, bare
from xnmt import loss, loss_calculator, model_base, output, reports, search_strategy, util
from xnmt.persistence import serializable_init, Serializable, bare

NO_DECODING_ATTEMPTED = "@@NO_DECODING_ATTEMPTED@@"

Expand All @@ -30,18 +29,21 @@ class Inference(object):
for debugging purposes.
* ``score``: output scores, useful for rescoring
batcher: inference batcher, needed e.g. in connection with ``pad_src_token_to_multiple``
reporter: a reporter to create reports for each decoded sentence
"""
def __init__(self, src_file: Optional[str] = None, trg_file: Optional[str] = None, ref_file: Optional[str] = None,
max_src_len: Optional[int] = None, max_num_sents: Optional[int] = None,
mode: str = "onebest",
batcher: xnmt.batcher.InOrderBatcher = bare(xnmt.batcher.InOrderBatcher, batch_size=1)):
batcher: xnmt.batcher.InOrderBatcher = bare(xnmt.batcher.InOrderBatcher, batch_size=1),
reporter: Union[None, reports.Reporter, Sequence[reports.Reporter]] = None):
self.src_file = src_file
self.trg_file = trg_file
self.ref_file = ref_file
self.max_src_len = max_src_len
self.max_num_sents = max_num_sents
self.mode = mode
self.batcher = batcher
self.reporter = reporter

def generate_one(self, generator: model_base.GeneratorModel, src: xnmt.input.Input, src_i: int, forced_ref_ids) -> List[output.Output]:
# TODO: src should probably a batch of inputs for consistency with return values being a batch of outputs
Expand All @@ -67,7 +69,7 @@ def perform_inference(self, generator: model_base.GeneratorModel, src_file: str

ref_corpus, src_corpus = self._read_corpus(generator, src_file, mode=self.mode, ref_file=self.ref_file)

self._init_generator(generator)
generator.set_train(False)

ref_scores = None
if self.mode == 'score':
Expand Down Expand Up @@ -115,6 +117,7 @@ def _generate_output(self, generator: model_base.GeneratorModel, src_corpus: Seq
if forced_ref_corpus: ref_batch = ref_batches[batch_i]
dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY)
outputs = self.generate_one(generator, src_batch, range(cur_sent_i,cur_sent_i+batch_size), ref_batch)
if self.reporter: self._create_report()
for i in range(len(outputs)):
if assert_scores is not None:
# If debugging forced decoding, make sure it matches
Expand All @@ -126,6 +129,16 @@ def _generate_output(self, generator: model_base.GeneratorModel, src_corpus: Seq
output_txt = outputs[i].apply_post_processor(self.post_processor)
fp.write(f"{output_txt}\n")
cur_sent_i += batch_size
if self.max_num_sents and cur_sent_i >= self.max_num_sents: break

def _create_report(self):
assert self.reporter is not None
if not isinstance(self.reporter, collections.abc.Iterable):
self.reporter = [self.reporter]
report_inputs = self.reporter[0].get_report_input(context={})
for report_input in report_inputs:
for reporter in self.reporter:
reporter.create_report(**report_input)

def _compute_losses(self, generator, ref_corpus, src_corpus, max_num_sents) -> List[float]:
batched_src, batched_ref = self.batcher.pack(src_corpus, ref_corpus)
Expand All @@ -134,15 +147,13 @@ def _compute_losses(self, generator, ref_corpus, src_corpus, max_num_sents) -> L
if max_num_sents and sent_count >= max_num_sents: break
dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY)
loss_expr = self.compute_losses_one(generator, src, ref)
if isinstance(loss_expr.value(), Iterable):
if isinstance(loss_expr.value(), collections.abc.Iterable):
ref_scores.extend(loss_expr.value())
else:
ref_scores.append(loss_expr.value())
ref_scores = [-x for x in ref_scores]
return ref_scores

def _init_generator(self, generator: model_base.GeneratorModel) -> None:
generator.set_train(False)

@staticmethod
def _write_rescored_output(ref_scores: Sequence[float], ref_file: str, trg_file: str) -> None:
Expand Down Expand Up @@ -219,9 +230,10 @@ def __init__(self, src_file: Optional[str] = None, trg_file: Optional[str] = Non
max_src_len: Optional[int] = None, max_num_sents: Optional[int] = None,
post_process: Union[str, output.OutputProcessor] = bare(output.PlainTextOutputProcessor),
mode: str = "onebest",
batcher: xnmt.batcher.InOrderBatcher = bare(xnmt.batcher.InOrderBatcher, batch_size=1)):
batcher: xnmt.batcher.InOrderBatcher = bare(xnmt.batcher.InOrderBatcher, batch_size=1),
reporter: Union[None, reports.Reporter, Sequence[reports.Reporter]] = None):
super().__init__(src_file=src_file, trg_file=trg_file, ref_file=ref_file, max_src_len=max_src_len,
max_num_sents=max_num_sents, mode=mode, batcher=batcher)
max_num_sents=max_num_sents, mode=mode, batcher=batcher, reporter=reporter)
self.post_processor = output.OutputProcessor.get_output_processor(post_process)

def generate_one(self, generator: model_base.GeneratorModel, src: xnmt.input.Input, src_i: int, forced_ref_ids)\
Expand All @@ -248,8 +260,6 @@ class AutoRegressiveInference(Inference, Serializable):
max_num_sents: Stop decoding after the first n sentences.
post_process: post-processing of translation outputs
(available string shortcuts: ``none``,``join-char``,``join-bpe``,``join-piece``)
report_path: a path to which decoding reports will be written
report_type: report to generate ``file/html``. Can be multiple, separate with comma.
search_strategy: a search strategy used during decoding.
mode: type of decoding to perform.
Expand All @@ -267,16 +277,14 @@ class AutoRegressiveInference(Inference, Serializable):
def __init__(self, src_file: Optional[str] = None, trg_file: Optional[str] = None, ref_file: Optional[str] = None,
max_src_len: Optional[int] = None, max_num_sents: Optional[int] = None,
post_process: Union[str, output.OutputProcessor] = bare(output.PlainTextOutputProcessor),
report_path: Optional[str] = None, report_type: str = "html",
search_strategy: search_strategy.SearchStrategy = bare(search_strategy.BeamSearch),
mode: str = "onebest",
batcher: xnmt.batcher.InOrderBatcher = bare(xnmt.batcher.InOrderBatcher, batch_size=1)):
batcher: xnmt.batcher.InOrderBatcher = bare(xnmt.batcher.InOrderBatcher, batch_size=1),
reporter: Union[None, reports.Reporter, Sequence[reports.Reporter]] = None):
super().__init__(src_file=src_file, trg_file=trg_file, ref_file=ref_file, max_src_len=max_src_len,
max_num_sents=max_num_sents, mode=mode, batcher=batcher)
max_num_sents=max_num_sents, mode=mode, batcher=batcher, reporter=reporter)

self.post_processor = output.OutputProcessor.get_output_processor(post_process)
self.report_path = report_path
self.report_type = report_type
self.search_strategy = search_strategy

def generate_one(self, generator: model_base.GeneratorModel, src: xnmt.input.Input, src_i: int, forced_ref_ids)\
Expand All @@ -288,21 +296,3 @@ def compute_losses_one(self, generator: model_base.GeneratorModel, src: xnmt.inp
ref: xnmt.input.Input) -> loss.FactoredLossExpr:
loss_expr = generator.calc_loss(src, ref, loss_calculator=loss_calculator.AutoRegressiveMLELoss())
return loss_expr

def _init_generator(self, generator: model_base.GeneratorModel) -> None:
generator.set_train(False)

is_reporting = issubclass(generator.__class__, reports.Reportable) and self.report_path is not None
src_vocab = generator.src_reader.vocab if hasattr(generator.src_reader, "vocab") else None
trg_vocab = generator.trg_reader.vocab if hasattr(generator.trg_reader, "vocab") else None

generator.initialize_generator(report_path=self.report_path,
report_type=self.report_type)
if hasattr(generator, "set_trg_vocab"):
generator.set_trg_vocab(trg_vocab)
if hasattr(generator, "set_reporting_src_vocab"):
generator.set_reporting_src_vocab(src_vocab)
if is_reporting:
generator.set_report_resource("src_vocab", src_vocab)
generator.set_report_resource("trg_vocab", trg_vocab)

24 changes: 23 additions & 1 deletion xnmt/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import matplotlib.pyplot as plt

from xnmt import util

def plot_attention(src_words, trg_words, attention_matrix, file_name=None):
"""This takes in source and target words and an attention matrix (in numpy format)
and prints a visualization of this to a file.
Expand All @@ -15,11 +17,12 @@ def plot_attention(src_words, trg_words, attention_matrix, file_name=None):
where rows correspond to source words, and columns correspond to target words
file_name: the name of the file to which we write the attention
"""
fig, ax = plt.subplots()
fig, ax = plt.subplots(figsize=(8.0, 8.0))
# put the major ticks at the middle of each cell
ax.set_xticks(np.arange(attention_matrix.shape[1]) + 0.5, minor=False)
ax.set_yticks(np.arange(attention_matrix.shape[0]) + 0.5, minor=False)
ax.invert_yaxis()
if not src_words: plt.yticks([], [])

# label axes by words
ax.set_xticklabels(trg_words, minor=False)
Expand All @@ -31,8 +34,27 @@ def plot_attention(src_words, trg_words, attention_matrix, file_name=None):
plt.colorbar()

if file_name is not None:
util.make_parent_dir(file_name)
plt.savefig(file_name, dpi=100)
else:
plt.show()
plt.close()

def plot_speech_features(feature_matrix, file_name=None, vertical = True):
"""Plot speech feature matrix.
Args:
feature_matrix: a two-dimensional numpy array of values between zero and one,
where rows correspond to source words, and columns correspond to target words
file_name: the name of the file to which we write the attention
"""
fig, ax = plt.subplots(figsize=(1.5, 8.0))
if vertical: feature_matrix = feature_matrix.T
plt.pcolor(feature_matrix, cmap=plt.cm.magma, vmin=0, vmax=1)
plt.axis('off')
if file_name is not None:
util.make_parent_dir(file_name)
plt.savefig(file_name, dpi=100)
else:
plt.show()
plt.close()

0 comments on commit b2e43df

Please sign in to comment.