Skip to content

Commit

Permalink
Fixed reporting
Browse files Browse the repository at this point in the history
  • Loading branch information
philip30 committed Nov 1, 2018
1 parent de291b7 commit 67a88d5
Show file tree
Hide file tree
Showing 14 changed files with 36 additions and 26 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ test/tmp
test/config/logs/
test/config/models/
test/config/reports/
test/config/hyp/
/*.yaml
examples/preproc/
.project
Expand Down
3 changes: 2 additions & 1 deletion examples/07_load_finetune.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,6 @@ exp2-finetune-model: !LoadSerialized
val: 0.5
- path: train.dev_zero
val: True

- path: status
val: null

2 changes: 2 additions & 0 deletions examples/08_load_eval_beam.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ exp2-eval-model: !LoadSerialized
overwrite: # list of [path, value] pairs. Value can be scalar or an arbitrary object
- path: train # skip the training loop
val: null
- path: status
val: null
- path: model.inference.search_strategy.beam_size # try some new beam settings
val: 10
- path: evaluate
Expand Down
2 changes: 0 additions & 2 deletions examples/14_report.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
# to the inference class.
!Experiment
name: report
exp_global: !ExpGlobal
compute_report: True
model: !DefaultTranslator
src_reader: !PlainTextReader
vocab: !Vocab {vocab_file: examples/data/head.ja.vocab}
Expand Down
1 change: 0 additions & 1 deletion test/config/report.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
report: !Experiment
exp_global: !ExpGlobal
default_layer_dim: 64
compute_report: True
model: !DefaultTranslator
src_reader: !PlainTextReader
vocab: !Vocab {vocab_file: examples/data/head.ja.vocab}
Expand Down
1 change: 0 additions & 1 deletion test/config/seg_report.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
report: !Experiment
exp_global: !ExpGlobal
compute_report: True
default_layer_dim: 64
model: !DefaultTranslator
src_reader: !CharFromWordTextReader
Expand Down
1 change: 0 additions & 1 deletion test/config/speech.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ speech: !Experiment
save_num_checkpoints: 2
default_layer_dim: 32
dropout: 0.4
compute_report: True
preproc: !PreprocRunner
overwrite: False
tasks:
Expand Down
7 changes: 4 additions & 3 deletions xnmt/eval/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,13 @@ def eval(self) -> 'EvalScore':
loss_val += loss.get_factored_loss_val(comb_method=self.loss_comb_method)

loss_stats = {k: v/ref_words_cnt for k, v in loss_val.items()}

#
return LossScore(sum(loss_stats.values()),
loss_stats=loss_stats,
num_ref_words = ref_words_cnt,
desc=self.desc)

class AccuracyEvalTask(EvalTask, reports.Reportable, Serializable):
class AccuracyEvalTask(EvalTask, Serializable):
"""
A task that does evaluation of some measure of accuracy.
Expand Down Expand Up @@ -133,7 +133,8 @@ def __init__(self, src_file: Union[str,Sequence[str]], ref_file: Union[str,Seque

def eval(self):
event_trigger.set_train(False)
self.report_corpus_info({"ref_file":self.ref_file})
if issubclass(self.model.__class__, reports.Reportable):
self.model.report_corpus_info({"ref_file": self.ref_file})
self.inference.perform_inference(generator=self.model,
src_file=self.src_file,
trg_file=self.hyp_file)
Expand Down
5 changes: 5 additions & 0 deletions xnmt/event_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,8 @@ def calc_additional_loss(trg: Union[sent.Sentence, batchers.Batch],
@events.register_xnmt_event_assign
def get_report_input(context={}) -> dict:
return context

@events.register_xnmt_event
def set_reporting(reporting):
pass

2 changes: 0 additions & 2 deletions xnmt/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def __init__(self,
truncate_dec_batches: bool = False,
save_num_checkpoints: numbers.Integral = 1,
loss_comb_method: str = "sum",
compute_report: bool = False,
commandline_args: dict = {},
placeholders: Dict[str, Any] = {}) -> None:
self.model_file = model_file
Expand All @@ -61,7 +60,6 @@ def __init__(self,
self.commandline_args = commandline_args
self.save_num_checkpoints = save_num_checkpoints
self.loss_comb_method = loss_comb_method
self.compute_report = compute_report
self.placeholders = placeholders

class Experiment(Serializable):
Expand Down
1 change: 1 addition & 0 deletions xnmt/inferences.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def _generate_output(self, generator: 'models.GeneratorModel', src_file: str,
with open(trg_file, 'wt', encoding='utf-8') as fp: # Saving the translated output to a trg file
cur_sent_i = 0
src_batch, ref_batch, assert_batch = [], [], []
event_trigger.set_reporting(self.reporter is not None)
for curr_sent_i, (src_line, ref_line, assert_line) in enumerate(zip(src_in, forced_ref_in, assert_in)):
if self.max_num_sents and cur_sent_i >= self.max_num_sents:
break
Expand Down
12 changes: 5 additions & 7 deletions xnmt/models/translators.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ def __init__(self,
trg_embedder: Embedder=bare(SimpleWordEmbedder),
decoder: Decoder=bare(AutoRegressiveDecoder),
inference: inferences.AutoRegressiveInference=bare(inferences.AutoRegressiveInference),
truncate_dec_batches:bool=False,
compute_report:bool = Ref("exp_global.compute_report", default=False)):
truncate_dec_batches:bool=False):
super().__init__(src_reader=src_reader, trg_reader=trg_reader)
self.src_embedder = src_embedder
self.encoder = encoder
Expand All @@ -106,15 +105,13 @@ def __init__(self,
self.decoder = decoder
self.inference = inference
self.truncate_dec_batches = truncate_dec_batches
self.compute_report = compute_report

def shared_params(self):
return [{".src_embedder.emb_dim", ".encoder.input_dim"},
{".encoder.hidden_dim", ".attender.input_dim", ".decoder.input_dim"},
{".attender.state_dim", ".decoder.rnn.hidden_dim"},
{".trg_embedder.emb_dim", ".decoder.trg_embed_dim"}]


def _encode_src(self, src: Union[batchers.Batch, sent.Sentence]):
embeddings = self.src_embedder.embed_sent(src)
encoding = self.encoder.transduce(embeddings)
Expand Down Expand Up @@ -249,11 +246,12 @@ def generate(self,
outputs.append(out_sent)
else:
outputs.append(sent.NbestSentence(base_sent=out_sent, nbest_id=src[0].idx))
if self.compute_report:

if self.is_reporting():
attentions = np.concatenate([x.npvalue() for x in attentions], axis=1)
self.report_sent_info({"attentions": attentions,
"src": src[0],
"output": outputs[0]})
"src": src[0],
"output": outputs[0]})

return outputs

Expand Down
18 changes: 14 additions & 4 deletions xnmt/reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ class Reportable(object):

@register_xnmt_handler
def __init__(self) -> None:
self._sent_info_list = []

pass
def report_sent_info(self, sent_info: Dict[str, Any]) -> None:
"""
Add key/value pairs belonging to the current sentence for reporting.
Expand Down Expand Up @@ -90,13 +90,14 @@ def report_corpus_info(self, glob_info: Dict[str, Any]) -> None:
self._glob_info_list.update(glob_info)

@handle_xnmt_event
def on_get_report_input(self, context=ReportInfo()):
def on_get_report_input(self, context):
if hasattr(self, "_glob_info_list"):
context.glob_info.update(self._glob_info_list)
if not hasattr(self, "_sent_info_list"):
return context
if len(context.sent_info)>0:
assert len(context.sent_info) == len(self._sent_info_list)
assert len(context.sent_info) == len(self._sent_info_list), \
"{} != {}".format(len(context.sent_info), len(self._sent_info_list))
else:
context.sent_info = []
for _ in range(len(self._sent_info_list)): context.sent_info.append({})
Expand All @@ -105,6 +106,14 @@ def on_get_report_input(self, context=ReportInfo()):
self._sent_info_list.clear()
return context

@handle_xnmt_event
def on_set_reporting(self, is_reporting):
self._sent_info_list = []
self._is_reporting = is_reporting

def is_reporting(self):
return self._is_reporting if hasattr(self, "_is_reporting") else False

class Reporter(object):
"""
A base class for a reporter that collects reportable information, formats it and writes it to disk.
Expand Down Expand Up @@ -461,6 +470,7 @@ def create_sent_report(self, segment_actions, src, **kwargs):
def conclude_report(self):
if hasattr(self, "report_fp") and self.report_fp:
self.report_fp.close()
self.report_fp = None


class OOVStatisticsReporter(Reporter, Serializable):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,14 @@ def __init__(self, embed_encoder=bare(IdentitySeqTransducer),
length_prior=None,
eps_greedy=None,
sample_during_search=False,
reporter=None,
compute_report=Ref("exp_global.compute_report", default=False)):
reporter=None):
self.embed_encoder = self.add_serializable_component("embed_encoder", embed_encoder, lambda: embed_encoder)
self.segment_composer = self.add_serializable_component("segment_composer", segment_composer, lambda: segment_composer)
self.final_transducer = self.add_serializable_component("final_transducer", final_transducer, lambda: final_transducer)
self.policy_learning = self.add_serializable_component("policy_learning", policy_learning, lambda: policy_learning) if policy_learning is not None else None
self.length_prior = self.add_serializable_component("length_prior", length_prior, lambda: length_prior) if length_prior is not None else None
self.eps_greedy = self.add_serializable_component("eps_greedy", eps_greedy, lambda: eps_greedy) if eps_greedy is not None else None
self.sample_during_search = sample_during_search
self.compute_report = compute_report
self.reporter = reporter
self.no_char_embed = issubclass(segment_composer.__class__, VocabBasedComposer)
# Others
Expand Down Expand Up @@ -113,7 +111,7 @@ def transduce(self, embed_sent: ExpressionSequence) -> List[ExpressionSequence]:
self.seg_size_unpadded = seg_size_unpadded
self.compose_output = outputs
self.segment_actions = actions
if not self.train and self.compute_report:
if not self.train and self.is_reporting():
if len(actions) == 1: # Support only AccuracyEvalTask
self.report_sent_info({"segment_actions": actions})

Expand Down

0 comments on commit 67a88d5

Please sign in to comment.