Skip to content

Commit

Permalink
Fixed report breaking
Browse files Browse the repository at this point in the history
  • Loading branch information
philip30 committed Sep 27, 2018
1 parent b2db445 commit d72529b
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 6 deletions.
28 changes: 25 additions & 3 deletions test/config/seg_report.yaml
Expand Up @@ -21,7 +21,23 @@ report: !Experiment
loss_calculator: !FeedbackLoss
child_loss: !MLELoss {}
repeat: 2
dev_tasks:
- !LossEvalTask
src_file: examples/data/head.ja
ref_file: examples/data/head.en
- !AccuracyEvalTask
eval_metrics: bleu,wer
src_file: examples/data/dev.ja
ref_file: examples/data/dev.en
hyp_file: test/tmp/{EXP}.test_hyp
inference: !AutoRegressiveInference
reporter: !SegmentationReporter
report_path: examples/output/test.segment

evaluate:
- !LossEvalTask
src_file: examples/data/head.ja
ref_file: examples/data/head.en
- !AccuracyEvalTask
eval_metrics: bleu,wer
src_file: examples/data/dev.ja
Expand All @@ -30,7 +46,13 @@ report: !Experiment
inference: !AutoRegressiveInference
reporter: !SegmentationReporter
report_path: examples/output/test.segment
- !LossEvalTask
src_file: examples/data/head.ja
ref_file: examples/data/head.en
- !AccuracyEvalTask
eval_metrics: bleu,wer
src_file: examples/data/dev.ja
ref_file: examples/data/dev.en
hyp_file: test/tmp/{EXP}.test_hyp
inference: !AutoRegressiveInference
reporter: !SegmentationReporter
report_path: examples/output/test.segment


15 changes: 12 additions & 3 deletions xnmt/specialized_encoders/segmenting_encoder/segmenting_encoder.py
Expand Up @@ -68,6 +68,14 @@ def __init__(self, embed_encoder=bare(IdentitySeqTransducer),
self.compute_report = compute_report
self.reporter = reporter
self.no_char_embed = issubclass(segment_composer.__class__, VocabBasedComposer)
# Others
self.segmenting_action = None
self.compose_output = None
self.segment_actions = None
self.seg_size_unpadded = None
self.src_sent = None
self.reward = None
self.train = None

def shared_params(self):
return [{".embed_encoder.hidden_dim",".policy_learning.policy_network.input_dim"},
Expand Down Expand Up @@ -106,7 +114,8 @@ def transduce(self, embed_sent: ExpressionSequence) -> List[ExpressionSequence]:
self.compose_output = outputs
self.segment_actions = actions
if not self.train and self.compute_report:
self.report_sent_info({"segment_actions": actions})
if len(actions) == 1: # Support only AccuracyEvalTask
self.report_sent_info({"segment_actions": actions})

@handle_xnmt_event
def on_calc_additional_loss(self, trg, generator, generator_loss):
Expand Down Expand Up @@ -232,14 +241,14 @@ def sparse_to_dense(self, actions, length):
def pad(self, outputs):
# Padding
max_col = max(len(xs) for xs in outputs)
P0 = dy.vecInput(outputs[0][0].dim()[0][0])
p0 = dy.vecInput(outputs[0][0].dim()[0][0])
masks = np.zeros((len(outputs), max_col), dtype=int)
modified = False
ret = []
for xs, mask in zip(outputs, masks):
deficit = max_col - len(xs)
if deficit > 0:
xs.extend([P0 for _ in range(deficit)])
xs.extend([p0 for _ in range(deficit)])
mask[-deficit:] = 1
modified = True
ret.append(dy.concatenate_cols(xs))
Expand Down

0 comments on commit d72529b

Please sign in to comment.