Skip to content

Commit

Permalink
max_num_sents for LossEvalTask (#518)
Browse files Browse the repository at this point in the history
  • Loading branch information
msperber authored and neubig committed Aug 16, 2018
1 parent ce31628 commit 662cffe
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions xnmt/eval/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class LossEvalTask(EvalTask, Serializable):
loss_calculator: loss calculator
max_src_len: omit sentences with source length greater than specified number
max_trg_len: omit sentences with target length greater than specified number
max_num_sents: compute loss only for the first n sentences in the given corpus
loss_comb_method: method for combining loss across batch elements ('sum' or 'avg').
desc: description to pass on to computed score objects
"""
Expand All @@ -47,6 +48,7 @@ def __init__(self,
loss_calculator: LossCalculator = bare(MLELoss),
max_src_len: Optional[int] = None,
max_trg_len: Optional[int] = None,
max_num_sents: Optional[int] = None,
loss_comb_method: str = Ref("exp_global.loss_comb_method", default="sum"),
desc: Any = None):
self.model = model
Expand All @@ -57,6 +59,7 @@ def __init__(self,
self.src_data = None
self.max_src_len = max_src_len
self.max_trg_len = max_trg_len
self.max_num_sents = max_num_sents
self.loss_comb_method = loss_comb_method
self.desc=desc

Expand All @@ -75,6 +78,7 @@ def eval(self) -> 'EvalScore':
src_file=self.src_file,
trg_file=self.ref_file,
batcher=self.batcher,
max_num_sents=self.max_num_sents,
max_src_len=self.max_src_len,
max_trg_len=self.max_trg_len)
loss_val = FactoredLossVal()
Expand Down

0 comments on commit 662cffe

Please sign in to comment.